sumitrwk commited on
Commit
8ae248a
Β·
verified Β·
1 Parent(s): a0672f2

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +180 -180
src/streamlit_app.py CHANGED
@@ -1,181 +1,181 @@
1
- import streamlit as st
2
- import pandas as pd
3
- import plotly.graph_objects as go
4
- import torch
5
- from src.backend import ModelResearcher, ModelManager
6
- from src.benchmarks import BenchmarkSuite
7
- from src.model_diagnostics import ModelDiagnostics
8
- import src.ablation_lab as ab # Import the new module
9
-
10
- # --- Styling & Config ---
11
- st.set_page_config(page_title="DeepBench: AI Researcher Workbench", layout="wide", page_icon="πŸ§ͺ")
12
-
13
- st.markdown("""
14
- <style>
15
- .stApp { background-color: #0e1117; color: #FAFAFA; }
16
- h1, h2, h3 { color: #00d4ff; }
17
- .metric-card {
18
- background-color: #262730; border: 1px solid #41424C;
19
- border-radius: 8px; padding: 15px; margin-bottom: 10px;
20
- text-align: center;
21
- }
22
- .metric-val { font-size: 24px; font-weight: bold; }
23
- .stButton>button { width: 100%; border-radius: 5px; }
24
- </style>
25
- """, unsafe_allow_html=True)
26
-
27
- # --- State Management ---
28
- if 'manager' not in st.session_state:
29
- st.session_state['manager'] = ModelManager(device="cuda" if torch.cuda.is_available() else "cpu")
30
-
31
- # --- Sidebar ---
32
- with st.sidebar:
33
- st.title("πŸ§ͺ DeepBench")
34
- st.markdown("### Researcher Control Panel")
35
- task = st.selectbox("Domain", ["Language", "Vision"])
36
- arch = st.radio("Architecture", ["All", "Transformer", "RNN/RWKV"])
37
- st.markdown("---")
38
- st.info(f"Device: {st.session_state['manager'].device.upper()}")
39
- st.caption("v4.0 Full Suite | Ablation Lab Active")
40
-
41
- # --- Tabs ---
42
- # We add "βœ‚οΈ Ablation Lab" as the last tab
43
- tab_names = ["πŸ” Discovery", "βš”οΈ Battle Arena", "πŸ’¬ Playground", "πŸ’Ύ Hardware Forecast", "🩻 Model X-Ray", "βœ‚οΈ Ablation Lab"]
44
- tab1, tab2, tab3, tab4, tab5, tab6 = st.tabs(tab_names)
45
-
46
- # ================= TAB 1: DISCOVERY =================
47
- with tab1:
48
- researcher = ModelResearcher()
49
- col_search, col_res = st.columns([1, 4])
50
- with col_search:
51
- if st.button("Fetch Models", use_container_width=True):
52
- st.session_state['models'] = researcher.search_models(task_domain=task, architecture_type=arch)
53
- with col_res:
54
- if 'models' in st.session_state:
55
- st.dataframe(st.session_state['models'], column_config={"downloads": st.column_config.ProgressColumn("Downloads", format="%d", min_value=0, max_value=1000000)}, use_container_width=True)
56
-
57
- # ================= TAB 2: BATTLE ARENA =================
58
- with tab2:
59
- if 'models' in st.session_state:
60
- all_ids = st.session_state['models']['model_id'].tolist()
61
- select_options = ["None"] + all_ids
62
- c1, c2 = st.columns(2)
63
-
64
- with c1:
65
- st.subheader("Champion (Model A)")
66
- model_a = st.selectbox("Select Model A", select_options, index=1 if len(all_ids)>0 else 0)
67
- quant_a = st.selectbox("Quantization A", ["None (FP16)", "8-bit (Int8)"], key="q_a")
68
- with c2:
69
- st.subheader("Challenger (Model B)")
70
- model_b = st.selectbox("Select Model B", select_options, index=0)
71
- quant_b = st.selectbox("Quantization B", ["None (FP16)", "8-bit (Int8)"], key="q_b")
72
-
73
- bench_opts = ["Perplexity", "MMLU", "GSM8K", "ARC-C", "ARC-E", "HellaSwag", "PIQA"]
74
- selected_bench = st.multiselect("Benchmarks", bench_opts, default=["Perplexity", "MMLU"])
75
-
76
- if st.button("βš”οΈ Run Comparison"):
77
- col_a, col_mid, col_b = st.columns([1, 0.1, 1])
78
- results_a, results_b = {}, {}
79
- q_map_a = "8-bit" if "8-bit" in quant_a else "None"
80
- q_map_b = "8-bit" if "8-bit" in quant_b else "None"
81
-
82
- with col_a:
83
- if model_a != "None":
84
- succ, msg = st.session_state['manager'].load_model(model_a, quantization=q_map_a)
85
- if succ:
86
- mod, tok = st.session_state['manager'].get_components(model_a, quantization=q_map_a)
87
- suite = BenchmarkSuite(mod, tok, model_id=f"{model_a}_{q_map_a}")
88
- for b in selected_bench:
89
- res = suite.run_benchmark(b, simulation_mode=True)
90
- results_a[b] = res
91
- st.markdown(f"""<div class='metric-card'><div style='color:#aaa;'>{b}</div><div class='metric-val'>{res['score']:.2f}</div><div>{res['rating']}</div></div>""", unsafe_allow_html=True)
92
- with col_b:
93
- if model_b != "None":
94
- succ, msg = st.session_state['manager'].load_model(model_b, quantization=q_map_b)
95
- if succ:
96
- mod, tok = st.session_state['manager'].get_components(model_b, quantization=q_map_b)
97
- suite = BenchmarkSuite(mod, tok, model_id=f"{model_b}_{q_map_b}")
98
- for b in selected_bench:
99
- res = suite.run_benchmark(b, simulation_mode=True)
100
- results_b[b] = res
101
- st.markdown(f"""<div class='metric-card'><div style='color:#aaa;'>{b}</div><div class='metric-val'>{res['score']:.2f}</div><div>{res['rating']}</div></div>""", unsafe_allow_html=True)
102
-
103
- if results_a and results_b:
104
- st.markdown("### πŸ•ΈοΈ Comparison Map")
105
- categories = list(results_a.keys())
106
- vals_a = [r['score'] if r['unit'] == "%" else (100-r['score']) for r in results_a.values()]
107
- vals_b = [r['score'] if r['unit'] == "%" else (100-r['score']) for r in results_b.values()]
108
- fig = go.Figure()
109
- fig.add_trace(go.Scatterpolar(r=vals_a, theta=categories, fill='toself', name=f"{model_a} ({q_map_a})", line_color="#00d4ff"))
110
- fig.add_trace(go.Scatterpolar(r=vals_b, theta=categories, fill='toself', name=f"{model_b} ({q_map_b})", line_color="#ff0055"))
111
- fig.update_layout(polar=dict(radialaxis=dict(visible=True, range=[0, 100])), paper_bgcolor="rgba(0,0,0,0)", font_color="white")
112
- st.plotly_chart(fig, use_container_width=True)
113
- else:
114
- st.warning("Go to Discovery tab first.")
115
-
116
- # ================= TAB 3: PLAYGROUND =================
117
- with tab3:
118
- st.subheader("πŸ’¬ Generation Playground")
119
- if 'models' in st.session_state:
120
- all_ids = st.session_state['models']['model_id'].tolist()
121
- select_options_play = ["None"] + all_ids
122
- pc1, pc2 = st.columns(2)
123
- with pc1:
124
- pm_a = st.selectbox("Generator A", select_options_play, index=1 if len(all_ids)>0 else 0, key="pm_a")
125
- pq_a = st.selectbox("Quant A", ["None (FP16)", "8-bit (Int8)"], key="pq_a")
126
- with pc2:
127
- pm_b = st.selectbox("Generator B", select_options_play, index=0, key="pm_b")
128
- pq_b = st.selectbox("Quant B", ["None (FP16)", "8-bit (Int8)"], key="pq_b")
129
- user_prompt = st.text_area("Prompt", value="Explain quantum computing like I'm 5.")
130
- if st.button("Generate Text"):
131
- c1, c2 = st.columns(2)
132
- pq_map_a = "8-bit" if "8-bit" in pq_a else "None"
133
- pq_map_b = "8-bit" if "8-bit" in pq_b else "None"
134
- with c1:
135
- if pm_a != "None":
136
- succ, msg = st.session_state['manager'].load_model(pm_a, quantization=pq_map_a)
137
- if succ:
138
- out = st.session_state['manager'].generate_text(pm_a, pq_map_a, user_prompt)
139
- st.info(out)
140
- with c2:
141
- if pm_b != "None":
142
- succ, msg = st.session_state['manager'].load_model(pm_b, quantization=pq_map_b)
143
- if succ:
144
- out = st.session_state['manager'].generate_text(pm_b, pq_map_b, user_prompt)
145
- st.success(out)
146
- else: st.warning("Please fetch models in Tab 1 first.")
147
-
148
- # ================= TAB 4: HARDWARE FORECAST =================
149
- with tab4:
150
- st.header("πŸ’Ύ Hardware Forecast")
151
- col1, col2 = st.columns(2)
152
- with col1:
153
- vram_input = st.text_input("Enter Model Size (e.g., 7B, 13B)", value="7B")
154
- if st.button("Calculate VRAM"):
155
- res = ModelDiagnostics.estimate_vram(vram_input)
156
- if res: st.session_state['vram_res'] = res
157
- else: st.error("Invalid format.")
158
- with col2:
159
- if 'vram_res' in st.session_state:
160
- res = st.session_state['vram_res']
161
- st.success(f"**Results for {res['params_in_billions']}B Params**")
162
- st.markdown(f"- **Training (FP32):** `{res['FP32 (Training/Full)']}`\n- **Inference (FP16):** `{res['FP16 (Inference)']}`\n- **Quantized (Int8):** `{res['INT8 (Quantized)']}`")
163
-
164
- # ================= TAB 5: MODEL X-RAY =================
165
- with tab5:
166
- st.header("πŸ” Model X-Ray")
167
- if 'models' in st.session_state:
168
- all_ids = st.session_state['models']['model_id'].tolist()
169
- xray_model = st.selectbox("Select Model to Inspect", all_ids)
170
- if st.button("Scan Layers"):
171
- succ, msg = st.session_state['manager'].load_model(xray_model, quantization="None")
172
- if succ:
173
- mod, _ = st.session_state['manager'].get_components(xray_model, quantization="None")
174
- structure = ModelDiagnostics.get_layer_structure(mod)
175
- st.text_area("Raw PyTorch Structure", value=structure, height=400)
176
- else: st.warning("Go to Discovery tab first.")
177
-
178
- # ================= TAB 6: ABLATION LAB (NEW) =================
179
- with tab6:
180
- # We delegate the entire rendering to the specialized module
181
  ab.render_ablation_dashboard()
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import plotly.graph_objects as go
4
+ import torch
5
+ from backend import ModelResearcher, ModelManager
6
+ from benchmarks import BenchmarkSuite
7
+ from model_diagnostics import ModelDiagnostics
8
+ import ablation_lab as ab # Import the new module
9
+
10
+ # --- Styling & Config ---
11
+ st.set_page_config(page_title="DeepBench: AI Researcher Workbench", layout="wide", page_icon="πŸ§ͺ")
12
+
13
+ st.markdown("""
14
+ <style>
15
+ .stApp { background-color: #0e1117; color: #FAFAFA; }
16
+ h1, h2, h3 { color: #00d4ff; }
17
+ .metric-card {
18
+ background-color: #262730; border: 1px solid #41424C;
19
+ border-radius: 8px; padding: 15px; margin-bottom: 10px;
20
+ text-align: center;
21
+ }
22
+ .metric-val { font-size: 24px; font-weight: bold; }
23
+ .stButton>button { width: 100%; border-radius: 5px; }
24
+ </style>
25
+ """, unsafe_allow_html=True)
26
+
27
+ # --- State Management ---
28
+ if 'manager' not in st.session_state:
29
+ st.session_state['manager'] = ModelManager(device="cuda" if torch.cuda.is_available() else "cpu")
30
+
31
+ # --- Sidebar ---
32
+ with st.sidebar:
33
+ st.title("πŸ§ͺ DeepBench")
34
+ st.markdown("### Researcher Control Panel")
35
+ task = st.selectbox("Domain", ["Language", "Vision"])
36
+ arch = st.radio("Architecture", ["All", "Transformer", "RNN/RWKV"])
37
+ st.markdown("---")
38
+ st.info(f"Device: {st.session_state['manager'].device.upper()}")
39
+ st.caption("v4.0 Full Suite | Ablation Lab Active")
40
+
41
+ # --- Tabs ---
42
+ # We add "βœ‚οΈ Ablation Lab" as the last tab
43
+ tab_names = ["πŸ” Discovery", "βš”οΈ Battle Arena", "πŸ’¬ Playground", "πŸ’Ύ Hardware Forecast", "🩻 Model X-Ray", "βœ‚οΈ Ablation Lab"]
44
+ tab1, tab2, tab3, tab4, tab5, tab6 = st.tabs(tab_names)
45
+
46
+ # ================= TAB 1: DISCOVERY =================
47
+ with tab1:
48
+ researcher = ModelResearcher()
49
+ col_search, col_res = st.columns([1, 4])
50
+ with col_search:
51
+ if st.button("Fetch Models", use_container_width=True):
52
+ st.session_state['models'] = researcher.search_models(task_domain=task, architecture_type=arch)
53
+ with col_res:
54
+ if 'models' in st.session_state:
55
+ st.dataframe(st.session_state['models'], column_config={"downloads": st.column_config.ProgressColumn("Downloads", format="%d", min_value=0, max_value=1000000)}, use_container_width=True)
56
+
57
+ # ================= TAB 2: BATTLE ARENA =================
58
+ with tab2:
59
+ if 'models' in st.session_state:
60
+ all_ids = st.session_state['models']['model_id'].tolist()
61
+ select_options = ["None"] + all_ids
62
+ c1, c2 = st.columns(2)
63
+
64
+ with c1:
65
+ st.subheader("Champion (Model A)")
66
+ model_a = st.selectbox("Select Model A", select_options, index=1 if len(all_ids)>0 else 0)
67
+ quant_a = st.selectbox("Quantization A", ["None (FP16)", "8-bit (Int8)"], key="q_a")
68
+ with c2:
69
+ st.subheader("Challenger (Model B)")
70
+ model_b = st.selectbox("Select Model B", select_options, index=0)
71
+ quant_b = st.selectbox("Quantization B", ["None (FP16)", "8-bit (Int8)"], key="q_b")
72
+
73
+ bench_opts = ["Perplexity", "MMLU", "GSM8K", "ARC-C", "ARC-E", "HellaSwag", "PIQA"]
74
+ selected_bench = st.multiselect("Benchmarks", bench_opts, default=["Perplexity", "MMLU"])
75
+
76
+ if st.button("βš”οΈ Run Comparison"):
77
+ col_a, col_mid, col_b = st.columns([1, 0.1, 1])
78
+ results_a, results_b = {}, {}
79
+ q_map_a = "8-bit" if "8-bit" in quant_a else "None"
80
+ q_map_b = "8-bit" if "8-bit" in quant_b else "None"
81
+
82
+ with col_a:
83
+ if model_a != "None":
84
+ succ, msg = st.session_state['manager'].load_model(model_a, quantization=q_map_a)
85
+ if succ:
86
+ mod, tok = st.session_state['manager'].get_components(model_a, quantization=q_map_a)
87
+ suite = BenchmarkSuite(mod, tok, model_id=f"{model_a}_{q_map_a}")
88
+ for b in selected_bench:
89
+ res = suite.run_benchmark(b, simulation_mode=True)
90
+ results_a[b] = res
91
+ st.markdown(f"""<div class='metric-card'><div style='color:#aaa;'>{b}</div><div class='metric-val'>{res['score']:.2f}</div><div>{res['rating']}</div></div>""", unsafe_allow_html=True)
92
+ with col_b:
93
+ if model_b != "None":
94
+ succ, msg = st.session_state['manager'].load_model(model_b, quantization=q_map_b)
95
+ if succ:
96
+ mod, tok = st.session_state['manager'].get_components(model_b, quantization=q_map_b)
97
+ suite = BenchmarkSuite(mod, tok, model_id=f"{model_b}_{q_map_b}")
98
+ for b in selected_bench:
99
+ res = suite.run_benchmark(b, simulation_mode=True)
100
+ results_b[b] = res
101
+ st.markdown(f"""<div class='metric-card'><div style='color:#aaa;'>{b}</div><div class='metric-val'>{res['score']:.2f}</div><div>{res['rating']}</div></div>""", unsafe_allow_html=True)
102
+
103
+ if results_a and results_b:
104
+ st.markdown("### πŸ•ΈοΈ Comparison Map")
105
+ categories = list(results_a.keys())
106
+ vals_a = [r['score'] if r['unit'] == "%" else (100-r['score']) for r in results_a.values()]
107
+ vals_b = [r['score'] if r['unit'] == "%" else (100-r['score']) for r in results_b.values()]
108
+ fig = go.Figure()
109
+ fig.add_trace(go.Scatterpolar(r=vals_a, theta=categories, fill='toself', name=f"{model_a} ({q_map_a})", line_color="#00d4ff"))
110
+ fig.add_trace(go.Scatterpolar(r=vals_b, theta=categories, fill='toself', name=f"{model_b} ({q_map_b})", line_color="#ff0055"))
111
+ fig.update_layout(polar=dict(radialaxis=dict(visible=True, range=[0, 100])), paper_bgcolor="rgba(0,0,0,0)", font_color="white")
112
+ st.plotly_chart(fig, use_container_width=True)
113
+ else:
114
+ st.warning("Go to Discovery tab first.")
115
+
116
+ # ================= TAB 3: PLAYGROUND =================
117
+ with tab3:
118
+ st.subheader("πŸ’¬ Generation Playground")
119
+ if 'models' in st.session_state:
120
+ all_ids = st.session_state['models']['model_id'].tolist()
121
+ select_options_play = ["None"] + all_ids
122
+ pc1, pc2 = st.columns(2)
123
+ with pc1:
124
+ pm_a = st.selectbox("Generator A", select_options_play, index=1 if len(all_ids)>0 else 0, key="pm_a")
125
+ pq_a = st.selectbox("Quant A", ["None (FP16)", "8-bit (Int8)"], key="pq_a")
126
+ with pc2:
127
+ pm_b = st.selectbox("Generator B", select_options_play, index=0, key="pm_b")
128
+ pq_b = st.selectbox("Quant B", ["None (FP16)", "8-bit (Int8)"], key="pq_b")
129
+ user_prompt = st.text_area("Prompt", value="Explain quantum computing like I'm 5.")
130
+ if st.button("Generate Text"):
131
+ c1, c2 = st.columns(2)
132
+ pq_map_a = "8-bit" if "8-bit" in pq_a else "None"
133
+ pq_map_b = "8-bit" if "8-bit" in pq_b else "None"
134
+ with c1:
135
+ if pm_a != "None":
136
+ succ, msg = st.session_state['manager'].load_model(pm_a, quantization=pq_map_a)
137
+ if succ:
138
+ out = st.session_state['manager'].generate_text(pm_a, pq_map_a, user_prompt)
139
+ st.info(out)
140
+ with c2:
141
+ if pm_b != "None":
142
+ succ, msg = st.session_state['manager'].load_model(pm_b, quantization=pq_map_b)
143
+ if succ:
144
+ out = st.session_state['manager'].generate_text(pm_b, pq_map_b, user_prompt)
145
+ st.success(out)
146
+ else: st.warning("Please fetch models in Tab 1 first.")
147
+
148
+ # ================= TAB 4: HARDWARE FORECAST =================
149
+ with tab4:
150
+ st.header("πŸ’Ύ Hardware Forecast")
151
+ col1, col2 = st.columns(2)
152
+ with col1:
153
+ vram_input = st.text_input("Enter Model Size (e.g., 7B, 13B)", value="7B")
154
+ if st.button("Calculate VRAM"):
155
+ res = ModelDiagnostics.estimate_vram(vram_input)
156
+ if res: st.session_state['vram_res'] = res
157
+ else: st.error("Invalid format.")
158
+ with col2:
159
+ if 'vram_res' in st.session_state:
160
+ res = st.session_state['vram_res']
161
+ st.success(f"**Results for {res['params_in_billions']}B Params**")
162
+ st.markdown(f"- **Training (FP32):** `{res['FP32 (Training/Full)']}`\n- **Inference (FP16):** `{res['FP16 (Inference)']}`\n- **Quantized (Int8):** `{res['INT8 (Quantized)']}`")
163
+
164
+ # ================= TAB 5: MODEL X-RAY =================
165
+ with tab5:
166
+ st.header("πŸ” Model X-Ray")
167
+ if 'models' in st.session_state:
168
+ all_ids = st.session_state['models']['model_id'].tolist()
169
+ xray_model = st.selectbox("Select Model to Inspect", all_ids)
170
+ if st.button("Scan Layers"):
171
+ succ, msg = st.session_state['manager'].load_model(xray_model, quantization="None")
172
+ if succ:
173
+ mod, _ = st.session_state['manager'].get_components(xray_model, quantization="None")
174
+ structure = ModelDiagnostics.get_layer_structure(mod)
175
+ st.text_area("Raw PyTorch Structure", value=structure, height=400)
176
+ else: st.warning("Go to Discovery tab first.")
177
+
178
+ # ================= TAB 6: ABLATION LAB (NEW) =================
179
+ with tab6:
180
+ # We delegate the entire rendering to the specialized module
181
  ab.render_ablation_dashboard()