Lefei commited on
Commit
da5af17
·
verified ·
1 Parent(s): 0d5823b

update app.py, add choice button for VisionTSpp base and large

Browse files
Files changed (1) hide show
  1. app.py +46 -33
app.py CHANGED
@@ -221,7 +221,13 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
221
  # pred_range = np.arange(context_len, context_len + pred_len)
222
  pred_range = np.arange(context_len-1, context_len + pred_len)
223
 
 
 
 
 
224
  pred_median_visual = true_data[context_len-1:context_len, i] + pred_median[:, i]
 
 
225
  ax.plot(pred_range, pred_median_visual, label='Prediction (Median)', color='red', linewidth=1.5)
226
  # ax.plot(pred_range, pred_median[:, i], label='Prediction (Median)', color='red', linewidth=1.5)
227
 
@@ -229,7 +235,12 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
229
  lower_quantile_pred, upper_quantile_pred = quantile_preds[j][:, i], quantile_preds[-(j+1)][:, i]
230
  lower_quantile_pred_visual = true_data[context_len-1:context_len, i] + lower_quantile_pred
231
  upper_quantile_pred_visual = true_data[context_len-1:context_len, i] + upper_quantile_pred
232
-
 
 
 
 
 
233
  q_low, q_high = quantile_vals[j], quantile_vals[-(j+1)]
234
  ax.fill_between(pred_range, lower_quantile_pred_visual, upper_quantile_pred_visual, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
235
  # ax.fill_between(pred_range, lower_quantile_pred, upper_quantile_pred, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
@@ -336,6 +347,7 @@ def predict_at_index(df, index, context_len, pred_len, session_dir, model_size):
336
  full_true_context = data[start_idx : start_idx + context_len]
337
  full_true_series = np.concatenate([full_true_context, y_true], axis=0)
338
 
 
339
  ts_fig = visual_ts_with_quantiles(
340
  true_data=full_true_series, pred_median=pred_median,
341
  pred_quantiles_list=pred_quantiles, model_quantiles=list(all_preds.keys()),
@@ -375,39 +387,40 @@ def get_session_dir(session_id: gr.State):
375
  def run_forecast(data_source, upload_file, index, context_len, pred_len, model_size, session_id: gr.State):
376
  session_dir = get_session_dir(session_id)
377
 
378
- try:
379
- if data_source == "Upload CSV":
380
- if upload_file is None:
381
- raise gr.Error("Please upload a CSV file when 'Upload CSV' is selected.")
382
- uploaded_file_path = Path(session_dir) / Path(upload_file.name).name
383
- shutil.copy(upload_file.name, uploaded_file_path)
384
- df = pd.read_csv(uploaded_file_path)
385
- else:
386
- df = load_preset_data(data_source)
387
-
388
- index, context_len, pred_len = int(index), int(context_len), int(pred_len)
389
- # --- Pass model_size to predict_at_index ---
390
- result = predict_at_index(df, index, context_len, pred_len, session_dir, model_size)
391
-
392
- final_index = min(index, result.total_samples - 1)
393
-
394
- return (
395
- result.ts_fig,
396
- result.input_img_fig,
397
- result.recon_img_fig,
398
- result.csv_path,
399
- gr.update(maximum=result.total_samples - 1, value=final_index),
400
- gr.update(value=result.inferred_freq),
401
- session_dir
402
- )
403
 
404
- except Exception as e:
405
- print(f"Error during forecast: {e}")
406
- error_fig = plt.figure(figsize=(10, 5))
407
- plt.text(0.5, 0.5, f"An error occurred:\n{str(e)}", ha='center', va='center', wrap=True, color='red', fontsize=12)
408
- plt.axis('off')
409
- plt.close(error_fig)
410
- return error_fig, None, None, None, gr.update(), gr.update(value="Error"), session_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
 
412
 
413
  with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes.Soft()) as demo:
 
221
  # pred_range = np.arange(context_len, context_len + pred_len)
222
  pred_range = np.arange(context_len-1, context_len + pred_len)
223
 
224
+ print(true_data[:, i].shape)
225
+ print(pred_median[:, i].shape)
226
+ print(pred_range.shape)
227
+
228
  pred_median_visual = true_data[context_len-1:context_len, i] + pred_median[:, i]
229
+ print(pred_median_visual.shape)
230
+
231
  ax.plot(pred_range, pred_median_visual, label='Prediction (Median)', color='red', linewidth=1.5)
232
  # ax.plot(pred_range, pred_median[:, i], label='Prediction (Median)', color='red', linewidth=1.5)
233
 
 
235
  lower_quantile_pred, upper_quantile_pred = quantile_preds[j][:, i], quantile_preds[-(j+1)][:, i]
236
  lower_quantile_pred_visual = true_data[context_len-1:context_len, i] + lower_quantile_pred
237
  upper_quantile_pred_visual = true_data[context_len-1:context_len, i] + upper_quantile_pred
238
+
239
+ print(lower_quantile_pred.shape)
240
+ print(upper_quantile_pred.shape)
241
+ print(lower_quantile_pred_visual.shape)
242
+ print(upper_quantile_pred_visual.shape)
243
+
244
  q_low, q_high = quantile_vals[j], quantile_vals[-(j+1)]
245
  ax.fill_between(pred_range, lower_quantile_pred_visual, upper_quantile_pred_visual, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
246
  # ax.fill_between(pred_range, lower_quantile_pred, upper_quantile_pred, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
 
347
  full_true_context = data[start_idx : start_idx + context_len]
348
  full_true_series = np.concatenate([full_true_context, y_true], axis=0)
349
 
350
+
351
  ts_fig = visual_ts_with_quantiles(
352
  true_data=full_true_series, pred_median=pred_median,
353
  pred_quantiles_list=pred_quantiles, model_quantiles=list(all_preds.keys()),
 
387
  def run_forecast(data_source, upload_file, index, context_len, pred_len, model_size, session_id: gr.State):
388
  session_dir = get_session_dir(session_id)
389
 
390
+ # try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
 
392
+ if data_source == "Upload CSV":
393
+ if upload_file is None:
394
+ raise gr.Error("Please upload a CSV file when 'Upload CSV' is selected.")
395
+ uploaded_file_path = Path(session_dir) / Path(upload_file.name).name
396
+ shutil.copy(upload_file.name, uploaded_file_path)
397
+ df = pd.read_csv(uploaded_file_path)
398
+ else:
399
+ df = load_preset_data(data_source)
400
+
401
+ index, context_len, pred_len = int(index), int(context_len), int(pred_len)
402
+ # --- Pass model_size to predict_at_index ---
403
+ result = predict_at_index(df, index, context_len, pred_len, session_dir, model_size)
404
+
405
+ final_index = min(index, result.total_samples - 1)
406
+
407
+ return (
408
+ result.ts_fig,
409
+ result.input_img_fig,
410
+ result.recon_img_fig,
411
+ result.csv_path,
412
+ gr.update(maximum=result.total_samples - 1, value=final_index),
413
+ gr.update(value=result.inferred_freq),
414
+ session_dir
415
+ )
416
+
417
+ # except Exception as e:
418
+ # print(f"Error during forecast: {e}")
419
+ # error_fig = plt.figure(figsize=(10, 5))
420
+ # plt.text(0.5, 0.5, f"An error occurred:\n{str(e)}", ha='center', va='center', wrap=True, color='red', fontsize=12)
421
+ # plt.axis('off')
422
+ # plt.close(error_fig)
423
+ # return error_fig, None, None, None, gr.update(), gr.update(value="Error"), session_id
424
 
425
 
426
  with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes.Soft()) as demo: