Lefei commited on
Commit
ae4082b
·
verified ·
1 Parent(s): fbd13aa

update app.py, add choice button for VisionTSpp base and large

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -210,6 +210,7 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
210
 
211
  pred_quantiles_list.insert(len(QUANTILES)//2, pred_median)
212
  sorted_quantiles = sorted(zip(QUANTILES, pred_quantiles_list), key=lambda x: x[0])
 
213
  quantile_preds = [item[1] for item in sorted_quantiles if item[0] != 0.5]
214
  quantile_vals = [item[0] for item in sorted_quantiles if item[0] != 0.5]
215
  num_bands = len(quantile_preds) // 2
@@ -217,7 +218,9 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
217
 
218
  for i, ax in enumerate(axes):
219
  ax.plot(true_data[:, i], label='Ground Truth', color='black', linewidth=1.5)
220
- pred_range = np.arange(context_len, context_len + pred_len)
 
 
221
  ax.plot(pred_range, pred_median[:, i], label='Prediction (Median)', color='red', linewidth=1.5)
222
 
223
  for j in range(num_bands):
@@ -226,7 +229,9 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
226
  ax.fill_between(pred_range, lower_quantile_pred, upper_quantile_pred, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
227
 
228
  y_min, y_max = ax.get_ylim()
229
- ax.vlines(x=context_len, ymin=y_min, ymax=y_max, colors='gray', linestyles='--', alpha=0.7)
 
 
230
  ax.set_ylabel(f'Var {i+1}', rotation=0, labelpad=30, ha='right', va='center')
231
  ax.grid(True, which='both', linestyle='--', linewidth=0.5)
232
  ax.margins(x=0)
 
210
 
211
  pred_quantiles_list.insert(len(QUANTILES)//2, pred_median)
212
  sorted_quantiles = sorted(zip(QUANTILES, pred_quantiles_list), key=lambda x: x[0])
213
+
214
  quantile_preds = [item[1] for item in sorted_quantiles if item[0] != 0.5]
215
  quantile_vals = [item[0] for item in sorted_quantiles if item[0] != 0.5]
216
  num_bands = len(quantile_preds) // 2
 
218
 
219
  for i, ax in enumerate(axes):
220
  ax.plot(true_data[:, i], label='Ground Truth', color='black', linewidth=1.5)
221
+ # pred_range = np.arange(context_len, context_len + pred_len)
222
+ pred_range = np.arange(context_len-1, context_len + pred_len)
223
+
224
  ax.plot(pred_range, pred_median[:, i], label='Prediction (Median)', color='red', linewidth=1.5)
225
 
226
  for j in range(num_bands):
 
229
  ax.fill_between(pred_range, lower_quantile_pred, upper_quantile_pred, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
230
 
231
  y_min, y_max = ax.get_ylim()
232
+ # ax.vlines(x=context_len, ymin=y_min, ymax=y_max, colors='gray', linestyles='--', alpha=0.7)
233
+ ax.vlines(x=context_len-1, ymin=y_min, ymax=y_max, colors='gray', linestyles='--', alpha=0.7)
234
+
235
  ax.set_ylabel(f'Var {i+1}', rotation=0, labelpad=30, ha='right', va='center')
236
  ax.grid(True, which='both', linestyle='--', linewidth=0.5)
237
  ax.margins(x=0)