Lefei commited on
Commit
fbd13aa
·
verified ·
1 Parent(s): 2f6dc9b

update app.py, add choice button for VisionTSpp base and large

Browse files
Files changed (1) hide show
  1. app.py +39 -9
app.py CHANGED
@@ -53,7 +53,7 @@ def cleanup_old_sessions():
53
  print("Cleanup complete. No old sessions found.")
54
 
55
 
56
- # --- NEW: Function to run the cleanup periodically in the background ---
57
  def periodic_cleanup_task():
58
  """Wrapper function to run cleanup in a loop with a sleep interval."""
59
  print("Starting background thread for periodic cleanup.")
@@ -96,18 +96,26 @@ for model_size, config in MODEL_CONFIGS.items():
96
 
97
  QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
98
 
99
- # --- NEW: Global variables to hold the currently loaded model and its size ---
 
 
 
 
 
 
100
  CURRENT_MODEL_SIZE = None
101
  CURRENT_MODEL = None
102
 
103
  def load_model_for_size(model_size: str):
104
  """Loads the specified VisionTS++ model."""
105
  global CURRENT_MODEL, CURRENT_MODEL_SIZE
 
106
  if model_size not in MODEL_CONFIGS:
107
  raise ValueError(f"Invalid model size: {model_size}. Available: {list(MODEL_CONFIGS.keys())}")
108
 
109
  config = MODEL_CONFIGS[model_size]
110
  print(f"Loading {model_size} model...")
 
111
  model = VisionTSpp(
112
  config["arch"],
113
  ckpt_path=config["ckpt_path"],
@@ -126,6 +134,7 @@ def load_model_for_size(model_size: str):
126
 
127
  CURRENT_MODEL = model
128
  CURRENT_MODEL_SIZE = model_size
 
129
  return model
130
 
131
  # Load the default model (base) on startup
@@ -160,34 +169,44 @@ def load_preset_data(name):
160
  def show_image_tensor(image_tensor, title='', cur_nvars=1, cur_color_list=None):
161
  if image_tensor is None:
162
  return None
 
163
  image = image_tensor.cpu()
164
  cur_image = torch.zeros_like(image)
 
165
  height_per_var = image.shape[0] // cur_nvars
166
  for i in range(cur_nvars):
167
  cur_color_idx = cur_color_list[i]
168
  var_slice = image[i*height_per_var:(i+1)*height_per_var, :, :]
169
  unnormalized_channel = var_slice[:, :, cur_color_idx] * imagenet_std[cur_color_idx] + imagenet_mean[cur_color_idx]
170
  cur_image[i*height_per_var:(i+1)*height_per_var, :, cur_color_idx] = unnormalized_channel * 255
 
171
  cur_image = torch.clamp(cur_image, 0, 255).int().numpy()
172
  fig, ax = plt.subplots(figsize=(6, 6))
 
173
  ax.imshow(cur_image)
174
  ax.set_title(title, fontsize=14)
175
  ax.axis('off')
176
  plt.tight_layout()
177
  plt.close(fig)
 
178
  return fig
179
 
180
  def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_quantiles, context_len, pred_len):
181
- if isinstance(true_data, torch.Tensor): true_data = true_data.cpu().numpy()
182
- if isinstance(pred_median, torch.Tensor): pred_median = pred_median.cpu().numpy()
 
 
 
183
  for i, q in enumerate(pred_quantiles_list):
184
  if isinstance(q, torch.Tensor):
185
  pred_quantiles_list[i] = q.cpu().numpy()
186
 
187
  nvars = true_data.shape[1]
188
  FIG_WIDTH, FIG_HEIGHT_PER_VAR = 15, 2.0
 
189
  fig, axes = plt.subplots(nvars, 1, figsize=(FIG_WIDTH, nvars * FIG_HEIGHT_PER_VAR), sharex=True)
190
- if nvars == 1: axes = [axes]
 
191
 
192
  pred_quantiles_list.insert(len(QUANTILES)//2, pred_median)
193
  sorted_quantiles = sorted(zip(QUANTILES, pred_quantiles_list), key=lambda x: x[0])
@@ -200,10 +219,12 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
200
  ax.plot(true_data[:, i], label='Ground Truth', color='black', linewidth=1.5)
201
  pred_range = np.arange(context_len, context_len + pred_len)
202
  ax.plot(pred_range, pred_median[:, i], label='Prediction (Median)', color='red', linewidth=1.5)
 
203
  for j in range(num_bands):
204
  lower_quantile_pred, upper_quantile_pred = quantile_preds[j][:, i], quantile_preds[-(j+1)][:, i]
205
  q_low, q_high = quantile_vals[j], quantile_vals[-(j+1)]
206
  ax.fill_between(pred_range, lower_quantile_pred, upper_quantile_pred, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
 
207
  y_min, y_max = ax.get_ylim()
208
  ax.vlines(x=context_len, ymin=y_min, ymax=y_max, colors='gray', linestyles='--', alpha=0.7)
209
  ax.set_ylabel(f'Var {i+1}', rotation=0, labelpad=30, ha='right', va='center')
@@ -212,9 +233,11 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
212
 
213
  handles, labels = axes[0].get_legend_handles_labels()
214
  unique_labels = dict(zip(labels, handles))
 
215
  fig.legend(unique_labels.values(), unique_labels.keys(), loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=num_bands + 2)
216
  plt.tight_layout(rect=[0, 0, 1, 0.95])
217
  plt.close(fig)
 
218
  return fig
219
 
220
 
@@ -237,11 +260,13 @@ def predict_at_index(df, index, context_len, pred_len, session_dir, model_size):
237
  try:
238
  df['date'] = pd.to_datetime(df['date'])
239
  df = df.sort_values('date').set_index('date')
 
240
  inferred_freq = pd.infer_freq(df.index)
241
  if inferred_freq is None:
242
  time_diff = df.index[1] - df.index[0]
243
  inferred_freq = pd.tseries.frequencies.to_offset(time_diff).freqstr
244
  gr.Warning(f"Could not reliably infer frequency. Using fallback based on first two timestamps: {inferred_freq}")
 
245
  print(f"Inferred frequency: {inferred_freq}")
246
  except Exception as e:
247
  raise gr.Error(f"❌ Date processing failed: {e}. Please check the date format (e.g., YYYY-MM-DD HH:MM:SS).")
@@ -270,7 +295,7 @@ def predict_at_index(df, index, context_len, pred_len, session_dir, model_size):
270
 
271
  color_list = [i % 3 for i in range(nvars)]
272
 
273
- # --- NEW: Load the requested model if it's not the current one ---
274
  if CURRENT_MODEL_SIZE != model_size:
275
  print(f"Switching model from {CURRENT_MODEL_SIZE} to {model_size}")
276
  load_model_for_size(model_size)
@@ -289,6 +314,7 @@ def predict_at_index(df, index, context_len, pred_len, session_dir, model_size):
289
  all_y_pred_list = copy.deepcopy(y_pred_quantile_list)
290
  all_y_pred_list.insert(len(QUANTILES)//2, y_pred)
291
  all_preds = dict(zip(QUANTILES, all_y_pred_list))
 
292
  pred_median_norm = all_preds.pop(0.5)[0]
293
  pred_quantiles_norm = [q[0] for q in list(all_preds.values())]
294
 
@@ -304,15 +330,18 @@ def predict_at_index(df, index, context_len, pred_len, session_dir, model_size):
304
  pred_quantiles_list=pred_quantiles, model_quantiles=list(all_preds.keys()),
305
  context_len=context_len, pred_len=pred_len
306
  )
 
307
  input_img_fig = show_image_tensor(input_image[0, 0], f'Input Image (Sample {index})', nvars, color_list)
308
  recon_img_fig = show_image_tensor(reconstructed_image[0, 0], 'Reconstructed Image', nvars, color_list)
309
 
310
  csv_path = Path(session_dir) / "prediction_result.csv"
311
  time_index = df.index[start_idx + context_len : start_idx + context_len + pred_len]
312
  result_data = {'date': time_index}
 
313
  for i in range(nvars):
314
  result_data[f'True_Var{i+1}'] = y_true[:, i]
315
  result_data[f'Pred_Median_Var{i+1}'] = pred_median[:, i]
 
316
  result_df = pd.DataFrame(result_data)
317
  result_df.to_csv(csv_path, index=False)
318
 
@@ -329,6 +358,7 @@ def get_session_dir(session_id: gr.State):
329
  session_dir = Path(SESSION_DIR_ROOT) / session_uuid
330
  session_dir.mkdir(exist_ok=True, parents=True)
331
  session_id = str(session_dir)
 
332
  return session_id
333
 
334
  def run_forecast(data_source, upload_file, index, context_len, pred_len, model_size, session_id: gr.State):
@@ -389,7 +419,7 @@ with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes
389
  with gr.Row():
390
  with gr.Column(scale=1, min_width=300):
391
  gr.Markdown("### 1. Data & Model Configuration")
392
- # --- NEW: Add model selection dropdown ---
393
  model_size = gr.Dropdown(
394
  label="Select Model Size",
395
  choices=["base", "large"],
@@ -459,7 +489,7 @@ with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes
459
 
460
  data_source.change(fn=toggle_upload_visibility, inputs=data_source, outputs=upload_file)
461
 
462
- # --- NEW: Include model_size in the inputs list ---
463
  inputs = [data_source, upload_file, sample_index, context_len, pred_len, model_size, session_id_state]
464
  outputs = [ts_plot, input_img_plot, recon_img_plot, download_csv, sample_index, freq_display, session_id_state]
465
 
@@ -474,7 +504,7 @@ if __name__ == "__main__":
474
  # --- Run initial cleanup on startup ---
475
  cleanup_old_sessions()
476
 
477
- # --- NEW: Start the periodic cleanup in a background daemon thread ---
478
  cleanup_thread = threading.Thread(target=periodic_cleanup_task, daemon=True)
479
  cleanup_thread.start()
480
 
 
53
  print("Cleanup complete. No old sessions found.")
54
 
55
 
56
+ # --- Function to run the cleanup periodically in the background ---
57
  def periodic_cleanup_task():
58
  """Wrapper function to run cleanup in a loop with a sleep interval."""
59
  print("Starting background thread for periodic cleanup.")
 
96
 
97
  QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
98
 
99
+
100
+ # Image normalization constants
101
+ imagenet_mean = np.array([0.485, 0.456, 0.406])
102
+ imagenet_std = np.array([0.229, 0.224, 0.225])
103
+
104
+
105
+ # --- Global variables to hold the currently loaded model and its size ---
106
  CURRENT_MODEL_SIZE = None
107
  CURRENT_MODEL = None
108
 
109
  def load_model_for_size(model_size: str):
110
  """Loads the specified VisionTS++ model."""
111
  global CURRENT_MODEL, CURRENT_MODEL_SIZE
112
+
113
  if model_size not in MODEL_CONFIGS:
114
  raise ValueError(f"Invalid model size: {model_size}. Available: {list(MODEL_CONFIGS.keys())}")
115
 
116
  config = MODEL_CONFIGS[model_size]
117
  print(f"Loading {model_size} model...")
118
+
119
  model = VisionTSpp(
120
  config["arch"],
121
  ckpt_path=config["ckpt_path"],
 
134
 
135
  CURRENT_MODEL = model
136
  CURRENT_MODEL_SIZE = model_size
137
+
138
  return model
139
 
140
  # Load the default model (base) on startup
 
169
  def show_image_tensor(image_tensor, title='', cur_nvars=1, cur_color_list=None):
170
  if image_tensor is None:
171
  return None
172
+
173
  image = image_tensor.cpu()
174
  cur_image = torch.zeros_like(image)
175
+
176
  height_per_var = image.shape[0] // cur_nvars
177
  for i in range(cur_nvars):
178
  cur_color_idx = cur_color_list[i]
179
  var_slice = image[i*height_per_var:(i+1)*height_per_var, :, :]
180
  unnormalized_channel = var_slice[:, :, cur_color_idx] * imagenet_std[cur_color_idx] + imagenet_mean[cur_color_idx]
181
  cur_image[i*height_per_var:(i+1)*height_per_var, :, cur_color_idx] = unnormalized_channel * 255
182
+
183
  cur_image = torch.clamp(cur_image, 0, 255).int().numpy()
184
  fig, ax = plt.subplots(figsize=(6, 6))
185
+
186
  ax.imshow(cur_image)
187
  ax.set_title(title, fontsize=14)
188
  ax.axis('off')
189
  plt.tight_layout()
190
  plt.close(fig)
191
+
192
  return fig
193
 
194
  def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_quantiles, context_len, pred_len):
195
+ if isinstance(true_data, torch.Tensor):
196
+ true_data = true_data.cpu().numpy()
197
+ if isinstance(pred_median, torch.Tensor):
198
+ pred_median = pred_median.cpu().numpy()
199
+
200
  for i, q in enumerate(pred_quantiles_list):
201
  if isinstance(q, torch.Tensor):
202
  pred_quantiles_list[i] = q.cpu().numpy()
203
 
204
  nvars = true_data.shape[1]
205
  FIG_WIDTH, FIG_HEIGHT_PER_VAR = 15, 2.0
206
+
207
  fig, axes = plt.subplots(nvars, 1, figsize=(FIG_WIDTH, nvars * FIG_HEIGHT_PER_VAR), sharex=True)
208
+ if nvars == 1:
209
+ axes = [axes]
210
 
211
  pred_quantiles_list.insert(len(QUANTILES)//2, pred_median)
212
  sorted_quantiles = sorted(zip(QUANTILES, pred_quantiles_list), key=lambda x: x[0])
 
219
  ax.plot(true_data[:, i], label='Ground Truth', color='black', linewidth=1.5)
220
  pred_range = np.arange(context_len, context_len + pred_len)
221
  ax.plot(pred_range, pred_median[:, i], label='Prediction (Median)', color='red', linewidth=1.5)
222
+
223
  for j in range(num_bands):
224
  lower_quantile_pred, upper_quantile_pred = quantile_preds[j][:, i], quantile_preds[-(j+1)][:, i]
225
  q_low, q_high = quantile_vals[j], quantile_vals[-(j+1)]
226
  ax.fill_between(pred_range, lower_quantile_pred, upper_quantile_pred, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
227
+
228
  y_min, y_max = ax.get_ylim()
229
  ax.vlines(x=context_len, ymin=y_min, ymax=y_max, colors='gray', linestyles='--', alpha=0.7)
230
  ax.set_ylabel(f'Var {i+1}', rotation=0, labelpad=30, ha='right', va='center')
 
233
 
234
  handles, labels = axes[0].get_legend_handles_labels()
235
  unique_labels = dict(zip(labels, handles))
236
+
237
  fig.legend(unique_labels.values(), unique_labels.keys(), loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=num_bands + 2)
238
  plt.tight_layout(rect=[0, 0, 1, 0.95])
239
  plt.close(fig)
240
+
241
  return fig
242
 
243
 
 
260
  try:
261
  df['date'] = pd.to_datetime(df['date'])
262
  df = df.sort_values('date').set_index('date')
263
+
264
  inferred_freq = pd.infer_freq(df.index)
265
  if inferred_freq is None:
266
  time_diff = df.index[1] - df.index[0]
267
  inferred_freq = pd.tseries.frequencies.to_offset(time_diff).freqstr
268
  gr.Warning(f"Could not reliably infer frequency. Using fallback based on first two timestamps: {inferred_freq}")
269
+
270
  print(f"Inferred frequency: {inferred_freq}")
271
  except Exception as e:
272
  raise gr.Error(f"❌ Date processing failed: {e}. Please check the date format (e.g., YYYY-MM-DD HH:MM:SS).")
 
295
 
296
  color_list = [i % 3 for i in range(nvars)]
297
 
298
+ # --- Load the requested model if it's not the current one ---
299
  if CURRENT_MODEL_SIZE != model_size:
300
  print(f"Switching model from {CURRENT_MODEL_SIZE} to {model_size}")
301
  load_model_for_size(model_size)
 
314
  all_y_pred_list = copy.deepcopy(y_pred_quantile_list)
315
  all_y_pred_list.insert(len(QUANTILES)//2, y_pred)
316
  all_preds = dict(zip(QUANTILES, all_y_pred_list))
317
+
318
  pred_median_norm = all_preds.pop(0.5)[0]
319
  pred_quantiles_norm = [q[0] for q in list(all_preds.values())]
320
 
 
330
  pred_quantiles_list=pred_quantiles, model_quantiles=list(all_preds.keys()),
331
  context_len=context_len, pred_len=pred_len
332
  )
333
+
334
  input_img_fig = show_image_tensor(input_image[0, 0], f'Input Image (Sample {index})', nvars, color_list)
335
  recon_img_fig = show_image_tensor(reconstructed_image[0, 0], 'Reconstructed Image', nvars, color_list)
336
 
337
  csv_path = Path(session_dir) / "prediction_result.csv"
338
  time_index = df.index[start_idx + context_len : start_idx + context_len + pred_len]
339
  result_data = {'date': time_index}
340
+
341
  for i in range(nvars):
342
  result_data[f'True_Var{i+1}'] = y_true[:, i]
343
  result_data[f'Pred_Median_Var{i+1}'] = pred_median[:, i]
344
+
345
  result_df = pd.DataFrame(result_data)
346
  result_df.to_csv(csv_path, index=False)
347
 
 
358
  session_dir = Path(SESSION_DIR_ROOT) / session_uuid
359
  session_dir.mkdir(exist_ok=True, parents=True)
360
  session_id = str(session_dir)
361
+
362
  return session_id
363
 
364
  def run_forecast(data_source, upload_file, index, context_len, pred_len, model_size, session_id: gr.State):
 
419
  with gr.Row():
420
  with gr.Column(scale=1, min_width=300):
421
  gr.Markdown("### 1. Data & Model Configuration")
422
+ # --- Add model selection dropdown ---
423
  model_size = gr.Dropdown(
424
  label="Select Model Size",
425
  choices=["base", "large"],
 
489
 
490
  data_source.change(fn=toggle_upload_visibility, inputs=data_source, outputs=upload_file)
491
 
492
+ # --- Include model_size in the inputs list ---
493
  inputs = [data_source, upload_file, sample_index, context_len, pred_len, model_size, session_id_state]
494
  outputs = [ts_plot, input_img_plot, recon_img_plot, download_csv, sample_index, freq_display, session_id_state]
495
 
 
504
  # --- Run initial cleanup on startup ---
505
  cleanup_old_sessions()
506
 
507
+ # --- Start the periodic cleanup in a background daemon thread ---
508
  cleanup_thread = threading.Thread(target=periodic_cleanup_task, daemon=True)
509
  cleanup_thread.start()
510