Spaces:
Running
Running
update app.py, add choice button for VisionTSpp base and large
Browse files
app.py
CHANGED
|
@@ -221,7 +221,13 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
|
|
| 221 |
# pred_range = np.arange(context_len, context_len + pred_len)
|
| 222 |
pred_range = np.arange(context_len-1, context_len + pred_len)
|
| 223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
pred_median_visual = true_data[context_len-1:context_len, i] + pred_median[:, i]
|
|
|
|
|
|
|
| 225 |
ax.plot(pred_range, pred_median_visual, label='Prediction (Median)', color='red', linewidth=1.5)
|
| 226 |
# ax.plot(pred_range, pred_median[:, i], label='Prediction (Median)', color='red', linewidth=1.5)
|
| 227 |
|
|
@@ -229,7 +235,12 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
|
|
| 229 |
lower_quantile_pred, upper_quantile_pred = quantile_preds[j][:, i], quantile_preds[-(j+1)][:, i]
|
| 230 |
lower_quantile_pred_visual = true_data[context_len-1:context_len, i] + lower_quantile_pred
|
| 231 |
upper_quantile_pred_visual = true_data[context_len-1:context_len, i] + upper_quantile_pred
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
q_low, q_high = quantile_vals[j], quantile_vals[-(j+1)]
|
| 234 |
ax.fill_between(pred_range, lower_quantile_pred_visual, upper_quantile_pred_visual, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
|
| 235 |
# ax.fill_between(pred_range, lower_quantile_pred, upper_quantile_pred, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
|
|
@@ -336,6 +347,7 @@ def predict_at_index(df, index, context_len, pred_len, session_dir, model_size):
|
|
| 336 |
full_true_context = data[start_idx : start_idx + context_len]
|
| 337 |
full_true_series = np.concatenate([full_true_context, y_true], axis=0)
|
| 338 |
|
|
|
|
| 339 |
ts_fig = visual_ts_with_quantiles(
|
| 340 |
true_data=full_true_series, pred_median=pred_median,
|
| 341 |
pred_quantiles_list=pred_quantiles, model_quantiles=list(all_preds.keys()),
|
|
@@ -375,39 +387,40 @@ def get_session_dir(session_id: gr.State):
|
|
| 375 |
def run_forecast(data_source, upload_file, index, context_len, pred_len, model_size, session_id: gr.State):
|
| 376 |
session_dir = get_session_dir(session_id)
|
| 377 |
|
| 378 |
-
try:
|
| 379 |
-
if data_source == "Upload CSV":
|
| 380 |
-
if upload_file is None:
|
| 381 |
-
raise gr.Error("Please upload a CSV file when 'Upload CSV' is selected.")
|
| 382 |
-
uploaded_file_path = Path(session_dir) / Path(upload_file.name).name
|
| 383 |
-
shutil.copy(upload_file.name, uploaded_file_path)
|
| 384 |
-
df = pd.read_csv(uploaded_file_path)
|
| 385 |
-
else:
|
| 386 |
-
df = load_preset_data(data_source)
|
| 387 |
-
|
| 388 |
-
index, context_len, pred_len = int(index), int(context_len), int(pred_len)
|
| 389 |
-
# --- Pass model_size to predict_at_index ---
|
| 390 |
-
result = predict_at_index(df, index, context_len, pred_len, session_dir, model_size)
|
| 391 |
-
|
| 392 |
-
final_index = min(index, result.total_samples - 1)
|
| 393 |
-
|
| 394 |
-
return (
|
| 395 |
-
result.ts_fig,
|
| 396 |
-
result.input_img_fig,
|
| 397 |
-
result.recon_img_fig,
|
| 398 |
-
result.csv_path,
|
| 399 |
-
gr.update(maximum=result.total_samples - 1, value=final_index),
|
| 400 |
-
gr.update(value=result.inferred_freq),
|
| 401 |
-
session_dir
|
| 402 |
-
)
|
| 403 |
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 411 |
|
| 412 |
|
| 413 |
with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes.Soft()) as demo:
|
|
|
|
| 221 |
# pred_range = np.arange(context_len, context_len + pred_len)
|
| 222 |
pred_range = np.arange(context_len-1, context_len + pred_len)
|
| 223 |
|
| 224 |
+
print(true_data[:, i].shape)
|
| 225 |
+
print(pred_median[:, i].shape)
|
| 226 |
+
print(pred_range.shape)
|
| 227 |
+
|
| 228 |
pred_median_visual = true_data[context_len-1:context_len, i] + pred_median[:, i]
|
| 229 |
+
print(pred_median_visual.shape)
|
| 230 |
+
|
| 231 |
ax.plot(pred_range, pred_median_visual, label='Prediction (Median)', color='red', linewidth=1.5)
|
| 232 |
# ax.plot(pred_range, pred_median[:, i], label='Prediction (Median)', color='red', linewidth=1.5)
|
| 233 |
|
|
|
|
| 235 |
lower_quantile_pred, upper_quantile_pred = quantile_preds[j][:, i], quantile_preds[-(j+1)][:, i]
|
| 236 |
lower_quantile_pred_visual = true_data[context_len-1:context_len, i] + lower_quantile_pred
|
| 237 |
upper_quantile_pred_visual = true_data[context_len-1:context_len, i] + upper_quantile_pred
|
| 238 |
+
|
| 239 |
+
print(lower_quantile_pred.shape)
|
| 240 |
+
print(upper_quantile_pred.shape)
|
| 241 |
+
print(lower_quantile_pred_visual.shape)
|
| 242 |
+
print(upper_quantile_pred_visual.shape)
|
| 243 |
+
|
| 244 |
q_low, q_high = quantile_vals[j], quantile_vals[-(j+1)]
|
| 245 |
ax.fill_between(pred_range, lower_quantile_pred_visual, upper_quantile_pred_visual, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
|
| 246 |
# ax.fill_between(pred_range, lower_quantile_pred, upper_quantile_pred, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
|
|
|
|
| 347 |
full_true_context = data[start_idx : start_idx + context_len]
|
| 348 |
full_true_series = np.concatenate([full_true_context, y_true], axis=0)
|
| 349 |
|
| 350 |
+
|
| 351 |
ts_fig = visual_ts_with_quantiles(
|
| 352 |
true_data=full_true_series, pred_median=pred_median,
|
| 353 |
pred_quantiles_list=pred_quantiles, model_quantiles=list(all_preds.keys()),
|
|
|
|
| 387 |
def run_forecast(data_source, upload_file, index, context_len, pred_len, model_size, session_id: gr.State):
|
| 388 |
session_dir = get_session_dir(session_id)
|
| 389 |
|
| 390 |
+
# try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
|
| 392 |
+
if data_source == "Upload CSV":
|
| 393 |
+
if upload_file is None:
|
| 394 |
+
raise gr.Error("Please upload a CSV file when 'Upload CSV' is selected.")
|
| 395 |
+
uploaded_file_path = Path(session_dir) / Path(upload_file.name).name
|
| 396 |
+
shutil.copy(upload_file.name, uploaded_file_path)
|
| 397 |
+
df = pd.read_csv(uploaded_file_path)
|
| 398 |
+
else:
|
| 399 |
+
df = load_preset_data(data_source)
|
| 400 |
+
|
| 401 |
+
index, context_len, pred_len = int(index), int(context_len), int(pred_len)
|
| 402 |
+
# --- Pass model_size to predict_at_index ---
|
| 403 |
+
result = predict_at_index(df, index, context_len, pred_len, session_dir, model_size)
|
| 404 |
+
|
| 405 |
+
final_index = min(index, result.total_samples - 1)
|
| 406 |
+
|
| 407 |
+
return (
|
| 408 |
+
result.ts_fig,
|
| 409 |
+
result.input_img_fig,
|
| 410 |
+
result.recon_img_fig,
|
| 411 |
+
result.csv_path,
|
| 412 |
+
gr.update(maximum=result.total_samples - 1, value=final_index),
|
| 413 |
+
gr.update(value=result.inferred_freq),
|
| 414 |
+
session_dir
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
# except Exception as e:
|
| 418 |
+
# print(f"Error during forecast: {e}")
|
| 419 |
+
# error_fig = plt.figure(figsize=(10, 5))
|
| 420 |
+
# plt.text(0.5, 0.5, f"An error occurred:\n{str(e)}", ha='center', va='center', wrap=True, color='red', fontsize=12)
|
| 421 |
+
# plt.axis('off')
|
| 422 |
+
# plt.close(error_fig)
|
| 423 |
+
# return error_fig, None, None, None, gr.update(), gr.update(value="Error"), session_id
|
| 424 |
|
| 425 |
|
| 426 |
with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes.Soft()) as demo:
|