Spaces:
Sleeping
Sleeping
update T4 env
Browse files
app.py
CHANGED
|
@@ -47,6 +47,7 @@ class TimeSeriesEditor:
|
|
| 47 |
(200, 250, 0, self.function_parser.string_to_function("sin(2*pi*x)"), 0.05)
|
| 48 |
# 200,250,0,sin(2*pi*x),0.05
|
| 49 |
]
|
|
|
|
| 50 |
|
| 51 |
def format_value(self, value: float, feature_idx: int) -> str:
|
| 52 |
"""Format value with appropriate units and notation"""
|
|
@@ -440,6 +441,10 @@ class TimeSeriesEditor:
|
|
| 440 |
|
| 441 |
# Run prediction
|
| 442 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
sample = self.trainer.predict_weighted_points(
|
| 444 |
observed_points, # (seq_length, feature_dim)
|
| 445 |
observed_mask, # (seq_length, feature_dim)
|
|
@@ -614,7 +619,7 @@ class TimeSeriesEditor:
|
|
| 614 |
def create_gradio_interface(editor: TimeSeriesEditor):
|
| 615 |
with gr.Blocks() as app:
|
| 616 |
gr.Markdown("# Time Series Editor")
|
| 617 |
-
gr.Markdown("## Instruction: Scroll Down + Click [Update Figure] [~20s
|
| 618 |
|
| 619 |
metrics_display = gr.JSON(label="Metrics", value={})
|
| 620 |
|
|
@@ -1073,7 +1078,9 @@ if __name__ == "__main__":
|
|
| 1073 |
print(os.getcwd())
|
| 1074 |
|
| 1075 |
device = torch.device(f"cuda:0") if torch.cuda.is_available() else "cpu"
|
| 1076 |
-
|
|
|
|
|
|
|
| 1077 |
from models.Tiffusion import tiffusion
|
| 1078 |
|
| 1079 |
model = tiffusion.Tiffusion(
|
|
|
|
| 47 |
(200, 250, 0, self.function_parser.string_to_function("sin(2*pi*x)"), 0.05)
|
| 48 |
# 200,250,0,sin(2*pi*x),0.05
|
| 49 |
]
|
| 50 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 51 |
|
| 52 |
def format_value(self, value: float, feature_idx: int) -> str:
|
| 53 |
"""Format value with appropriate units and notation"""
|
|
|
|
| 441 |
|
| 442 |
# Run prediction
|
| 443 |
with torch.no_grad():
|
| 444 |
+
# to cuda
|
| 445 |
+
observed_points = observed_points.to(self.device)
|
| 446 |
+
observed_mask = observed_mask.to(self.device)
|
| 447 |
+
|
| 448 |
sample = self.trainer.predict_weighted_points(
|
| 449 |
observed_points, # (seq_length, feature_dim)
|
| 450 |
observed_mask, # (seq_length, feature_dim)
|
|
|
|
| 619 |
def create_gradio_interface(editor: TimeSeriesEditor):
|
| 620 |
with gr.Blocks() as app:
|
| 621 |
gr.Markdown("# Time Series Editor")
|
| 622 |
+
gr.Markdown("## Instruction: Scroll Down + Click [Update Figure] [~20s]")
|
| 623 |
|
| 624 |
metrics_display = gr.JSON(label="Metrics", value={})
|
| 625 |
|
|
|
|
| 1078 |
print(os.getcwd())
|
| 1079 |
|
| 1080 |
device = torch.device(f"cuda:0") if torch.cuda.is_available() else "cpu"
|
| 1081 |
+
print(f"Device: {device}")
|
| 1082 |
+
print(f"Using device: {device}")
|
| 1083 |
+
|
| 1084 |
from models.Tiffusion import tiffusion
|
| 1085 |
|
| 1086 |
model = tiffusion.Tiffusion(
|