sumitrwk commited on
Commit
ab3f651
·
verified ·
1 Parent(s): 0c0fef5

Upload app.py

Browse files
Files changed (1) hide show
  1. src/app.py +181 -0
src/app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import plotly.graph_objects as go
4
+ import torch
5
+ from src.backend import ModelResearcher, ModelManager
6
+ from src.benchmarks import BenchmarkSuite
7
+ from src.model_diagnostics import ModelDiagnostics
8
+ import src.ablation_lab as ab # Import the new module
9
+
10
+ # --- Styling & Config ---
11
+ st.set_page_config(page_title="DeepBench: AI Researcher Workbench", layout="wide", page_icon="🧪")
12
+
13
+ st.markdown("""
14
+ <style>
15
+ .stApp { background-color: #0e1117; color: #FAFAFA; }
16
+ h1, h2, h3 { color: #00d4ff; }
17
+ .metric-card {
18
+ background-color: #262730; border: 1px solid #41424C;
19
+ border-radius: 8px; padding: 15px; margin-bottom: 10px;
20
+ text-align: center;
21
+ }
22
+ .metric-val { font-size: 24px; font-weight: bold; }
23
+ .stButton>button { width: 100%; border-radius: 5px; }
24
+ </style>
25
+ """, unsafe_allow_html=True)
26
+
27
+ # --- State Management ---
28
+ if 'manager' not in st.session_state:
29
+ st.session_state['manager'] = ModelManager(device="cuda" if torch.cuda.is_available() else "cpu")
30
+
31
+ # --- Sidebar ---
32
+ with st.sidebar:
33
+ st.title("🧪 DeepBench")
34
+ st.markdown("### Researcher Control Panel")
35
+ task = st.selectbox("Domain", ["Language", "Vision"])
36
+ arch = st.radio("Architecture", ["All", "Transformer", "RNN/RWKV"])
37
+ st.markdown("---")
38
+ st.info(f"Device: {st.session_state['manager'].device.upper()}")
39
+ st.caption("v4.0 Full Suite | Ablation Lab Active")
40
+
41
+ # --- Tabs ---
42
+ # We add "✂️ Ablation Lab" as the last tab
43
+ tab_names = ["🔍 Discovery", "⚔️ Battle Arena", "💬 Playground", "💾 Hardware Forecast", "🩻 Model X-Ray", "✂️ Ablation Lab"]
44
+ tab1, tab2, tab3, tab4, tab5, tab6 = st.tabs(tab_names)
45
+
46
+ # ================= TAB 1: DISCOVERY =================
47
+ with tab1:
48
+ researcher = ModelResearcher()
49
+ col_search, col_res = st.columns([1, 4])
50
+ with col_search:
51
+ if st.button("Fetch Models", use_container_width=True):
52
+ st.session_state['models'] = researcher.search_models(task_domain=task, architecture_type=arch)
53
+ with col_res:
54
+ if 'models' in st.session_state:
55
+ st.dataframe(st.session_state['models'], column_config={"downloads": st.column_config.ProgressColumn("Downloads", format="%d", min_value=0, max_value=1000000)}, use_container_width=True)
56
+
57
+ # ================= TAB 2: BATTLE ARENA =================
58
+ with tab2:
59
+ if 'models' in st.session_state:
60
+ all_ids = st.session_state['models']['model_id'].tolist()
61
+ select_options = ["None"] + all_ids
62
+ c1, c2 = st.columns(2)
63
+
64
+ with c1:
65
+ st.subheader("Champion (Model A)")
66
+ model_a = st.selectbox("Select Model A", select_options, index=1 if len(all_ids)>0 else 0)
67
+ quant_a = st.selectbox("Quantization A", ["None (FP16)", "8-bit (Int8)"], key="q_a")
68
+ with c2:
69
+ st.subheader("Challenger (Model B)")
70
+ model_b = st.selectbox("Select Model B", select_options, index=0)
71
+ quant_b = st.selectbox("Quantization B", ["None (FP16)", "8-bit (Int8)"], key="q_b")
72
+
73
+ bench_opts = ["Perplexity", "MMLU", "GSM8K", "ARC-C", "ARC-E", "HellaSwag", "PIQA"]
74
+ selected_bench = st.multiselect("Benchmarks", bench_opts, default=["Perplexity", "MMLU"])
75
+
76
+ if st.button("⚔️ Run Comparison"):
77
+ col_a, col_mid, col_b = st.columns([1, 0.1, 1])
78
+ results_a, results_b = {}, {}
79
+ q_map_a = "8-bit" if "8-bit" in quant_a else "None"
80
+ q_map_b = "8-bit" if "8-bit" in quant_b else "None"
81
+
82
+ with col_a:
83
+ if model_a != "None":
84
+ succ, msg = st.session_state['manager'].load_model(model_a, quantization=q_map_a)
85
+ if succ:
86
+ mod, tok = st.session_state['manager'].get_components(model_a, quantization=q_map_a)
87
+ suite = BenchmarkSuite(mod, tok, model_id=f"{model_a}_{q_map_a}")
88
+ for b in selected_bench:
89
+ res = suite.run_benchmark(b, simulation_mode=True)
90
+ results_a[b] = res
91
+ st.markdown(f"""<div class='metric-card'><div style='color:#aaa;'>{b}</div><div class='metric-val'>{res['score']:.2f}</div><div>{res['rating']}</div></div>""", unsafe_allow_html=True)
92
+ with col_b:
93
+ if model_b != "None":
94
+ succ, msg = st.session_state['manager'].load_model(model_b, quantization=q_map_b)
95
+ if succ:
96
+ mod, tok = st.session_state['manager'].get_components(model_b, quantization=q_map_b)
97
+ suite = BenchmarkSuite(mod, tok, model_id=f"{model_b}_{q_map_b}")
98
+ for b in selected_bench:
99
+ res = suite.run_benchmark(b, simulation_mode=True)
100
+ results_b[b] = res
101
+ st.markdown(f"""<div class='metric-card'><div style='color:#aaa;'>{b}</div><div class='metric-val'>{res['score']:.2f}</div><div>{res['rating']}</div></div>""", unsafe_allow_html=True)
102
+
103
+ if results_a and results_b:
104
+ st.markdown("### 🕸️ Comparison Map")
105
+ categories = list(results_a.keys())
106
+ vals_a = [r['score'] if r['unit'] == "%" else (100-r['score']) for r in results_a.values()]
107
+ vals_b = [r['score'] if r['unit'] == "%" else (100-r['score']) for r in results_b.values()]
108
+ fig = go.Figure()
109
+ fig.add_trace(go.Scatterpolar(r=vals_a, theta=categories, fill='toself', name=f"{model_a} ({q_map_a})", line_color="#00d4ff"))
110
+ fig.add_trace(go.Scatterpolar(r=vals_b, theta=categories, fill='toself', name=f"{model_b} ({q_map_b})", line_color="#ff0055"))
111
+ fig.update_layout(polar=dict(radialaxis=dict(visible=True, range=[0, 100])), paper_bgcolor="rgba(0,0,0,0)", font_color="white")
112
+ st.plotly_chart(fig, use_container_width=True)
113
+ else:
114
+ st.warning("Go to Discovery tab first.")
115
+
116
+ # ================= TAB 3: PLAYGROUND =================
117
+ with tab3:
118
+ st.subheader("💬 Generation Playground")
119
+ if 'models' in st.session_state:
120
+ all_ids = st.session_state['models']['model_id'].tolist()
121
+ select_options_play = ["None"] + all_ids
122
+ pc1, pc2 = st.columns(2)
123
+ with pc1:
124
+ pm_a = st.selectbox("Generator A", select_options_play, index=1 if len(all_ids)>0 else 0, key="pm_a")
125
+ pq_a = st.selectbox("Quant A", ["None (FP16)", "8-bit (Int8)"], key="pq_a")
126
+ with pc2:
127
+ pm_b = st.selectbox("Generator B", select_options_play, index=0, key="pm_b")
128
+ pq_b = st.selectbox("Quant B", ["None (FP16)", "8-bit (Int8)"], key="pq_b")
129
+ user_prompt = st.text_area("Prompt", value="Explain quantum computing like I'm 5.")
130
+ if st.button("Generate Text"):
131
+ c1, c2 = st.columns(2)
132
+ pq_map_a = "8-bit" if "8-bit" in pq_a else "None"
133
+ pq_map_b = "8-bit" if "8-bit" in pq_b else "None"
134
+ with c1:
135
+ if pm_a != "None":
136
+ succ, msg = st.session_state['manager'].load_model(pm_a, quantization=pq_map_a)
137
+ if succ:
138
+ out = st.session_state['manager'].generate_text(pm_a, pq_map_a, user_prompt)
139
+ st.info(out)
140
+ with c2:
141
+ if pm_b != "None":
142
+ succ, msg = st.session_state['manager'].load_model(pm_b, quantization=pq_map_b)
143
+ if succ:
144
+ out = st.session_state['manager'].generate_text(pm_b, pq_map_b, user_prompt)
145
+ st.success(out)
146
+ else: st.warning("Please fetch models in Tab 1 first.")
147
+
148
+ # ================= TAB 4: HARDWARE FORECAST =================
149
+ with tab4:
150
+ st.header("💾 Hardware Forecast")
151
+ col1, col2 = st.columns(2)
152
+ with col1:
153
+ vram_input = st.text_input("Enter Model Size (e.g., 7B, 13B)", value="7B")
154
+ if st.button("Calculate VRAM"):
155
+ res = ModelDiagnostics.estimate_vram(vram_input)
156
+ if res: st.session_state['vram_res'] = res
157
+ else: st.error("Invalid format.")
158
+ with col2:
159
+ if 'vram_res' in st.session_state:
160
+ res = st.session_state['vram_res']
161
+ st.success(f"**Results for {res['params_in_billions']}B Params**")
162
+ st.markdown(f"- **Training (FP32):** `{res['FP32 (Training/Full)']}`\n- **Inference (FP16):** `{res['FP16 (Inference)']}`\n- **Quantized (Int8):** `{res['INT8 (Quantized)']}`")
163
+
164
+ # ================= TAB 5: MODEL X-RAY =================
165
+ with tab5:
166
+ st.header("🔍 Model X-Ray")
167
+ if 'models' in st.session_state:
168
+ all_ids = st.session_state['models']['model_id'].tolist()
169
+ xray_model = st.selectbox("Select Model to Inspect", all_ids)
170
+ if st.button("Scan Layers"):
171
+ succ, msg = st.session_state['manager'].load_model(xray_model, quantization="None")
172
+ if succ:
173
+ mod, _ = st.session_state['manager'].get_components(xray_model, quantization="None")
174
+ structure = ModelDiagnostics.get_layer_structure(mod)
175
+ st.text_area("Raw PyTorch Structure", value=structure, height=400)
176
+ else: st.warning("Go to Discovery tab first.")
177
+
178
+ # ================= TAB 6: ABLATION LAB (NEW) =================
179
+ with tab6:
180
+ # We delegate the entire rendering to the specialized module
181
+ ab.render_ablation_dashboard()