zqiao11 commited on
Commit
003ab1c
Β·
1 Parent(s): e580cf0

perf: Add cache for predictions

Browse files
Files changed (2) hide show
  1. src/leaderboard.py +30 -0
  2. src/tab.py +39 -47
src/leaderboard.py CHANGED
@@ -191,6 +191,36 @@ def _load_metrics_cached(model_name, dataset_term, horizon):
191
  return result
192
 
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  def load_test_windows(display_name, horizon, model_name="moirai_small", series=None, variate=None, window_id=None, parse_series=False):
195
  """
196
  Load test window results from TIME NPZ files.
 
191
  return result
192
 
193
 
194
+ @lru_cache(maxsize=10)
195
+ def _load_predictions_cached(model_name, dataset_term, horizon):
196
+ """
197
+ Load and cache predictions.npz for a specific (model, dataset, horizon).
198
+
199
+ predictions.npz is the heaviest file loaded in the Per Test Window tab
200
+ (contains quantile predictions for all series / windows / variates).
201
+ On HF Space the disk I/O is much slower than a local SSD, so caching
202
+ avoids redundant reads when cascading Gradio change events re-trigger
203
+ plot_window_series with identical parameters.
204
+
205
+ maxsize=10 keeps memory bounded β€” at most 10 (model, dataset, horizon)
206
+ combos stay in RAM at a time.
207
+
208
+ Returns:
209
+ dict with keys from the npz file (typically "predictions_quantiles"
210
+ and "quantile_levels"), or None if the file does not exist.
211
+ """
212
+ results_root = str(RESULTS_ROOT)
213
+ predictions_path = os.path.join(
214
+ results_root, model_name, dataset_term, horizon, "predictions.npz"
215
+ )
216
+ if not os.path.exists(predictions_path):
217
+ return None
218
+ npz = np.load(predictions_path)
219
+ result = {k: npz[k] for k in npz.files}
220
+ npz.close()
221
+ return result
222
+
223
+
224
  def load_test_windows(display_name, horizon, model_name="moirai_small", series=None, variate=None, window_id=None, parse_series=False):
225
  """
226
  Load test window results from TIME NPZ files.
src/tab.py CHANGED
@@ -32,7 +32,7 @@ import gradio as gr
32
  from src.about import DATASET_CHOICES, ALL_MODELS, RESULTS_ROOT, FEATURES_DF, FEATURES_BOOL_DF, PATTERN_MAP
33
  from src.leaderboard import (get_overall_leaderboard, get_dataset_multilevel_leaderboard,
34
  get_window_leaderboard, get_pattern_leaderboard, resolve_dataset_id,
35
- _get_dataset_metadata)
36
  from src.about import DATASETS_DF, ALL_HORIZONS
37
  # get_datasets_root, get_config_root no longer needed here β€” handled by _get_dataset_metadata
38
  import numpy as np
@@ -276,20 +276,16 @@ def plot_window_series(display_name, series, variate, window_id, horizon, select
276
  fig.update_layout(title="Dataset not found")
277
  return fig, ""
278
 
279
- predictions_path = os.path.join(results_root, model, dataset_term, horizon, "predictions.npz")
280
- print(f"πŸ“ predictions_path: {predictions_path}, exists: {os.path.exists(predictions_path)}")
281
-
282
- if not os.path.exists(predictions_path):
283
- print("❌ Predictions file not found")
284
  fig = go.Figure()
285
- fig.update_layout(title="Predictions file not found")
286
  return fig, ""
287
 
288
-
289
- predictions = np.load(predictions_path)
290
- # Load pre-computed quantiles (new format only)
291
- predictions_quantiles = predictions["predictions_quantiles"] # (num_series, num_windows, 9, num_variates, prediction_length)
292
- quantile_levels = predictions["quantile_levels"] # [0.1, 0.2, ..., 0.9]
293
 
294
  # Load prediction scale factor from config.json (for float16 overflow prevention)
295
  model_config_path = os.path.join(results_root, model, dataset_term, horizon, "config.json")
@@ -300,6 +296,7 @@ def plot_window_series(display_name, series, variate, window_id, horizon, select
300
  prediction_scale_factor = model_config.get("prediction_scale_factor", 1.0)
301
  if prediction_scale_factor != 1.0:
302
  print(f"πŸ“Š Applying inverse scale factor: {prediction_scale_factor}")
 
303
  predictions_quantiles = predictions_quantiles.astype(np.float32) * prediction_scale_factor
304
 
305
  # Use cached metadata for name-to-index mappings and Dataset object
@@ -923,7 +920,20 @@ def init_per_window_tab(demo):
923
 
924
  table_window = gr.DataFrame(elem_classes="custom-table", interactive=False)
925
 
926
- # When dataset changes: first update horizon choices, then update dropdowns
 
 
 
 
 
 
 
 
 
 
 
 
 
927
  dataset_dropdown.change(
928
  fn=update_horizon_choices,
929
  inputs=[dataset_dropdown],
@@ -933,62 +943,44 @@ def init_per_window_tab(demo):
933
  inputs=[dataset_dropdown, horizons],
934
  outputs=[series_dropdown, variate_dropdown, window_dropdown],
935
  ).then(
936
- # After dropdowns are updated, refresh the visualization and table
937
- fn=plot_window_series,
938
- inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons, quantiles, model],
939
- outputs=[ts_visualization, prediction_info]
940
  ).then(
941
- fn=get_window_leaderboard,
942
- inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons],
943
- outputs=table_window
944
  )
945
 
946
- # When horizon changes: update dropdowns, then refresh visualization
947
  horizons.change(
948
  fn=update_series_variate_and_window,
949
  inputs=[dataset_dropdown, horizons],
950
  outputs=[series_dropdown, variate_dropdown, window_dropdown],
951
  ).then(
952
- fn=plot_window_series,
953
- inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons, quantiles, model],
954
- outputs=[ts_visualization, prediction_info]
955
  ).then(
956
- fn=get_window_leaderboard,
957
- inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons],
958
- outputs=table_window
959
  )
960
 
961
- # For series, variate, window changes - update visualization and table
 
 
962
  for comp in [series_dropdown, variate_dropdown, window_dropdown]:
963
  comp.change(
964
- fn=get_window_leaderboard,
965
- inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons],
966
- outputs=table_window
967
- )
968
- comp.change(
969
- fn=plot_window_series,
970
- inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons, quantiles, model],
971
- outputs=[ts_visualization, prediction_info]
972
  )
973
 
974
- # For quantiles and model changes - only update visualization (no table change needed)
975
  for comp in [quantiles, model]:
976
  comp.change(
977
- fn=plot_window_series,
978
- inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons, quantiles, model],
979
- outputs=[ts_visualization, prediction_info]
980
  )
981
 
982
- # Load initial visualization and table on page load
983
  demo.load(
984
- fn=plot_window_series,
985
- inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons, quantiles, model],
986
- outputs=[ts_visualization, prediction_info]
987
  )
988
  demo.load(
989
- fn=get_window_leaderboard,
990
- inputs=[dataset_dropdown, series_dropdown, variate_dropdown, window_dropdown, horizons],
991
- outputs=table_window
992
  )
993
 
994
  # CSV Export
 
32
  from src.about import DATASET_CHOICES, ALL_MODELS, RESULTS_ROOT, FEATURES_DF, FEATURES_BOOL_DF, PATTERN_MAP
33
  from src.leaderboard import (get_overall_leaderboard, get_dataset_multilevel_leaderboard,
34
  get_window_leaderboard, get_pattern_leaderboard, resolve_dataset_id,
35
+ _get_dataset_metadata, _load_predictions_cached)
36
  from src.about import DATASETS_DF, ALL_HORIZONS
37
  # get_datasets_root, get_config_root no longer needed here β€” handled by _get_dataset_metadata
38
  import numpy as np
 
276
  fig.update_layout(title="Dataset not found")
277
  return fig, ""
278
 
279
+ # --- Cached predictions loading (biggest I/O in Per Test Window) ---
280
+ pred_data = _load_predictions_cached(model, dataset_term, horizon)
281
+ if pred_data is None:
282
+ print(f"❌ Predictions file not found for {model}/{dataset_term}/{horizon}")
 
283
  fig = go.Figure()
284
+ fig.update_layout(title="Predictions file not found for this horizon")
285
  return fig, ""
286
 
287
+ predictions_quantiles = pred_data["predictions_quantiles"] # (num_series, num_windows, 9, num_variates, prediction_length)
288
+ quantile_levels = pred_data["quantile_levels"] # [0.1, 0.2, ..., 0.9]
 
 
 
289
 
290
  # Load prediction scale factor from config.json (for float16 overflow prevention)
291
  model_config_path = os.path.join(results_root, model, dataset_term, horizon, "config.json")
 
296
  prediction_scale_factor = model_config.get("prediction_scale_factor", 1.0)
297
  if prediction_scale_factor != 1.0:
298
  print(f"πŸ“Š Applying inverse scale factor: {prediction_scale_factor}")
299
+ # Copy to avoid mutating the cached array
300
  predictions_quantiles = predictions_quantiles.astype(np.float32) * prediction_scale_factor
301
 
302
  # Use cached metadata for name-to-index mappings and Dataset object
 
920
 
921
  table_window = gr.DataFrame(elem_classes="custom-table", interactive=False)
922
 
923
+ # ── Shared input / output lists ────────────────────────────────────
924
+ _plot_in = [dataset_dropdown, series_dropdown, variate_dropdown,
925
+ window_dropdown, horizons, quantiles, model]
926
+ _plot_out = [ts_visualization, prediction_info]
927
+ _tbl_in = [dataset_dropdown, series_dropdown, variate_dropdown,
928
+ window_dropdown, horizons]
929
+ _tbl_out = table_window
930
+
931
+ # ── dataset changes ─────────────────────────────────────────────────
932
+ # Chain: update horizons β†’ update dropdowns β†’ refresh plot β†’ refresh table.
933
+ # The chain already calls plot & table at the end, so we do NOT bind
934
+ # separate .change() on series/variate/window for this trigger path β€”
935
+ # otherwise updating the 3 dropdowns cascades into 3 extra duplicate
936
+ # plot_window_series calls (the #1 cause of slowness on HF Space).
937
  dataset_dropdown.change(
938
  fn=update_horizon_choices,
939
  inputs=[dataset_dropdown],
 
943
  inputs=[dataset_dropdown, horizons],
944
  outputs=[series_dropdown, variate_dropdown, window_dropdown],
945
  ).then(
946
+ fn=plot_window_series, inputs=_plot_in, outputs=_plot_out,
 
 
 
947
  ).then(
948
+ fn=get_window_leaderboard, inputs=_tbl_in, outputs=_tbl_out,
 
 
949
  )
950
 
951
+ # ── horizon changes ─────────────────────────────────────────────────
952
  horizons.change(
953
  fn=update_series_variate_and_window,
954
  inputs=[dataset_dropdown, horizons],
955
  outputs=[series_dropdown, variate_dropdown, window_dropdown],
956
  ).then(
957
+ fn=plot_window_series, inputs=_plot_in, outputs=_plot_out,
 
 
958
  ).then(
959
+ fn=get_window_leaderboard, inputs=_tbl_in, outputs=_tbl_out,
 
 
960
  )
961
 
962
+ # ── series / variate / window manual changes ────────────────────────
963
+ # Use a single .then() chain per dropdown so each user-initiated
964
+ # change fires plot + table exactly ONCE instead of 2 separate events.
965
  for comp in [series_dropdown, variate_dropdown, window_dropdown]:
966
  comp.change(
967
+ fn=plot_window_series, inputs=_plot_in, outputs=_plot_out,
968
+ ).then(
969
+ fn=get_window_leaderboard, inputs=_tbl_in, outputs=_tbl_out,
 
 
 
 
 
970
  )
971
 
972
+ # ── quantiles / model changes ───────────────────────────────────────
973
  for comp in [quantiles, model]:
974
  comp.change(
975
+ fn=plot_window_series, inputs=_plot_in, outputs=_plot_out,
 
 
976
  )
977
 
978
+ # ── initial page load ───────────────────────────────────────────────
979
  demo.load(
980
+ fn=plot_window_series, inputs=_plot_in, outputs=_plot_out,
 
 
981
  )
982
  demo.load(
983
+ fn=get_window_leaderboard, inputs=_tbl_in, outputs=_tbl_out,
 
 
984
  )
985
 
986
  # CSV Export