Udayshankar Ravikumar
Added data format notice.
77546b2 unverified
import gradio as gr
import pandas as pd
import numpy as np
import joblib
import os
from huggingface_hub import snapshot_download
import tempfile
# -------------------------------------------------
# Configuration
# -------------------------------------------------
HF_REPO_ID = "uralstech/AIDE-Chip-Surrogates"
MODEL_DIR = "surrogate_models_v2"
WORKLOAD_ALIAS = {
"matrix": "matrix_mul",
"matmul": "matrix_mul",
}
TARGETS = ["ipc", "l2_miss_rate"]
FEATURE_COLS = [
"l1d_size_log2",
"l1i_size_log2",
"l2_size_log2",
"l1d_assoc_log2",
"l1i_assoc_log2",
"l2_assoc_log2",
"l2_l1d_ratio_log2",
"l1d_sets_log2",
"l2_sets_log2",
]
REQUIRED_COLS = [
"workload",
"l1d_size",
"l1i_size",
"l2_size",
"l1d_assoc",
"l1i_assoc",
"l2_assoc",
]
# -------------------------------------------------
# Global model cache
# -------------------------------------------------
MODEL_CACHE = {}
# -------------------------------------------------
# Model Download
# -------------------------------------------------
def ensure_models():
if not os.path.exists(MODEL_DIR):
snapshot_download(
repo_id=HF_REPO_ID,
local_dir=".",
allow_patterns="*.pkl",
)
# -------------------------------------------------
# Utilities
# -------------------------------------------------
def resolve_workload(workload: str) -> str:
return WORKLOAD_ALIAS.get(workload, workload)
def load_model(workload: str, target: str):
try:
return MODEL_CACHE[(workload, target)]
except KeyError:
raise RuntimeError(f"Model not preloaded: {workload}, {target}")
def physical_sanity_check(ipc, miss_rate):
out = []
if ipc < 0 or ipc > 3.5:
out.append(f"IPC={ipc:.3f} out of physical range")
if miss_rate < 0 or miss_rate > 1:
out.append(f"L2 miss rate={miss_rate:.3f} out of [0,1]")
return out
# -------------------------------------------------
# Preload models (runs once at app start)
# -------------------------------------------------
def preload_models():
ensure_models()
workloads = {
"crc32",
"dijkstra",
"fft",
"matrix_mul",
"qsort",
"sha",
}
for workload in workloads:
for target in TARGETS:
path = os.path.join(
MODEL_DIR, f"model_{workload}_{target}.pkl"
)
payload = joblib.load(path)
MODEL_CACHE[(workload, target)] = (
payload["model"],
payload["log_target"],
)
return "ready"
# -------------------------------------------------
# Inference Core
# -------------------------------------------------
def run_inference(df: pd.DataFrame) -> pd.DataFrame:
missing = set(REQUIRED_COLS) - set(df.columns)
if missing:
raise ValueError(f"Missing required columns: {missing}")
for col in [
"l1d_size",
"l1i_size",
"l2_size",
"l1d_assoc",
"l1i_assoc",
"l2_assoc",
]:
df[f"{col}_log2"] = np.log2(df[col])
df["l2_l1d_ratio_log2"] = df["l2_size_log2"] - df["l1d_size_log2"]
df["l1d_sets_log2"] = df["l1d_size_log2"] - df["l1d_assoc_log2"]
df["l2_sets_log2"] = df["l2_size_log2"] - df["l2_assoc_log2"]
df["pred_ipc"] = np.nan
df["pred_l2_miss_rate"] = np.nan
df["warnings"] = ""
for idx, row in df.iterrows():
workload = resolve_workload(row["workload"])
X = row[FEATURE_COLS].values.reshape(1, -1)
preds = {}
warn_msgs = []
for target in TARGETS:
model, is_log = load_model(workload, target)
pred_raw = model.predict(X)[0]
pred = np.expm1(pred_raw) if is_log else pred_raw
if target == "l2_miss_rate":
pred = np.clip(pred, 0, 1)
preds[target] = float(pred)
warn_msgs.extend(
physical_sanity_check(preds["ipc"], preds["l2_miss_rate"])
)
df.at[idx, "pred_ipc"] = preds["ipc"]
df.at[idx, "pred_l2_miss_rate"] = preds["l2_miss_rate"]
df.at[idx, "warnings"] = "; ".join(warn_msgs)
return df
# -------------------------------------------------
# Gradio Wrapper
# -------------------------------------------------
def infer_from_csv(file):
df = pd.read_csv(file.name)
out_df = run_inference(df)
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
out_df.to_csv(tmp.name, index=False)
warning_rows = out_df[out_df["warnings"] != ""]
warning_text = (
f"{len(warning_rows)} rows triggered sanity warnings."
if not warning_rows.empty
else "No sanity warnings detected."
)
return out_df.head(20), tmp.name, warning_text
# -------------------------------------------------
# UI
# -------------------------------------------------
with gr.Blocks(title="AIDE Chip Surrogate Inference") as demo:
# Loading screen ONLY
loading_md = gr.Markdown(
"## Downloading surrogate models…\n\nThis may take a while.",
visible=True,
)
# Main app (hidden initially)
with gr.Column(visible=False) as app_ui:
gr.Markdown(
"""
# AIDE Chip Surrogate Inference
Upload a CSV describing cache configurations and workloads.
The app will run surrogate models to predict:
- IPC
- L2 Miss Rate
## Expected CSV Format
The input CSV **must** contain the following columns:
**Required columns**
- `workload` — one of: `crc32`, `dijkstra`, `fft`, `matrix_mul`, `qsort`, `sha`
- `l1d_size` — L1 data cache size (kibibytes, power of two)
- `l1i_size` — L1 instruction cache size (kibibytes, power of two)
- `l2_size` — L2 cache size (kibibytes, power of two)
- `l1d_assoc` — L1D associativity (power of two)
- `l1i_assoc` — L1I associativity (power of two)
- `l2_assoc` — L2 associativity (power of two)
**Notes**
- All size and associativity values must be positive and powers of two.
- One row corresponds to one cache configuration.
**Example**
```
workload,l1d_size,l1i_size,l2_size,l1d_assoc,l1i_assoc,l2_assoc
matrix_mul,128,64,1024,16,8,16
fft,128,64,2048,16,8,32
```
"""
)
csv_input = gr.File(label="Input CSV", file_types=[".csv"])
run_btn = gr.Button("Run Inference")
preview = gr.Dataframe(label="Preview (first 20 rows)")
output_csv = gr.File(label="Download Full Output CSV")
warnings_box = gr.Textbox(label="Sanity Check Summary")
run_btn.click(
infer_from_csv,
inputs=csv_input,
outputs=[preview, output_csv, warnings_box],
)
# Startup load hook
demo.load(
preload_models,
inputs=None,
outputs=None,
).then(
lambda: (
gr.update(visible=False),
gr.update(visible=True),
),
outputs=[loading_md, app_ui],
)
if __name__ == "__main__":
demo.launch()