Udayshankar Ravikumar commited on
Commit
e131d6a
·
unverified ·
1 Parent(s): 43586c3

Preloading

Browse files
Files changed (1) hide show
  1. app.py +62 -21
app.py CHANGED
@@ -41,6 +41,11 @@ REQUIRED_COLS = [
41
  "l2_assoc",
42
  ]
43
 
 
 
 
 
 
44
  # -------------------------------------------------
45
  # Model Download
46
  # -------------------------------------------------
@@ -48,11 +53,9 @@ def ensure_models():
48
  if not os.path.exists(MODEL_DIR):
49
  snapshot_download(
50
  repo_id=HF_REPO_ID,
51
- local_dir='.',
52
- allow_patterns="*.pkl")
53
-
54
- print(f"Model dir: {MODEL_DIR}")
55
- print(os.listdir(MODEL_DIR))
56
 
57
  # -------------------------------------------------
58
  # Utilities
@@ -61,19 +64,39 @@ def resolve_workload(workload: str) -> str:
61
  return WORKLOAD_ALIAS.get(workload, workload)
62
 
63
  def load_model(workload: str, target: str):
64
- model_path = os.path.join(MODEL_DIR, f"model_{workload}_{target}.pkl")
65
- if not os.path.exists(model_path):
66
- raise FileNotFoundError(f"Model not found: {model_path}")
67
- payload = joblib.load(model_path)
68
- return payload["model"], payload["log_target"]
69
 
70
  def physical_sanity_check(ipc, miss_rate):
71
- warnings = []
72
  if ipc < 0 or ipc > 3.5:
73
- warnings.append(f"IPC={ipc:.3f} out of physical range")
74
  if miss_rate < 0 or miss_rate > 1:
75
- warnings.append(f"L2 miss rate={miss_rate:.3f} out of [0,1]")
76
- return warnings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  # -------------------------------------------------
79
  # Inference Core
@@ -85,8 +108,12 @@ def run_inference(df: pd.DataFrame) -> pd.DataFrame:
85
 
86
  # Feature engineering
87
  for col in [
88
- "l1d_size", "l1i_size", "l2_size",
89
- "l1d_assoc", "l1i_assoc", "l2_assoc",
 
 
 
 
90
  ]:
91
  df[f"{col}_log2"] = np.log2(df[col])
92
 
@@ -127,8 +154,6 @@ def run_inference(df: pd.DataFrame) -> pd.DataFrame:
127
  # Gradio Wrapper
128
  # -------------------------------------------------
129
  def infer_from_csv(file):
130
- ensure_models()
131
-
132
  df = pd.read_csv(file.name)
133
  out_df = run_inference(df)
134
 
@@ -148,6 +173,9 @@ def infer_from_csv(file):
148
  # UI
149
  # -------------------------------------------------
150
  with gr.Blocks(title="AIDE Chip Surrogate Inference") as demo:
 
 
 
151
  gr.Markdown(
152
  """
153
  # AIDE Chip Surrogate Inference
@@ -160,16 +188,29 @@ with gr.Blocks(title="AIDE Chip Surrogate Inference") as demo:
160
  )
161
 
162
  csv_input = gr.File(label="Input CSV", file_types=[".csv"])
163
- run_btn = gr.Button("Run Inference")
164
 
165
  preview = gr.Dataframe(label="Preview (first 20 rows)")
166
  output_csv = gr.File(label="Download Full Output CSV")
167
- warnings = gr.Textbox(label="Sanity Check Summary")
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  run_btn.click(
170
  infer_from_csv,
171
  inputs=csv_input,
172
- outputs=[preview, output_csv, warnings],
173
  )
174
 
175
  if __name__ == "__main__":
 
41
  "l2_assoc",
42
  ]
43
 
44
+ # -------------------------------------------------
45
+ # Global model cache
46
+ # -------------------------------------------------
47
+ MODEL_CACHE = {}
48
+
49
  # -------------------------------------------------
50
  # Model Download
51
  # -------------------------------------------------
 
53
  if not os.path.exists(MODEL_DIR):
54
  snapshot_download(
55
  repo_id=HF_REPO_ID,
56
+ local_dir=".",
57
+ allow_patterns="*.pkl",
58
+ )
 
 
59
 
60
  # -------------------------------------------------
61
  # Utilities
 
64
  return WORKLOAD_ALIAS.get(workload, workload)
65
 
66
  def load_model(workload: str, target: str):
67
+ try:
68
+ return MODEL_CACHE[(workload, target)]
69
+ except KeyError:
70
+ raise RuntimeError(f"Model not preloaded: {workload}, {target}")
 
71
 
72
  def physical_sanity_check(ipc, miss_rate):
73
+ warnings_out = []
74
  if ipc < 0 or ipc > 3.5:
75
+ warnings_out.append(f"IPC={ipc:.3f} out of physical range")
76
  if miss_rate < 0 or miss_rate > 1:
77
+ warnings_out.append(f"L2 miss rate={miss_rate:.3f} out of [0,1]")
78
+ return warnings_out
79
+
80
+ # -------------------------------------------------
81
+ # Preload all models at startup
82
+ # -------------------------------------------------
83
+ def preload_models():
84
+ ensure_models()
85
+
86
+ workloads = set(WORKLOAD_ALIAS.values()) | {"matrix_mul"}
87
+
88
+ for workload in workloads:
89
+ for target in TARGETS:
90
+ model_path = os.path.join(
91
+ MODEL_DIR, f"model_{workload}_{target}.pkl"
92
+ )
93
+ payload = joblib.load(model_path)
94
+ MODEL_CACHE[(workload, target)] = (
95
+ payload["model"],
96
+ payload["log_target"],
97
+ )
98
+
99
+ return "✅ Models loaded successfully."
100
 
101
  # -------------------------------------------------
102
  # Inference Core
 
108
 
109
  # Feature engineering
110
  for col in [
111
+ "l1d_size",
112
+ "l1i_size",
113
+ "l2_size",
114
+ "l1d_assoc",
115
+ "l1i_assoc",
116
+ "l2_assoc",
117
  ]:
118
  df[f"{col}_log2"] = np.log2(df[col])
119
 
 
154
  # Gradio Wrapper
155
  # -------------------------------------------------
156
  def infer_from_csv(file):
 
 
157
  df = pd.read_csv(file.name)
158
  out_df = run_inference(df)
159
 
 
173
  # UI
174
  # -------------------------------------------------
175
  with gr.Blocks(title="AIDE Chip Surrogate Inference") as demo:
176
+ loading_md = gr.Markdown("## ⏳ Loading surrogate models…", visible=True)
177
+ ready_md = gr.Markdown("## ✅ Models ready", visible=False)
178
+
179
  gr.Markdown(
180
  """
181
  # AIDE Chip Surrogate Inference
 
188
  )
189
 
190
  csv_input = gr.File(label="Input CSV", file_types=[".csv"])
191
+ run_btn = gr.Button("Run Inference", interactive=False)
192
 
193
  preview = gr.Dataframe(label="Preview (first 20 rows)")
194
  output_csv = gr.File(label="Download Full Output CSV")
195
+ warnings_box = gr.Textbox(label="Sanity Check Summary")
196
+
197
+ demo.load(
198
+ preload_models,
199
+ inputs=None,
200
+ outputs=ready_md,
201
+ ).then(
202
+ lambda: (
203
+ gr.update(visible=False),
204
+ gr.update(visible=True),
205
+ gr.update(interactive=True),
206
+ ),
207
+ outputs=[loading_md, ready_md, run_btn],
208
+ )
209
 
210
  run_btn.click(
211
  infer_from_csv,
212
  inputs=csv_input,
213
+ outputs=[preview, output_csv, warnings_box],
214
  )
215
 
216
  if __name__ == "__main__":