Spaces:
Sleeping
Sleeping
update
Browse files- app.py +164 -172
- models/CSDI/tiffusion.py +1 -59
app.py
CHANGED
|
@@ -43,8 +43,11 @@ class TimeSeriesEditor:
|
|
| 43 |
# Add frequency band multipliers
|
| 44 |
self.freq_bands = np.ones(5) # 5 frequency bands, initially all set to 1.0
|
| 45 |
self.function_parser = FunctionParser()
|
| 46 |
-
self.trending_controls = [
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
| 48 |
def format_value(self, value: float, feature_idx: int) -> str:
|
| 49 |
"""Format value with appropriate units and notation"""
|
| 50 |
if self.show_normalized:
|
|
@@ -377,7 +380,7 @@ class TimeSeriesEditor:
|
|
| 377 |
peak_alpha: float,
|
| 378 |
auc_weight: float,
|
| 379 |
peak_weight: float,
|
| 380 |
-
enable_trending: bool =
|
| 381 |
enable_trending_with_diff: bool = False,
|
| 382 |
trending_params: str = ""
|
| 383 |
) -> Tuple[List[go.Figure], str, str, Dict]:
|
|
@@ -436,15 +439,16 @@ class TimeSeriesEditor:
|
|
| 436 |
# model_control_signal["selected_areas"] = areas
|
| 437 |
|
| 438 |
# Run prediction
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
|
|
|
| 448 |
|
| 449 |
# Store latest results
|
| 450 |
self.latest_sample = sample
|
|
@@ -610,10 +614,10 @@ class TimeSeriesEditor:
|
|
| 610 |
def create_gradio_interface(editor: TimeSeriesEditor):
|
| 611 |
with gr.Blocks() as app:
|
| 612 |
gr.Markdown("# Time Series Editor")
|
| 613 |
-
gr.Markdown("## Instruction: Scroll Down + Click
|
| 614 |
|
| 615 |
metrics_display = gr.JSON(label="Metrics", value={})
|
| 616 |
-
|
| 617 |
with gr.Row():
|
| 618 |
with gr.Column(scale=1):
|
| 619 |
# with Tab():
|
|
@@ -642,103 +646,102 @@ def create_gradio_interface(editor: TimeSeriesEditor):
|
|
| 642 |
|
| 643 |
# TS Section
|
| 644 |
gr.Markdown("## Time Series Control Panel")
|
| 645 |
-
with gr.Accordion("Open for More Detail"):
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 732 |
|
| 733 |
-
|
| 734 |
-
with gr.Group(visible=False):
|
| 735 |
-
gr.Markdown("### Peak Control")
|
| 736 |
-
enable_peaks = gr.Checkbox(label="Enable Peak Control", value=False)
|
| 737 |
-
peak_points_input = gr.Textbox(label="Peak Points (comma-separated)", value="100,200")
|
| 738 |
-
peak_alpha_input = gr.Number(label="Peak Alpha", value=10)
|
| 739 |
-
peak_weight_input = gr.Number(label="Peak Weight", value=1.0)
|
| 740 |
-
|
| 741 |
-
update_model_btn = gr.Button("Update Figure")
|
| 742 |
|
| 743 |
gr.Markdown("## Extend Edit", visible=False)
|
| 744 |
with gr.Tab("Range Shift", visible=False):
|
|
@@ -905,26 +908,26 @@ def create_gradio_interface(editor: TimeSeriesEditor):
|
|
| 905 |
outputs=[*plots, metrics_display]
|
| 906 |
)
|
| 907 |
|
| 908 |
-
|
| 909 |
-
|
| 910 |
-
|
| 911 |
-
|
| 912 |
-
|
| 913 |
-
|
| 914 |
-
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
| 927 |
-
|
| 928 |
|
| 929 |
return app
|
| 930 |
|
|
@@ -1059,63 +1062,52 @@ class FunctionParser:
|
|
| 1059 |
|
| 1060 |
except Exception as e:
|
| 1061 |
print(f"Error: {str(e)}")
|
| 1062 |
-
|
| 1063 |
# Example usage:
|
| 1064 |
if __name__ == "__main__":
|
| 1065 |
-
# Initialize with example data points
|
| 1066 |
-
# example_data_points = "0,0,0.04;2,0,0.58;6,0,0.27;58,0,1.0;-1,0,0.05"
|
| 1067 |
-
|
| 1068 |
import os
|
| 1069 |
import torch
|
| 1070 |
import numpy as np
|
| 1071 |
-
from engine.solver import Trainer
|
| 1072 |
-
from utils.io_utils import load_yaml_config, instantiate_from_config
|
| 1073 |
|
| 1074 |
# assert torch.cuda.is_available(), "CUDA must be available"
|
| 1075 |
-
class Parameters:
|
| 1076 |
-
def __init__(self) -> None:
|
| 1077 |
-
self.gpu = 0
|
| 1078 |
-
self.config_path = "./config/modified/revenue-baseline-365.yaml"
|
| 1079 |
-
# self.config_path = "config/modified/96/fmri.yaml"
|
| 1080 |
-
# self.config_path = "./config/control/revenue-baseline-sine.yaml"
|
| 1081 |
-
# self.save_dir = (
|
| 1082 |
-
# "../../../data/" + os.path.basename(self.config_path).split(".")[0]
|
| 1083 |
-
# )
|
| 1084 |
-
self.mode = "infill"
|
| 1085 |
-
self.missing_ratio = 0.95
|
| 1086 |
-
self.milestone = "10"
|
| 1087 |
-
# os.makedirs(self.save_dir, exist_ok=True)
|
| 1088 |
-
|
| 1089 |
os.environ["WANDB_ENABLED"] = "false"
|
| 1090 |
-
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
| 1091 |
-
# print working directory
|
| 1092 |
print(os.getcwd())
|
| 1093 |
-
|
| 1094 |
-
|
| 1095 |
-
|
| 1096 |
-
|
| 1097 |
-
|
| 1098 |
-
|
| 1099 |
-
|
| 1100 |
-
|
| 1101 |
-
|
| 1102 |
-
|
| 1103 |
-
|
| 1104 |
-
|
| 1105 |
-
|
| 1106 |
-
|
| 1107 |
-
|
| 1108 |
-
|
| 1109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1110 |
feature_dim = 3
|
| 1111 |
print(f"seq_length: {seq_length}, feature_dim: {feature_dim}")
|
| 1112 |
|
| 1113 |
-
|
| 1114 |
-
editor = TimeSeriesEditor(seq_length, feature_dim, trainer)
|
| 1115 |
editor.coef = coef
|
| 1116 |
editor.stepsize = stepsize
|
| 1117 |
editor.sampling_steps = sampling_steps
|
| 1118 |
|
| 1119 |
app = create_gradio_interface(editor)
|
| 1120 |
-
# app.launch(server_name="0.0.0.0", server_port=8888, share=True)
|
| 1121 |
app.launch(show_api=False)
|
|
|
|
| 43 |
# Add frequency band multipliers
|
| 44 |
self.freq_bands = np.ones(5) # 5 frequency bands, initially all set to 1.0
|
| 45 |
self.function_parser = FunctionParser()
|
| 46 |
+
self.trending_controls = [
|
| 47 |
+
(200, 250, 0, self.function_parser.string_to_function("sin(2*pi*x)"), 0.05)
|
| 48 |
+
# 200,250,0,sin(2*pi*x),0.05
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
def format_value(self, value: float, feature_idx: int) -> str:
|
| 52 |
"""Format value with appropriate units and notation"""
|
| 53 |
if self.show_normalized:
|
|
|
|
| 380 |
peak_alpha: float,
|
| 381 |
auc_weight: float,
|
| 382 |
peak_weight: float,
|
| 383 |
+
enable_trending: bool = True,
|
| 384 |
enable_trending_with_diff: bool = False,
|
| 385 |
trending_params: str = ""
|
| 386 |
) -> Tuple[List[go.Figure], str, str, Dict]:
|
|
|
|
| 439 |
# model_control_signal["selected_areas"] = areas
|
| 440 |
|
| 441 |
# Run prediction
|
| 442 |
+
with torch.no_grad():
|
| 443 |
+
sample = self.trainer.predict_weighted_points(
|
| 444 |
+
observed_points, # (seq_length, feature_dim)
|
| 445 |
+
observed_mask, # (seq_length, feature_dim)
|
| 446 |
+
self.coef, # fixed
|
| 447 |
+
self.stepsize, # fixed
|
| 448 |
+
self.sampling_steps, # fixed
|
| 449 |
+
# model_control_signal=model_control_signal,
|
| 450 |
+
gradient_control_signal=gradient_control_signal
|
| 451 |
+
)
|
| 452 |
|
| 453 |
# Store latest results
|
| 454 |
self.latest_sample = sample
|
|
|
|
| 614 |
def create_gradio_interface(editor: TimeSeriesEditor):
|
| 615 |
with gr.Blocks() as app:
|
| 616 |
gr.Markdown("# Time Series Editor")
|
| 617 |
+
gr.Markdown("## Instruction: Scroll Down + Click [Update Figure] [~20s-30s] [Running on CPU...]")
|
| 618 |
|
| 619 |
metrics_display = gr.JSON(label="Metrics", value={})
|
| 620 |
+
|
| 621 |
with gr.Row():
|
| 622 |
with gr.Column(scale=1):
|
| 623 |
# with Tab():
|
|
|
|
| 646 |
|
| 647 |
# TS Section
|
| 648 |
gr.Markdown("## Time Series Control Panel")
|
| 649 |
+
# with gr.Accordion("Open for More Detail"):
|
| 650 |
+
with gr.Group():
|
| 651 |
+
gr.Markdown("### Fixed Point Control")
|
| 652 |
+
data_points_df = gr.Dataframe(
|
| 653 |
+
headers=["time", "feature", "value"],
|
| 654 |
+
datatype=["number", "number", "number"],
|
| 655 |
+
# label="Anchor Point Control",
|
| 656 |
+
value=[[0, 0, 0.04], [2, 0, 0.58], [6, 0, 0.27], [58, 0, 1.0], [60, 0, 0.5]],
|
| 657 |
+
col_count=(3, "fixed"), # Fix number of columns
|
| 658 |
+
interactive=True
|
| 659 |
+
)
|
| 660 |
+
add_data_point_btn = gr.Button("Add Data Point")
|
| 661 |
+
|
| 662 |
+
def add_data_point(df):
|
| 663 |
+
new_row = pd.DataFrame([[None, 0, None]],
|
| 664 |
+
columns=["time", "feature", "value"])
|
| 665 |
+
return pd.concat([df, new_row], ignore_index=True)
|
| 666 |
+
|
| 667 |
+
add_data_point_btn.click(
|
| 668 |
+
fn=add_data_point,
|
| 669 |
+
inputs=[data_points_df],
|
| 670 |
+
outputs=[data_points_df]
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
+
with gr.Group():
|
| 674 |
+
gr.Markdown("### Group of Anchor Point Control with Confidence")
|
| 675 |
+
point_groups_df = gr.Dataframe(
|
| 676 |
+
headers=["start", "end", "interval", "feature", "value", "weight"],
|
| 677 |
+
datatype=["number", "number", "number", "number", "number", "number"],
|
| 678 |
+
# label="Group of Anchor Point Control",
|
| 679 |
+
value=[[0, 50, 10, 0, 0.5, 0.1], [100, 150, 50, 0, 0.1, 0.5]],
|
| 680 |
+
col_count=(6, "fixed"), # Fix number of columns
|
| 681 |
+
interactive=True
|
| 682 |
+
)
|
| 683 |
+
add_point_group_btn = gr.Button("Add Point Group")
|
| 684 |
+
|
| 685 |
+
def add_point_group(df):
|
| 686 |
+
new_row = pd.DataFrame([[None, None, None, 0, None, None]],
|
| 687 |
+
columns=["start", "end", "interval", "feature", "value", "weight"])
|
| 688 |
+
return pd.concat([df, new_row], ignore_index=True)
|
| 689 |
+
|
| 690 |
+
add_point_group_btn.click(
|
| 691 |
+
fn=add_point_group,
|
| 692 |
+
inputs=[point_groups_df],
|
| 693 |
+
outputs=[point_groups_df]
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
with gr.Group():
|
| 697 |
+
# with gr.Tab("Trending Control"):
|
| 698 |
+
gr.Markdown("### Trending Control")
|
| 699 |
+
gr.Markdown("""
|
| 700 |
+
Enter trending control parameters in the format:
|
| 701 |
+
```
|
| 702 |
+
start_time,end_time,feature,function,confidence
|
| 703 |
+
```
|
| 704 |
+
Examples:
|
| 705 |
+
- Linear trend: `0,100,0,x`
|
| 706 |
+
- Sine wave: `0,100,0,sin(2*pi*x)`
|
| 707 |
+
- Exponential: `0,100,0,exp(-x)`
|
| 708 |
+
|
| 709 |
+
Separate multiple trends with semicolons.
|
| 710 |
+
""")
|
| 711 |
+
enable_trending_control = gr.Checkbox(label="Enable Trending Control", value=True)
|
| 712 |
+
enable_trending_control_with_diff = gr.Checkbox(label="Consider Last Generated", value=False)
|
| 713 |
+
trending_control = gr.Textbox(
|
| 714 |
+
label="Trending Control Parameters",
|
| 715 |
+
lines=2,
|
| 716 |
+
placeholder="Enter parameters: start_time,end_time,feature,function,condifdence; separated by semicolons",
|
| 717 |
+
value="200,250,0,sin(2*pi*x),0.05"
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
# Area Control Parameters
|
| 721 |
+
with gr.Group(visible=False):
|
| 722 |
+
gr.Markdown("### Area Control")
|
| 723 |
+
enable_area_control = gr.Checkbox(label="Enable Area Control", value=False)
|
| 724 |
+
area_selections = gr.Textbox(
|
| 725 |
+
label="Area Selections (format: start_time,end_time,feature,target_value)",
|
| 726 |
+
lines=2,
|
| 727 |
+
placeholder="Enter areas: start,end,feature,target; separated by semicolons",
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
# AUC Parameters
|
| 731 |
+
gr.Markdown("### Statistics Control")
|
| 732 |
+
enable_auc = gr.Checkbox(label="Enable Total Sum Control", value=True)
|
| 733 |
+
auc_input = gr.Number(label="Target Sum Value", value=-150)
|
| 734 |
+
auc_weight_input = gr.Number(label="Sum Weight", value=10.0)
|
| 735 |
+
|
| 736 |
+
# Peak Parameters
|
| 737 |
+
with gr.Group(visible=False):
|
| 738 |
+
gr.Markdown("### Peak Control")
|
| 739 |
+
enable_peaks = gr.Checkbox(label="Enable Peak Control", value=False)
|
| 740 |
+
peak_points_input = gr.Textbox(label="Peak Points (comma-separated)", value="100,200")
|
| 741 |
+
peak_alpha_input = gr.Number(label="Peak Alpha", value=10)
|
| 742 |
+
peak_weight_input = gr.Number(label="Peak Weight", value=1.0)
|
| 743 |
|
| 744 |
+
update_model_btn = gr.Button("Update Figure")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 745 |
|
| 746 |
gr.Markdown("## Extend Edit", visible=False)
|
| 747 |
with gr.Tab("Range Shift", visible=False):
|
|
|
|
| 908 |
outputs=[*plots, metrics_display]
|
| 909 |
)
|
| 910 |
|
| 911 |
+
app.load(
|
| 912 |
+
fn=update_model_callback,
|
| 913 |
+
inputs=[
|
| 914 |
+
data_points_df,
|
| 915 |
+
point_groups_df,
|
| 916 |
+
enable_area_control,
|
| 917 |
+
area_selections,
|
| 918 |
+
enable_auc,
|
| 919 |
+
auc_input,
|
| 920 |
+
auc_weight_input,
|
| 921 |
+
enable_peaks,
|
| 922 |
+
peak_points_input,
|
| 923 |
+
peak_alpha_input,
|
| 924 |
+
peak_weight_input,
|
| 925 |
+
enable_trending_control,
|
| 926 |
+
enable_trending_control_with_diff,
|
| 927 |
+
trending_control
|
| 928 |
+
],
|
| 929 |
+
outputs=[*plots, metrics_display]
|
| 930 |
+
)
|
| 931 |
|
| 932 |
return app
|
| 933 |
|
|
|
|
| 1062 |
|
| 1063 |
except Exception as e:
|
| 1064 |
print(f"Error: {str(e)}")
|
|
|
|
| 1065 |
# Example usage:
|
| 1066 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
| 1067 |
import os
|
| 1068 |
import torch
|
| 1069 |
import numpy as np
|
|
|
|
|
|
|
| 1070 |
|
| 1071 |
# assert torch.cuda.is_available(), "CUDA must be available"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1072 |
os.environ["WANDB_ENABLED"] = "false"
|
|
|
|
|
|
|
| 1073 |
print(os.getcwd())
|
| 1074 |
+
|
| 1075 |
+
device = torch.device(f"cuda:0") if torch.cuda.is_available() else "cpu"
|
| 1076 |
+
|
| 1077 |
+
from models.Tiffusion import tiffusion
|
| 1078 |
+
|
| 1079 |
+
model = tiffusion.Tiffusion(
|
| 1080 |
+
seq_length=365,
|
| 1081 |
+
feature_size=3,
|
| 1082 |
+
n_layer_enc=6,
|
| 1083 |
+
n_layer_dec=4,
|
| 1084 |
+
d_model=128,
|
| 1085 |
+
timesteps=500,
|
| 1086 |
+
sampling_timesteps=200,
|
| 1087 |
+
loss_type='l1',
|
| 1088 |
+
beta_schedule='cosine',
|
| 1089 |
+
n_heads=8,
|
| 1090 |
+
mlp_hidden_times=4,
|
| 1091 |
+
attn_pd=0.0,
|
| 1092 |
+
resid_pd=0.0,
|
| 1093 |
+
kernel_size=1,
|
| 1094 |
+
padding_size=0,
|
| 1095 |
+
control_signal=[]
|
| 1096 |
+
).to(device)
|
| 1097 |
+
|
| 1098 |
+
model.load_state_dict(torch.load("./weight/checkpoint-10.pt", map_location=device, weights_only=True)["model"])
|
| 1099 |
+
|
| 1100 |
+
coef = 1.0e-2
|
| 1101 |
+
stepsize = 5.0e-2
|
| 1102 |
+
sampling_steps = 100 # Adjustable between 100-500 for speed/accuracy tradeoff
|
| 1103 |
+
seq_length = 365
|
| 1104 |
feature_dim = 3
|
| 1105 |
print(f"seq_length: {seq_length}, feature_dim: {feature_dim}")
|
| 1106 |
|
| 1107 |
+
editor = TimeSeriesEditor(seq_length, feature_dim, model)
|
|
|
|
| 1108 |
editor.coef = coef
|
| 1109 |
editor.stepsize = stepsize
|
| 1110 |
editor.sampling_steps = sampling_steps
|
| 1111 |
|
| 1112 |
app = create_gradio_interface(editor)
|
|
|
|
| 1113 |
app.launch(show_api=False)
|
models/CSDI/tiffusion.py
CHANGED
|
@@ -33,7 +33,7 @@ def cosine_beta_schedule(timesteps, s=0.008):
|
|
| 33 |
return torch.clip(betas, 0, 0.999)
|
| 34 |
|
| 35 |
|
| 36 |
-
class Tiffusion(
|
| 37 |
def __init__(
|
| 38 |
self,
|
| 39 |
seq_length,
|
|
@@ -111,12 +111,9 @@ class Tiffusion(CSDI_base):
|
|
| 111 |
config_diff["beta_start"], config_diff["beta_end"], self.num_steps
|
| 112 |
)
|
| 113 |
|
| 114 |
-
|
| 115 |
self.alpha_hat = 1 - self.beta
|
| 116 |
self.alpha = np.cumprod(self.alpha_hat)
|
| 117 |
self.alpha_torch = torch.tensor(self.alpha).float().to(self.device).unsqueeze(1).unsqueeze(1)
|
| 118 |
-
# self.beta = torch.from_numpy(self.beta).float().to(self.device)
|
| 119 |
-
# self.alpha = torch.from_numpy(self.alpha).float().to(self.device)
|
| 120 |
|
| 121 |
self.emb_total_dim = self.emb_time_dim + self.emb_feature_dim
|
| 122 |
if self.is_unconditional == False:
|
|
@@ -127,63 +124,8 @@ class Tiffusion(CSDI_base):
|
|
| 127 |
num_embeddings=self.target_dim
|
| 128 |
, embedding_dim=self.emb_feature_dim
|
| 129 |
)
|
| 130 |
-
# self.model: Transformer = Transformer(
|
| 131 |
-
# n_feat=feature_size,
|
| 132 |
-
# n_channel=seq_length,
|
| 133 |
-
# n_layer_enc=n_layer_enc,
|
| 134 |
-
# n_layer_dec=n_layer_dec,
|
| 135 |
-
# n_heads=n_heads,
|
| 136 |
-
# attn_pdrop=attn_pd,
|
| 137 |
-
# resid_pdrop=resid_pd,
|
| 138 |
-
# mlp_hidden_times=mlp_hidden_times,
|
| 139 |
-
# max_len=seq_length,
|
| 140 |
-
# n_embd=d_model,
|
| 141 |
-
# conv_params=[kernel_size, padding_size],
|
| 142 |
-
# **kwargs,
|
| 143 |
-
# )
|
| 144 |
-
class Config:
|
| 145 |
-
def __init__(self, **kwargs):
|
| 146 |
-
self.__dict__.update(kwargs)
|
| 147 |
|
| 148 |
-
# type: CSDI
|
| 149 |
-
# layers: 3
|
| 150 |
-
# channels: 64
|
| 151 |
-
# nheads: 8
|
| 152 |
-
# diffusion_embedding_dim: 128
|
| 153 |
-
# is_linear: False # linear transformer
|
| 154 |
-
|
| 155 |
-
# beta_start: 0.0001
|
| 156 |
-
# beta_end: 0.5
|
| 157 |
-
# schedule: "quad"
|
| 158 |
-
|
| 159 |
-
# num_steps: 50
|
| 160 |
-
|
| 161 |
-
# # edit
|
| 162 |
-
# edit_steps: 50 # the number of steps to perform editing
|
| 163 |
-
# bootstrap_ratio: 0.5 # [0,1]
|
| 164 |
-
|
| 165 |
-
# is_attr_proj: False
|
| 166 |
-
|
| 167 |
-
# side:
|
| 168 |
-
# num_var: 1
|
| 169 |
-
# var_emb: 16
|
| 170 |
-
# time_emb: 128
|
| 171 |
-
|
| 172 |
-
# attrs:
|
| 173 |
-
# attr_emb: 64
|
| 174 |
-
# config_diff["side_dim"] = self.emb_total_dim
|
| 175 |
self.diffmodel = diff_CSDI(
|
| 176 |
-
# config=Config(
|
| 177 |
-
# layers=3,
|
| 178 |
-
# channels=64,
|
| 179 |
-
# nheads=8,
|
| 180 |
-
# diffusion_embedding_dim=128,
|
| 181 |
-
# is_linear=False,
|
| 182 |
-
# beta_start=0.0001,
|
| 183 |
-
# beta_end=0.5,
|
| 184 |
-
# schedule="quad",
|
| 185 |
-
# num_steps=50,
|
| 186 |
-
# )
|
| 187 |
{
|
| 188 |
"layers": 3,
|
| 189 |
"channels": 64,
|
|
|
|
| 33 |
return torch.clip(betas, 0, 0.999)
|
| 34 |
|
| 35 |
|
| 36 |
+
class Tiffusion(nn.Module):
|
| 37 |
def __init__(
|
| 38 |
self,
|
| 39 |
seq_length,
|
|
|
|
| 111 |
config_diff["beta_start"], config_diff["beta_end"], self.num_steps
|
| 112 |
)
|
| 113 |
|
|
|
|
| 114 |
self.alpha_hat = 1 - self.beta
|
| 115 |
self.alpha = np.cumprod(self.alpha_hat)
|
| 116 |
self.alpha_torch = torch.tensor(self.alpha).float().to(self.device).unsqueeze(1).unsqueeze(1)
|
|
|
|
|
|
|
| 117 |
|
| 118 |
self.emb_total_dim = self.emb_time_dim + self.emb_feature_dim
|
| 119 |
if self.is_unconditional == False:
|
|
|
|
| 124 |
num_embeddings=self.target_dim
|
| 125 |
, embedding_dim=self.emb_feature_dim
|
| 126 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
self.diffmodel = diff_CSDI(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
{
|
| 130 |
"layers": 3,
|
| 131 |
"channels": 64,
|