Lefei commited on
Commit
2f6dc9b
Β·
verified Β·
1 Parent(s): ae594ea

update app.py, add choice of base and large models

Browse files
Files changed (1) hide show
  1. app.py +95 -30
app.py CHANGED
@@ -10,10 +10,11 @@ import copy
10
  import uuid
11
  import shutil
12
  import time
13
- import threading # <-- NEW: Import for background tasks
14
  from pathlib import Path
15
 
16
  from huggingface_hub import snapshot_download
 
17
  from visionts import VisionTSpp, freq_to_seasonality_list
18
 
19
  # ========================
@@ -60,6 +61,7 @@ def periodic_cleanup_task():
60
  cleanup_old_sessions()
61
  time.sleep(CLEANUP_INTERVAL_SECONDS)
62
 
 
63
  # ========================
64
  # 1. Model Configuration
65
  # ========================
@@ -67,22 +69,68 @@ def periodic_cleanup_task():
67
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
68
  REPO_ID = "Lefei/VisionTSpp"
69
  LOCAL_DIR = "./hf_models/VisionTSpp"
70
- CKPT_PATH = os.path.join(LOCAL_DIR, "visiontspp_model.ckpt")
71
- ARCH = 'mae_base'
72
 
73
- if not os.path.exists(CKPT_PATH):
74
- from huggingface_hub import snapshot_download
75
- print("Downloading model from Hugging Face Hub...")
76
- snapshot_download(repo_id=REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=False, resume_download=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
79
- # Assuming VisionTSpp is defined in a separate file or installed package
80
- # from visionts import VisionTSpp, freq_to_seasonality_list # Placeholder for your model import
81
- model = VisionTSpp(ARCH, ckpt_path=CKPT_PATH, quantile=True, clip_input=True, complete_no_clip=False, color=True).to(DEVICE)
82
- print(f"Model loaded on {DEVICE}")
83
 
84
- imagenet_mean = np.array([0.485, 0.456, 0.406])
85
- imagenet_std = np.array([0.229, 0.224, 0.225])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
 
88
  # ========================
@@ -182,7 +230,7 @@ class PredictionResult:
182
  self.total_samples = total_samples
183
  self.inferred_freq = inferred_freq
184
 
185
- def predict_at_index(df, index, context_len, pred_len, session_dir):
186
  if 'date' not in df.columns:
187
  raise gr.Error("❌ Input CSV must contain a 'date' column.")
188
 
@@ -204,7 +252,7 @@ def predict_at_index(df, index, context_len, pred_len, session_dir):
204
  total_samples = len(data) - context_len - pred_len + 1
205
  if total_samples <= 0:
206
  raise gr.Error(f"Data is too short. It needs at least context_len + pred_len = {context_len + pred_len} rows, but has {len(data)}.")
207
-
208
  index = max(0, min(index, total_samples - 1))
209
 
210
  train_len = int(len(data) * 0.7)
@@ -219,9 +267,17 @@ def predict_at_index(df, index, context_len, pred_len, session_dir):
219
 
220
  periodicity_list = freq_to_seasonality_list(inferred_freq)
221
  periodicity = periodicity_list[0] if periodicity_list else 1
222
-
223
  color_list = [i % 3 for i in range(nvars)]
224
- model.update_config(context_len=context_len, pred_len=pred_len, periodicity=periodicity,
 
 
 
 
 
 
 
 
225
  num_patch_input=7, padding_mode='constant')
226
 
227
  with torch.no_grad():
@@ -242,7 +298,7 @@ def predict_at_index(df, index, context_len, pred_len, session_dir):
242
 
243
  full_true_context = data[start_idx : start_idx + context_len]
244
  full_true_series = np.concatenate([full_true_context, y_true], axis=0)
245
-
246
  ts_fig = visual_ts_with_quantiles(
247
  true_data=full_true_series, pred_median=pred_median,
248
  pred_quantiles_list=pred_quantiles, model_quantiles=list(all_preds.keys()),
@@ -275,9 +331,9 @@ def get_session_dir(session_id: gr.State):
275
  session_id = str(session_dir)
276
  return session_id
277
 
278
- def run_forecast(data_source, upload_file, index, context_len, pred_len, session_id: gr.State):
279
  session_dir = get_session_dir(session_id)
280
-
281
  try:
282
  if data_source == "Upload CSV":
283
  if upload_file is None:
@@ -289,10 +345,11 @@ def run_forecast(data_source, upload_file, index, context_len, pred_len, session
289
  df = load_preset_data(data_source)
290
 
291
  index, context_len, pred_len = int(index), int(context_len), int(pred_len)
292
- result = predict_at_index(df, index, context_len, pred_len, session_dir)
293
-
 
294
  final_index = min(index, result.total_samples - 1)
295
-
296
  return (
297
  result.ts_fig,
298
  result.input_img_fig,
@@ -302,7 +359,7 @@ def run_forecast(data_source, upload_file, index, context_len, pred_len, session
302
  gr.update(value=result.inferred_freq),
303
  session_dir
304
  )
305
-
306
  except Exception as e:
307
  print(f"Error during forecast: {e}")
308
  error_fig = plt.figure(figsize=(10, 5))
@@ -325,12 +382,19 @@ with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes
325
  - βœ… **Slide** through different samples of the dataset for real-time forecasting.
326
  - βœ… **Download** the prediction results as a CSV file.
327
  - βœ… **User Isolation**: Each user session has its own temporary storage to prevent file conflicts. Old files are automatically cleaned up.
 
328
  """
329
  )
330
-
331
  with gr.Row():
332
  with gr.Column(scale=1, min_width=300):
333
  gr.Markdown("### 1. Data & Model Configuration")
 
 
 
 
 
 
334
  data_source = gr.Dropdown(
335
  label="Select Data Source",
336
  choices=list(PRESET_DATASETS.keys()) + ["Upload CSV"],
@@ -344,13 +408,13 @@ with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes
344
  2. Must contain a time column named `date` with a consistent frequency.
345
  """
346
  )
347
-
348
  context_len = gr.Number(label="Context Length (History)", value=336)
349
  pred_len = gr.Number(label="Prediction Length (Future)", value=96)
350
  freq_display = gr.Textbox(label="Detected Frequency", interactive=False)
351
 
352
  run_btn = gr.Button("πŸš€ Run Forecast", variant="primary")
353
-
354
  gr.Markdown("### 2. Sample Selection")
355
  sample_index = gr.Slider(
356
  label="Sample Index",
@@ -395,7 +459,8 @@ with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes
395
 
396
  data_source.change(fn=toggle_upload_visibility, inputs=data_source, outputs=upload_file)
397
 
398
- inputs = [data_source, upload_file, sample_index, context_len, pred_len, session_id_state]
 
399
  outputs = [ts_plot, input_img_plot, recon_img_plot, download_csv, sample_index, freq_display, session_id_state]
400
 
401
  run_btn.click(fn=run_forecast, inputs=inputs, outputs=outputs, api_name="run_forecast")
@@ -408,10 +473,10 @@ with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes
408
  if __name__ == "__main__":
409
  # --- Run initial cleanup on startup ---
410
  cleanup_old_sessions()
411
-
412
  # --- NEW: Start the periodic cleanup in a background daemon thread ---
413
  cleanup_thread = threading.Thread(target=periodic_cleanup_task, daemon=True)
414
  cleanup_thread.start()
415
 
416
  # --- Launch the Gradio app ---
417
- demo.launch(debug=True)
 
10
  import uuid
11
  import shutil
12
  import time
13
+ import threading
14
  from pathlib import Path
15
 
16
  from huggingface_hub import snapshot_download
17
+ # Assuming visionts package is available
18
  from visionts import VisionTSpp, freq_to_seasonality_list
19
 
20
  # ========================
 
61
  cleanup_old_sessions()
62
  time.sleep(CLEANUP_INTERVAL_SECONDS)
63
 
64
+
65
  # ========================
66
  # 1. Model Configuration
67
  # ========================
 
69
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
70
  REPO_ID = "Lefei/VisionTSpp"
71
  LOCAL_DIR = "./hf_models/VisionTSpp"
 
 
72
 
73
+ # --- Define model configurations ---
74
+ MODEL_CONFIGS = {
75
+ "base": {
76
+ "arch": 'mae_base',
77
+ "ckpt_path": os.path.join(LOCAL_DIR, "visiontspp_base.ckpt")
78
+ },
79
+ "large": {
80
+ "arch": 'mae_large',
81
+ "ckpt_path": os.path.join(LOCAL_DIR, "visiontspp_large.ckpt")
82
+ }
83
+ }
84
+
85
+ # Download both checkpoints if they don't exist
86
+ for model_size, config in MODEL_CONFIGS.items():
87
+ if not os.path.exists(config["ckpt_path"]):
88
+ print(f"Downloading {model_size} model from Hugging Face Hub...")
89
+ snapshot_download(
90
+ repo_id=REPO_ID,
91
+ local_dir=LOCAL_DIR,
92
+ local_dir_use_symlinks=False,
93
+ resume_download=True,
94
+ allow_patterns=[f"*{model_size}*"] # Download only relevant files if possible
95
+ )
96
 
97
  QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
 
 
 
 
98
 
99
+ # --- NEW: Global variables to hold the currently loaded model and its size ---
100
+ CURRENT_MODEL_SIZE = None
101
+ CURRENT_MODEL = None
102
+
103
+ def load_model_for_size(model_size: str):
104
+ """Loads the specified VisionTS++ model."""
105
+ global CURRENT_MODEL, CURRENT_MODEL_SIZE
106
+ if model_size not in MODEL_CONFIGS:
107
+ raise ValueError(f"Invalid model size: {model_size}. Available: {list(MODEL_CONFIGS.keys())}")
108
+
109
+ config = MODEL_CONFIGS[model_size]
110
+ print(f"Loading {model_size} model...")
111
+ model = VisionTSpp(
112
+ config["arch"],
113
+ ckpt_path=config["ckpt_path"],
114
+ quantile=True,
115
+ clip_input=True,
116
+ complete_no_clip=False,
117
+ color=True
118
+ ).to(DEVICE)
119
+ print(f"Model {model_size} loaded on {DEVICE}")
120
+
121
+ # Unload the previous model to free memory if it was loaded
122
+ if CURRENT_MODEL is not None:
123
+ print(f"Unloading previous model ({CURRENT_MODEL_SIZE})...")
124
+ del CURRENT_MODEL
125
+ torch.cuda.empty_cache() # Clear GPU cache if using CUDA
126
+
127
+ CURRENT_MODEL = model
128
+ CURRENT_MODEL_SIZE = model_size
129
+ return model
130
+
131
+ # Load the default model (base) on startup
132
+ if CURRENT_MODEL is None:
133
+ CURRENT_MODEL = load_model_for_size("base") # Or "large" as default
134
 
135
 
136
  # ========================
 
230
  self.total_samples = total_samples
231
  self.inferred_freq = inferred_freq
232
 
233
+ def predict_at_index(df, index, context_len, pred_len, session_dir, model_size):
234
  if 'date' not in df.columns:
235
  raise gr.Error("❌ Input CSV must contain a 'date' column.")
236
 
 
252
  total_samples = len(data) - context_len - pred_len + 1
253
  if total_samples <= 0:
254
  raise gr.Error(f"Data is too short. It needs at least context_len + pred_len = {context_len + pred_len} rows, but has {len(data)}.")
255
+
256
  index = max(0, min(index, total_samples - 1))
257
 
258
  train_len = int(len(data) * 0.7)
 
267
 
268
  periodicity_list = freq_to_seasonality_list(inferred_freq)
269
  periodicity = periodicity_list[0] if periodicity_list else 1
270
+
271
  color_list = [i % 3 for i in range(nvars)]
272
+
273
+ # --- NEW: Load the requested model if it's not the current one ---
274
+ if CURRENT_MODEL_SIZE != model_size:
275
+ print(f"Switching model from {CURRENT_MODEL_SIZE} to {model_size}")
276
+ load_model_for_size(model_size)
277
+
278
+ # --- Use the currently loaded model ---
279
+ model = CURRENT_MODEL
280
+ model.update_config(context_len=context_len, pred_len=pred_len, periodicity=periodicity,
281
  num_patch_input=7, padding_mode='constant')
282
 
283
  with torch.no_grad():
 
298
 
299
  full_true_context = data[start_idx : start_idx + context_len]
300
  full_true_series = np.concatenate([full_true_context, y_true], axis=0)
301
+
302
  ts_fig = visual_ts_with_quantiles(
303
  true_data=full_true_series, pred_median=pred_median,
304
  pred_quantiles_list=pred_quantiles, model_quantiles=list(all_preds.keys()),
 
331
  session_id = str(session_dir)
332
  return session_id
333
 
334
+ def run_forecast(data_source, upload_file, index, context_len, pred_len, model_size, session_id: gr.State):
335
  session_dir = get_session_dir(session_id)
336
+
337
  try:
338
  if data_source == "Upload CSV":
339
  if upload_file is None:
 
345
  df = load_preset_data(data_source)
346
 
347
  index, context_len, pred_len = int(index), int(context_len), int(pred_len)
348
+ # --- Pass model_size to predict_at_index ---
349
+ result = predict_at_index(df, index, context_len, pred_len, session_dir, model_size)
350
+
351
  final_index = min(index, result.total_samples - 1)
352
+
353
  return (
354
  result.ts_fig,
355
  result.input_img_fig,
 
359
  gr.update(value=result.inferred_freq),
360
  session_dir
361
  )
362
+
363
  except Exception as e:
364
  print(f"Error during forecast: {e}")
365
  error_fig = plt.figure(figsize=(10, 5))
 
382
  - βœ… **Slide** through different samples of the dataset for real-time forecasting.
383
  - βœ… **Download** the prediction results as a CSV file.
384
  - βœ… **User Isolation**: Each user session has its own temporary storage to prevent file conflicts. Old files are automatically cleaned up.
385
+ - βœ… **Model Selection**: Choose between 'base' and 'large' VisionTS++ models.
386
  """
387
  )
388
+
389
  with gr.Row():
390
  with gr.Column(scale=1, min_width=300):
391
  gr.Markdown("### 1. Data & Model Configuration")
392
+ # --- NEW: Add model selection dropdown ---
393
+ model_size = gr.Dropdown(
394
+ label="Select Model Size",
395
+ choices=["base", "large"],
396
+ value="base" # Default to base
397
+ )
398
  data_source = gr.Dropdown(
399
  label="Select Data Source",
400
  choices=list(PRESET_DATASETS.keys()) + ["Upload CSV"],
 
408
  2. Must contain a time column named `date` with a consistent frequency.
409
  """
410
  )
411
+
412
  context_len = gr.Number(label="Context Length (History)", value=336)
413
  pred_len = gr.Number(label="Prediction Length (Future)", value=96)
414
  freq_display = gr.Textbox(label="Detected Frequency", interactive=False)
415
 
416
  run_btn = gr.Button("πŸš€ Run Forecast", variant="primary")
417
+
418
  gr.Markdown("### 2. Sample Selection")
419
  sample_index = gr.Slider(
420
  label="Sample Index",
 
459
 
460
  data_source.change(fn=toggle_upload_visibility, inputs=data_source, outputs=upload_file)
461
 
462
+ # --- NEW: Include model_size in the inputs list ---
463
+ inputs = [data_source, upload_file, sample_index, context_len, pred_len, model_size, session_id_state]
464
  outputs = [ts_plot, input_img_plot, recon_img_plot, download_csv, sample_index, freq_display, session_id_state]
465
 
466
  run_btn.click(fn=run_forecast, inputs=inputs, outputs=outputs, api_name="run_forecast")
 
473
  if __name__ == "__main__":
474
  # --- Run initial cleanup on startup ---
475
  cleanup_old_sessions()
476
+
477
  # --- NEW: Start the periodic cleanup in a background daemon thread ---
478
  cleanup_thread = threading.Thread(target=periodic_cleanup_task, daemon=True)
479
  cleanup_thread.start()
480
 
481
  # --- Launch the Gradio app ---
482
+ demo.launch(debug=True)