Lefei commited on
Commit
7c41556
Β·
verified Β·
1 Parent(s): 8fdbc9b

update app.py, adding user session, periodic cleanup, and more descriptions

Browse files
Files changed (1) hide show
  1. app.py +138 -112
app.py CHANGED
@@ -1,7 +1,5 @@
1
  # app.py
2
  import os
3
- os.environ["GRADIO_TEMP_DIR"] = "/home/mouxiangchen/VisionTSpp/gradio_tmp"
4
-
5
  import gradio as gr
6
  import torch
7
  import numpy as np
@@ -9,42 +7,80 @@ import pandas as pd
9
  import matplotlib.pyplot as plt
10
  import einops
11
  import copy
 
 
 
 
 
12
 
13
  from huggingface_hub import snapshot_download
14
  from visionts import VisionTSpp, freq_to_seasonality_list
15
 
16
  # ========================
17
- # 1. Configuration
18
  # ========================
19
 
20
- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
21
- # DEVICE = 'cpu'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
 
23
  REPO_ID = "Lefei/VisionTSpp"
24
  LOCAL_DIR = "./hf_models/VisionTSpp"
25
  CKPT_PATH = os.path.join(LOCAL_DIR, "visiontspp_model.ckpt")
26
  ARCH = 'mae_base'
27
 
28
- # Download the model from Hugging Face Hub
29
  if not os.path.exists(CKPT_PATH):
30
- os.makedirs(LOCAL_DIR, exist_ok=True)
31
  print("Downloading model from Hugging Face Hub...")
32
- snapshot_download(repo_id=REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=False)
33
 
34
- # Load the model
35
  QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
36
- model = VisionTSpp(
37
- ARCH,
38
- ckpt_path=CKPT_PATH,
39
- # quantiles=QUANTILES,
40
- quantile=True,
41
- clip_input=True,
42
- complete_no_clip=False,
43
- color=True
44
- ).to(DEVICE)
45
  print(f"Model loaded on {DEVICE}")
46
 
47
- # Image normalization constants
48
  imagenet_mean = np.array([0.485, 0.456, 0.406])
49
  imagenet_std = np.array([0.229, 0.224, 0.225])
50
 
@@ -52,11 +88,7 @@ imagenet_std = np.array([0.229, 0.224, 0.225])
52
  # ========================
53
  # 2. Preset Datasets (Now Loaded Locally)
54
  # ========================
55
- # This dictionary maps user-friendly names to local file paths
56
- # ASSUMPTION: These files exist in a 'datasets' subfolder
57
-
58
  data_dir = "./datasets/"
59
- # data_dir = "./"
60
  PRESET_DATASETS = {
61
  "ETTm1": data_dir + "ETTm1.csv",
62
  "ETTm2": data_dir + "ETTm2.csv",
@@ -78,40 +110,28 @@ def load_preset_data(name):
78
  # 3. Visualization Functions (No changes needed)
79
  # ========================
80
  def show_image_tensor(image_tensor, title='', cur_nvars=1, cur_color_list=None):
81
- if image_tensor is None:
82
  return None
83
-
84
- # no need for permute?
85
- # image = image_tensor.permute(1, 2, 0).cpu()
86
  image = image_tensor.cpu()
87
-
88
  cur_image = torch.zeros_like(image)
89
-
90
  height_per_var = image.shape[0] // cur_nvars
91
  for i in range(cur_nvars):
92
  cur_color_idx = cur_color_list[i]
93
  var_slice = image[i*height_per_var:(i+1)*height_per_var, :, :]
94
  unnormalized_channel = var_slice[:, :, cur_color_idx] * imagenet_std[cur_color_idx] + imagenet_mean[cur_color_idx]
95
  cur_image[i*height_per_var:(i+1)*height_per_var, :, cur_color_idx] = unnormalized_channel * 255
96
-
97
  cur_image = torch.clamp(cur_image, 0, 255).int().numpy()
98
-
99
  fig, ax = plt.subplots(figsize=(6, 6))
100
  ax.imshow(cur_image)
101
  ax.set_title(title, fontsize=14)
102
  ax.axis('off')
103
-
104
  plt.tight_layout()
105
  plt.close(fig)
106
-
107
  return fig
108
 
109
  def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_quantiles, context_len, pred_len):
110
- if isinstance(true_data, torch.Tensor):
111
- true_data = true_data.cpu().numpy()
112
- if isinstance(pred_median, torch.Tensor):
113
- pred_median = pred_median.cpu().numpy()
114
-
115
  for i, q in enumerate(pred_quantiles_list):
116
  if isinstance(q, torch.Tensor):
117
  pred_quantiles_list[i] = q.cpu().numpy()
@@ -119,23 +139,12 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
119
  nvars = true_data.shape[1]
120
  FIG_WIDTH, FIG_HEIGHT_PER_VAR = 15, 2.0
121
  fig, axes = plt.subplots(nvars, 1, figsize=(FIG_WIDTH, nvars * FIG_HEIGHT_PER_VAR), sharex=True)
122
- if nvars == 1:
123
- axes = [axes]
124
-
125
- print(f"{len(pred_quantiles_list) = }")
126
- print(f"{len(model_quantiles) = }")
127
- print(f"{model_quantiles = }")
128
- print(f"{pred_quantiles_list[0].shape = }")
129
-
130
- # sorted_quantiles = sorted(zip(model_quantiles, pred_quantiles_list + [pred_median]), key=lambda x: x[0])
131
- # sorted_quantiles = sorted(zip(model_quantiles, pred_quantiles_list), key=lambda x: x[0])
132
 
133
  pred_quantiles_list.insert(len(QUANTILES)//2, pred_median)
134
  sorted_quantiles = sorted(zip(QUANTILES, pred_quantiles_list), key=lambda x: x[0])
135
-
136
  quantile_preds = [item[1] for item in sorted_quantiles if item[0] != 0.5]
137
  quantile_vals = [item[0] for item in sorted_quantiles if item[0] != 0.5]
138
-
139
  num_bands = len(quantile_preds) // 2
140
  quantile_colors = plt.cm.Blues(np.linspace(0.3, 0.8, num_bands))[::-1]
141
 
@@ -143,12 +152,10 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
143
  ax.plot(true_data[:, i], label='Ground Truth', color='black', linewidth=1.5)
144
  pred_range = np.arange(context_len, context_len + pred_len)
145
  ax.plot(pred_range, pred_median[:, i], label='Prediction (Median)', color='red', linewidth=1.5)
146
-
147
  for j in range(num_bands):
148
  lower_quantile_pred, upper_quantile_pred = quantile_preds[j][:, i], quantile_preds[-(j+1)][:, i]
149
  q_low, q_high = quantile_vals[j], quantile_vals[-(j+1)]
150
  ax.fill_between(pred_range, lower_quantile_pred, upper_quantile_pred, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
151
-
152
  y_min, y_max = ax.get_ylim()
153
  ax.vlines(x=context_len, ymin=y_min, ymax=y_max, colors='gray', linestyles='--', alpha=0.7)
154
  ax.set_ylabel(f'Var {i+1}', rotation=0, labelpad=30, ha='right', va='center')
@@ -158,10 +165,8 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
158
  handles, labels = axes[0].get_legend_handles_labels()
159
  unique_labels = dict(zip(labels, handles))
160
  fig.legend(unique_labels.values(), unique_labels.keys(), loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=num_bands + 2)
161
-
162
  plt.tight_layout(rect=[0, 0, 1, 0.95])
163
  plt.close(fig)
164
-
165
  return fig
166
 
167
 
@@ -177,23 +182,19 @@ class PredictionResult:
177
  self.total_samples = total_samples
178
  self.inferred_freq = inferred_freq
179
 
180
- def predict_at_index(df, index, context_len, pred_len):
181
- # === Data Validation & Frequency Inference ===
182
  if 'date' not in df.columns:
183
  raise gr.Error("❌ Input CSV must contain a 'date' column.")
184
 
185
  try:
186
  df['date'] = pd.to_datetime(df['date'])
187
  df = df.sort_values('date').set_index('date')
188
- # *** NEW: Infer frequency ***
189
  inferred_freq = pd.infer_freq(df.index)
190
  if inferred_freq is None:
191
- # Fallback if inference fails
192
  time_diff = df.index[1] - df.index[0]
193
  inferred_freq = pd.tseries.frequencies.to_offset(time_diff).freqstr
194
  gr.Warning(f"Could not reliably infer frequency. Using fallback based on first two timestamps: {inferred_freq}")
195
  print(f"Inferred frequency: {inferred_freq}")
196
-
197
  except Exception as e:
198
  raise gr.Error(f"❌ Date processing failed: {e}. Please check the date format (e.g., YYYY-MM-DD HH:MM:SS).")
199
 
@@ -216,12 +217,10 @@ def predict_at_index(df, index, context_len, pred_len):
216
  y_true_norm = data_norm[start_idx + context_len : start_idx + context_len + pred_len]
217
  x_tensor = torch.FloatTensor(x_norm).unsqueeze(0).to(DEVICE)
218
 
219
- # *** Use inferred frequency ***
220
  periodicity_list = freq_to_seasonality_list(inferred_freq)
221
  periodicity = periodicity_list[0] if periodicity_list else 1
222
 
223
  color_list = [i % 3 for i in range(nvars)]
224
- # model.update_config(context_len=context_len, pred_len=pred_len, periodicity=periodicity)
225
  model.update_config(context_len=context_len, pred_len=pred_len, periodicity=periodicity,
226
  num_patch_input=7, padding_mode='constant')
227
 
@@ -231,34 +230,12 @@ def predict_at_index(df, index, context_len, pred_len):
231
  )
232
  y_pred, y_pred_quantile_list = y_pred
233
 
234
- print(f"{x_tensor.shape = }")
235
- print(f"{y_pred.shape = }")
236
- print(f"{input_image.shape = }")
237
- print(f"{reconstructed_image.shape = }")
238
- print(f"{len(y_pred_quantile_list) = }")
239
-
240
- # print(f"{input_image[0,0,0, :, 0] = }")
241
- # print(f"{input_image[0,0,0, 50:70, 0] = }")
242
- # print(f"{input_image[0,0,0, 100:120, 0] = }")
243
-
244
  all_y_pred_list = copy.deepcopy(y_pred_quantile_list)
245
-
246
- # insert in the place of 0.5 quantile, ie:len(QUANTILES)//2
247
  all_y_pred_list.insert(len(QUANTILES)//2, y_pred)
248
-
249
- print(f"{len(all_y_pred_list) = }")
250
- print(f"{all_y_pred_list[0].shape = }")
251
-
252
  all_preds = dict(zip(QUANTILES, all_y_pred_list))
253
-
254
- print(f"{all_preds.keys() = }")
255
-
256
  pred_median_norm = all_preds.pop(0.5)[0]
257
  pred_quantiles_norm = [q[0] for q in list(all_preds.values())]
258
 
259
- print(f"{pred_median_norm.shape = }")
260
- print(f"{len(pred_quantiles_norm) = }")
261
-
262
  y_true = y_true_norm * x_std + x_mean
263
  pred_median = pred_median_norm.cpu().numpy() * x_std + x_mean
264
  pred_quantiles = [q.cpu().numpy() * x_std + x_mean for q in pred_quantiles_norm]
@@ -274,8 +251,7 @@ def predict_at_index(df, index, context_len, pred_len):
274
  input_img_fig = show_image_tensor(input_image[0, 0], f'Input Image (Sample {index})', nvars, color_list)
275
  recon_img_fig = show_image_tensor(reconstructed_image[0, 0], 'Reconstructed Image', nvars, color_list)
276
 
277
- os.makedirs("outputs", exist_ok=True)
278
- csv_path = "outputs/prediction_result.csv"
279
  time_index = df.index[start_idx + context_len : start_idx + context_len + pred_len]
280
  result_data = {'date': time_index}
281
  for i in range(nvars):
@@ -284,28 +260,38 @@ def predict_at_index(df, index, context_len, pred_len):
284
  result_df = pd.DataFrame(result_data)
285
  result_df.to_csv(csv_path, index=False)
286
 
287
- return PredictionResult(ts_fig, input_img_fig, recon_img_fig, csv_path, total_samples, inferred_freq)
288
 
289
 
290
  # ========================
291
  # 5. Gradio Interface
292
  # ========================
293
- def run_forecast(data_source, upload_file, index, context_len, pred_len):
294
- if data_source == "Upload CSV":
295
- if upload_file is None:
296
- raise gr.Error("Please upload a CSV file when 'Upload CSV' is selected.")
297
- df = pd.read_csv(upload_file.name)
298
- else:
299
- df = load_preset_data(data_source)
300
-
 
 
 
 
301
  try:
 
 
 
 
 
 
 
 
 
302
  index, context_len, pred_len = int(index), int(context_len), int(pred_len)
303
- result = predict_at_index(df, index, context_len, pred_len)
304
 
305
- if index >= result.total_samples:
306
- final_index = result.total_samples - 1
307
- else:
308
- final_index = index
309
 
310
  return (
311
  result.ts_fig,
@@ -313,18 +299,22 @@ def run_forecast(data_source, upload_file, index, context_len, pred_len):
313
  result.recon_img_fig,
314
  result.csv_path,
315
  gr.update(maximum=result.total_samples - 1, value=final_index),
316
- gr.update(value=result.inferred_freq) # *** Update frequency textbox ***
 
317
  )
318
 
319
  except Exception as e:
 
320
  error_fig = plt.figure(figsize=(10, 5))
321
  plt.text(0.5, 0.5, f"An error occurred:\n{str(e)}", ha='center', va='center', wrap=True, color='red', fontsize=12)
322
  plt.axis('off')
323
  plt.close(error_fig)
324
- return error_fig, None, None, None, gr.update(), gr.update(value="Error")
 
325
 
326
- # UI Layout
327
  with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes.Soft()) as demo:
 
 
328
  gr.Markdown("# πŸ•°οΈ VisionTS++: Multivariate Time Series Forecasting")
329
  gr.Markdown(
330
  """
@@ -334,6 +324,7 @@ with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes
334
  - βœ… **Visualize** predictions with multiple **quantile uncertainty bands**.
335
  - βœ… **Slide** through different samples of the dataset for real-time forecasting.
336
  - βœ… **Download** the prediction results as a CSV file.
 
337
  """
338
  )
339
 
@@ -356,36 +347,71 @@ with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes
356
 
357
  context_len = gr.Number(label="Context Length (History)", value=336)
358
  pred_len = gr.Number(label="Prediction Length (Future)", value=96)
359
- # *** Changed to non-interactive textbox to display freq ***
360
- freq_display = gr.Textbox(label="Detected Frequency", interactive=True)
361
 
362
  run_btn = gr.Button("πŸš€ Run Forecast", variant="primary")
363
 
364
  gr.Markdown("### 2. Sample Selection")
365
- sample_index = gr.Slider(label="Sample Index", minimum=0, maximum=100000, step=1, value=100000)
 
 
 
 
 
 
 
366
 
367
  with gr.Column(scale=3):
368
  gr.Markdown("### 3. Prediction Results")
369
  ts_plot = gr.Plot(label="Time Series Forecast with Quantile Bands")
 
 
 
 
 
 
 
 
370
  with gr.Row():
371
  input_img_plot = gr.Plot(label="Input as Image")
372
  recon_img_plot = gr.Plot(label="Reconstructed Image")
 
 
 
 
 
 
 
373
  download_csv = gr.File(label="Download Prediction CSV")
 
 
 
 
 
 
374
 
375
- # --- Event Handlers ---
376
  def toggle_upload_visibility(choice):
377
  return gr.update(visible=(choice == "Upload CSV"))
378
 
379
  data_source.change(fn=toggle_upload_visibility, inputs=data_source, outputs=upload_file)
380
 
381
- inputs = [data_source, upload_file, sample_index, context_len, pred_len]
382
- outputs = [ts_plot, input_img_plot, recon_img_plot, download_csv, sample_index, freq_display]
383
 
384
  run_btn.click(fn=run_forecast, inputs=inputs, outputs=outputs, api_name="run_forecast")
385
  sample_index.release(fn=run_forecast, inputs=inputs, outputs=outputs, api_name="run_forecast_on_slide")
 
 
 
 
 
 
 
 
386
 
387
- # Remove Examples block to avoid startup issues and rely on the button.
388
- # If you still want examples, ensure `cache_examples=False`.
389
- # For simplicity, we'll remove it as the 'Run' button is clear.
390
 
391
- demo.launch(debug=True)
 
 
1
  # app.py
2
  import os
 
 
3
  import gradio as gr
4
  import torch
5
  import numpy as np
 
7
  import matplotlib.pyplot as plt
8
  import einops
9
  import copy
10
+ import uuid
11
+ import shutil
12
+ import time
13
+ import threading # <-- NEW: Import for background tasks
14
+ from pathlib import Path
15
 
16
  from huggingface_hub import snapshot_download
17
  from visionts import VisionTSpp, freq_to_seasonality_list
18
 
19
  # ========================
20
+ # 0. Environment & Cleanup Configuration
21
  # ========================
22
 
23
+ # --- Configuration for Session Cleanup ---
24
+ SESSION_DIR_ROOT = Path("user_sessions")
25
+ SESSION_DIR_ROOT.mkdir(exist_ok=True)
26
+ MAX_FILE_AGE_SECONDS = 24 * 60 * 60 # 24 hours
27
+ CLEANUP_INTERVAL_SECONDS = 60 * 60 # Run cleanup check every 1 hour
28
+
29
+ # set the gradio tmp dir
30
+ os.environ["GRADIO_TEMP_DIR"] = "./user_sessions"
31
+
32
+
33
+ def cleanup_old_sessions():
34
+ """Deletes session folders older than MAX_FILE_AGE_SECONDS."""
35
+ print(f"Running periodic cleanup of old session directories, with periodicity of {CLEANUP_INTERVAL_SECONDS} seconds...")
36
+ now = time.time()
37
+ deleted_count = 0
38
+ for session_dir in SESSION_DIR_ROOT.iterdir():
39
+ if session_dir.is_dir():
40
+ try:
41
+ # Use modification time of the directory as an indicator of last activity
42
+ dir_mod_time = session_dir.stat().st_mtime
43
+ if (now - dir_mod_time) > MAX_FILE_AGE_SECONDS:
44
+ print(f"Cleaning up old session directory: {session_dir}, over {MAX_FILE_AGE_SECONDS} seconds.")
45
+ shutil.rmtree(session_dir)
46
+ deleted_count += 1
47
+ except Exception as e:
48
+ print(f"Error cleaning up directory {session_dir}: {e}")
49
+ if deleted_count > 0:
50
+ print(f"Cleanup complete. Removed {deleted_count} old session(s).")
51
+ else:
52
+ print("Cleanup complete. No old sessions found.")
53
+
54
+
55
+ # --- NEW: Function to run the cleanup periodically in the background ---
56
+ def periodic_cleanup_task():
57
+ """Wrapper function to run cleanup in a loop with a sleep interval."""
58
+ print("Starting background thread for periodic cleanup.")
59
+ while True:
60
+ cleanup_old_sessions()
61
+ time.sleep(CLEANUP_INTERVAL_SECONDS)
62
+
63
+ # ========================
64
+ # 1. Model Configuration
65
+ # ========================
66
 
67
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
68
  REPO_ID = "Lefei/VisionTSpp"
69
  LOCAL_DIR = "./hf_models/VisionTSpp"
70
  CKPT_PATH = os.path.join(LOCAL_DIR, "visiontspp_model.ckpt")
71
  ARCH = 'mae_base'
72
 
 
73
  if not os.path.exists(CKPT_PATH):
74
+ from huggingface_hub import snapshot_download
75
  print("Downloading model from Hugging Face Hub...")
76
+ snapshot_download(repo_id=REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=False, resume_download=True)
77
 
 
78
  QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
79
+ # Assuming VisionTSpp is defined in a separate file or installed package
80
+ # from visionts import VisionTSpp, freq_to_seasonality_list # Placeholder for your model import
81
+ model = VisionTSpp(ARCH, ckpt_path=CKPT_PATH, quantile=True, clip_input=True, complete_no_clip=False, color=True).to(DEVICE)
 
 
 
 
 
 
82
  print(f"Model loaded on {DEVICE}")
83
 
 
84
  imagenet_mean = np.array([0.485, 0.456, 0.406])
85
  imagenet_std = np.array([0.229, 0.224, 0.225])
86
 
 
88
  # ========================
89
  # 2. Preset Datasets (Now Loaded Locally)
90
  # ========================
 
 
 
91
  data_dir = "./datasets/"
 
92
  PRESET_DATASETS = {
93
  "ETTm1": data_dir + "ETTm1.csv",
94
  "ETTm2": data_dir + "ETTm2.csv",
 
110
  # 3. Visualization Functions (No changes needed)
111
  # ========================
112
  def show_image_tensor(image_tensor, title='', cur_nvars=1, cur_color_list=None):
113
+ if image_tensor is None:
114
  return None
 
 
 
115
  image = image_tensor.cpu()
 
116
  cur_image = torch.zeros_like(image)
 
117
  height_per_var = image.shape[0] // cur_nvars
118
  for i in range(cur_nvars):
119
  cur_color_idx = cur_color_list[i]
120
  var_slice = image[i*height_per_var:(i+1)*height_per_var, :, :]
121
  unnormalized_channel = var_slice[:, :, cur_color_idx] * imagenet_std[cur_color_idx] + imagenet_mean[cur_color_idx]
122
  cur_image[i*height_per_var:(i+1)*height_per_var, :, cur_color_idx] = unnormalized_channel * 255
 
123
  cur_image = torch.clamp(cur_image, 0, 255).int().numpy()
 
124
  fig, ax = plt.subplots(figsize=(6, 6))
125
  ax.imshow(cur_image)
126
  ax.set_title(title, fontsize=14)
127
  ax.axis('off')
 
128
  plt.tight_layout()
129
  plt.close(fig)
 
130
  return fig
131
 
132
  def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_quantiles, context_len, pred_len):
133
+ if isinstance(true_data, torch.Tensor): true_data = true_data.cpu().numpy()
134
+ if isinstance(pred_median, torch.Tensor): pred_median = pred_median.cpu().numpy()
 
 
 
135
  for i, q in enumerate(pred_quantiles_list):
136
  if isinstance(q, torch.Tensor):
137
  pred_quantiles_list[i] = q.cpu().numpy()
 
139
  nvars = true_data.shape[1]
140
  FIG_WIDTH, FIG_HEIGHT_PER_VAR = 15, 2.0
141
  fig, axes = plt.subplots(nvars, 1, figsize=(FIG_WIDTH, nvars * FIG_HEIGHT_PER_VAR), sharex=True)
142
+ if nvars == 1: axes = [axes]
 
 
 
 
 
 
 
 
 
143
 
144
  pred_quantiles_list.insert(len(QUANTILES)//2, pred_median)
145
  sorted_quantiles = sorted(zip(QUANTILES, pred_quantiles_list), key=lambda x: x[0])
 
146
  quantile_preds = [item[1] for item in sorted_quantiles if item[0] != 0.5]
147
  quantile_vals = [item[0] for item in sorted_quantiles if item[0] != 0.5]
 
148
  num_bands = len(quantile_preds) // 2
149
  quantile_colors = plt.cm.Blues(np.linspace(0.3, 0.8, num_bands))[::-1]
150
 
 
152
  ax.plot(true_data[:, i], label='Ground Truth', color='black', linewidth=1.5)
153
  pred_range = np.arange(context_len, context_len + pred_len)
154
  ax.plot(pred_range, pred_median[:, i], label='Prediction (Median)', color='red', linewidth=1.5)
 
155
  for j in range(num_bands):
156
  lower_quantile_pred, upper_quantile_pred = quantile_preds[j][:, i], quantile_preds[-(j+1)][:, i]
157
  q_low, q_high = quantile_vals[j], quantile_vals[-(j+1)]
158
  ax.fill_between(pred_range, lower_quantile_pred, upper_quantile_pred, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
 
159
  y_min, y_max = ax.get_ylim()
160
  ax.vlines(x=context_len, ymin=y_min, ymax=y_max, colors='gray', linestyles='--', alpha=0.7)
161
  ax.set_ylabel(f'Var {i+1}', rotation=0, labelpad=30, ha='right', va='center')
 
165
  handles, labels = axes[0].get_legend_handles_labels()
166
  unique_labels = dict(zip(labels, handles))
167
  fig.legend(unique_labels.values(), unique_labels.keys(), loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=num_bands + 2)
 
168
  plt.tight_layout(rect=[0, 0, 1, 0.95])
169
  plt.close(fig)
 
170
  return fig
171
 
172
 
 
182
  self.total_samples = total_samples
183
  self.inferred_freq = inferred_freq
184
 
185
+ def predict_at_index(df, index, context_len, pred_len, session_dir):
 
186
  if 'date' not in df.columns:
187
  raise gr.Error("❌ Input CSV must contain a 'date' column.")
188
 
189
  try:
190
  df['date'] = pd.to_datetime(df['date'])
191
  df = df.sort_values('date').set_index('date')
 
192
  inferred_freq = pd.infer_freq(df.index)
193
  if inferred_freq is None:
 
194
  time_diff = df.index[1] - df.index[0]
195
  inferred_freq = pd.tseries.frequencies.to_offset(time_diff).freqstr
196
  gr.Warning(f"Could not reliably infer frequency. Using fallback based on first two timestamps: {inferred_freq}")
197
  print(f"Inferred frequency: {inferred_freq}")
 
198
  except Exception as e:
199
  raise gr.Error(f"❌ Date processing failed: {e}. Please check the date format (e.g., YYYY-MM-DD HH:MM:SS).")
200
 
 
217
  y_true_norm = data_norm[start_idx + context_len : start_idx + context_len + pred_len]
218
  x_tensor = torch.FloatTensor(x_norm).unsqueeze(0).to(DEVICE)
219
 
 
220
  periodicity_list = freq_to_seasonality_list(inferred_freq)
221
  periodicity = periodicity_list[0] if periodicity_list else 1
222
 
223
  color_list = [i % 3 for i in range(nvars)]
 
224
  model.update_config(context_len=context_len, pred_len=pred_len, periodicity=periodicity,
225
  num_patch_input=7, padding_mode='constant')
226
 
 
230
  )
231
  y_pred, y_pred_quantile_list = y_pred
232
 
 
 
 
 
 
 
 
 
 
 
233
  all_y_pred_list = copy.deepcopy(y_pred_quantile_list)
 
 
234
  all_y_pred_list.insert(len(QUANTILES)//2, y_pred)
 
 
 
 
235
  all_preds = dict(zip(QUANTILES, all_y_pred_list))
 
 
 
236
  pred_median_norm = all_preds.pop(0.5)[0]
237
  pred_quantiles_norm = [q[0] for q in list(all_preds.values())]
238
 
 
 
 
239
  y_true = y_true_norm * x_std + x_mean
240
  pred_median = pred_median_norm.cpu().numpy() * x_std + x_mean
241
  pred_quantiles = [q.cpu().numpy() * x_std + x_mean for q in pred_quantiles_norm]
 
251
  input_img_fig = show_image_tensor(input_image[0, 0], f'Input Image (Sample {index})', nvars, color_list)
252
  recon_img_fig = show_image_tensor(reconstructed_image[0, 0], 'Reconstructed Image', nvars, color_list)
253
 
254
+ csv_path = Path(session_dir) / "prediction_result.csv"
 
255
  time_index = df.index[start_idx + context_len : start_idx + context_len + pred_len]
256
  result_data = {'date': time_index}
257
  for i in range(nvars):
 
260
  result_df = pd.DataFrame(result_data)
261
  result_df.to_csv(csv_path, index=False)
262
 
263
+ return PredictionResult(ts_fig, input_img_fig, recon_img_fig, str(csv_path), total_samples, inferred_freq)
264
 
265
 
266
  # ========================
267
  # 5. Gradio Interface
268
  # ========================
269
+ def get_session_dir(session_id: gr.State):
270
+ """Creates and returns a unique directory for the user session."""
271
+ if session_id is None or not Path(session_id).exists():
272
+ session_uuid = str(uuid.uuid4())
273
+ session_dir = Path(SESSION_DIR_ROOT) / session_uuid
274
+ session_dir.mkdir(exist_ok=True, parents=True)
275
+ session_id = str(session_dir)
276
+ return session_id
277
+
278
+ def run_forecast(data_source, upload_file, index, context_len, pred_len, session_id: gr.State):
279
+ session_dir = get_session_dir(session_id)
280
+
281
  try:
282
+ if data_source == "Upload CSV":
283
+ if upload_file is None:
284
+ raise gr.Error("Please upload a CSV file when 'Upload CSV' is selected.")
285
+ uploaded_file_path = Path(session_dir) / Path(upload_file.name).name
286
+ shutil.copy(upload_file.name, uploaded_file_path)
287
+ df = pd.read_csv(uploaded_file_path)
288
+ else:
289
+ df = load_preset_data(data_source)
290
+
291
  index, context_len, pred_len = int(index), int(context_len), int(pred_len)
292
+ result = predict_at_index(df, index, context_len, pred_len, session_dir)
293
 
294
+ final_index = min(index, result.total_samples - 1)
 
 
 
295
 
296
  return (
297
  result.ts_fig,
 
299
  result.recon_img_fig,
300
  result.csv_path,
301
  gr.update(maximum=result.total_samples - 1, value=final_index),
302
+ gr.update(value=result.inferred_freq),
303
+ session_dir
304
  )
305
 
306
  except Exception as e:
307
+ print(f"Error during forecast: {e}")
308
  error_fig = plt.figure(figsize=(10, 5))
309
  plt.text(0.5, 0.5, f"An error occurred:\n{str(e)}", ha='center', va='center', wrap=True, color='red', fontsize=12)
310
  plt.axis('off')
311
  plt.close(error_fig)
312
+ return error_fig, None, None, None, gr.update(), gr.update(value="Error"), session_id
313
+
314
 
 
315
  with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes.Soft()) as demo:
316
+ session_id_state = gr.State(None)
317
+
318
  gr.Markdown("# πŸ•°οΈ VisionTS++: Multivariate Time Series Forecasting")
319
  gr.Markdown(
320
  """
 
324
  - βœ… **Visualize** predictions with multiple **quantile uncertainty bands**.
325
  - βœ… **Slide** through different samples of the dataset for real-time forecasting.
326
  - βœ… **Download** the prediction results as a CSV file.
327
+ - βœ… **User Isolation**: Each user session has its own temporary storage to prevent file conflicts. Old files are automatically cleaned up.
328
  """
329
  )
330
 
 
347
 
348
  context_len = gr.Number(label="Context Length (History)", value=336)
349
  pred_len = gr.Number(label="Prediction Length (Future)", value=96)
350
+ freq_display = gr.Textbox(label="Detected Frequency", interactive=False)
 
351
 
352
  run_btn = gr.Button("πŸš€ Run Forecast", variant="primary")
353
 
354
  gr.Markdown("### 2. Sample Selection")
355
+ sample_index = gr.Slider(
356
+ label="Sample Index",
357
+ minimum=0,
358
+ maximum=1000000,
359
+ step=1,
360
+ value=1000000,
361
+ info="Drag the slider to select different starting points from the dataset for prediction."
362
+ )
363
 
364
  with gr.Column(scale=3):
365
  gr.Markdown("### 3. Prediction Results")
366
  ts_plot = gr.Plot(label="Time Series Forecast with Quantile Bands")
367
+ gr.Markdown(
368
+ """
369
+ **Plot Explanation:**
370
+ - **⚫ Black Line:** Ground truth data. The left side is the input context, and the right side is the actual future value.
371
+ - **πŸ”΄ Red Line:** The model's median prediction for the future.
372
+ - **πŸ”΅ Blue Shaded Areas:** Represent the model's uncertainty. The darker the blue, the wider the prediction interval, indicating more uncertainty.
373
+ """
374
+ )
375
  with gr.Row():
376
  input_img_plot = gr.Plot(label="Input as Image")
377
  recon_img_plot = gr.Plot(label="Reconstructed Image")
378
+ gr.Markdown(
379
+ """
380
+ **Image Explanation:**
381
+ - **Input as Image:** The historical time series data (look-back window) transformed into an image format that the VisionTS++ model uses as input.
382
+ - **Reconstructed Image:** The model's internal reconstruction of the input image. This helps to visualize what features the model is focusing on.
383
+ """
384
+ )
385
  download_csv = gr.File(label="Download Prediction CSV")
386
+ gr.Markdown(
387
+ """
388
+ **Download Prediction CSV:**
389
+ - You can download the prediction results of VisionTS++ here!
390
+ """
391
+ )
392
 
 
393
  def toggle_upload_visibility(choice):
394
  return gr.update(visible=(choice == "Upload CSV"))
395
 
396
  data_source.change(fn=toggle_upload_visibility, inputs=data_source, outputs=upload_file)
397
 
398
+ inputs = [data_source, upload_file, sample_index, context_len, pred_len, session_id_state]
399
+ outputs = [ts_plot, input_img_plot, recon_img_plot, download_csv, sample_index, freq_display, session_id_state]
400
 
401
  run_btn.click(fn=run_forecast, inputs=inputs, outputs=outputs, api_name="run_forecast")
402
  sample_index.release(fn=run_forecast, inputs=inputs, outputs=outputs, api_name="run_forecast_on_slide")
403
+
404
+
405
+ # ========================
406
+ # 6. Main Execution Block
407
+ # ========================
408
+ if __name__ == "__main__":
409
+ # --- Run initial cleanup on startup ---
410
+ cleanup_old_sessions()
411
 
412
+ # --- NEW: Start the periodic cleanup in a background daemon thread ---
413
+ cleanup_thread = threading.Thread(target=periodic_cleanup_task, daemon=True)
414
+ cleanup_thread.start()
415
 
416
+ # --- Launch the Gradio app ---
417
+ demo.launch(debug=True)