File size: 9,479 Bytes
8ae248a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab3f651
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import streamlit as st
import pandas as pd
import plotly.graph_objects as go
import torch
from backend import ModelResearcher, ModelManager
from benchmarks import BenchmarkSuite
from model_diagnostics import ModelDiagnostics
import ablation_lab as ab # Import the new module

# --- Styling & Config ---
st.set_page_config(page_title="DeepBench: AI Researcher Workbench", layout="wide", page_icon="πŸ§ͺ")

st.markdown("""
<style>
    .stApp { background-color: #0e1117; color: #FAFAFA; }
    h1, h2, h3 { color: #00d4ff; }
    .metric-card {
        background-color: #262730; border: 1px solid #41424C;
        border-radius: 8px; padding: 15px; margin-bottom: 10px;
        text-align: center;
    }
    .metric-val { font-size: 24px; font-weight: bold; }
    .stButton>button { width: 100%; border-radius: 5px; }
</style>
""", unsafe_allow_html=True)

# --- State Management ---
if 'manager' not in st.session_state:
    st.session_state['manager'] = ModelManager(device="cuda" if torch.cuda.is_available() else "cpu")

# --- Sidebar ---
with st.sidebar:
    st.title("πŸ§ͺ DeepBench")
    st.markdown("### Researcher Control Panel")
    task = st.selectbox("Domain", ["Language", "Vision"])
    arch = st.radio("Architecture", ["All", "Transformer", "RNN/RWKV"])
    st.markdown("---")
    st.info(f"Device: {st.session_state['manager'].device.upper()}")
    st.caption("v4.0 Full Suite | Ablation Lab Active")

# --- Tabs ---
# We add "βœ‚οΈ Ablation Lab" as the last tab
tab_names = ["πŸ” Discovery", "βš”οΈ Battle Arena", "πŸ’¬ Playground", "πŸ’Ύ Hardware Forecast", "🩻 Model X-Ray", "βœ‚οΈ Ablation Lab"]
tab1, tab2, tab3, tab4, tab5, tab6 = st.tabs(tab_names)

# ================= TAB 1: DISCOVERY =================
with tab1:
    researcher = ModelResearcher()
    col_search, col_res = st.columns([1, 4])
    with col_search:
        if st.button("Fetch Models", use_container_width=True):
            st.session_state['models'] = researcher.search_models(task_domain=task, architecture_type=arch)
    with col_res:
        if 'models' in st.session_state:
            st.dataframe(st.session_state['models'], column_config={"downloads": st.column_config.ProgressColumn("Downloads", format="%d", min_value=0, max_value=1000000)}, use_container_width=True)

# ================= TAB 2: BATTLE ARENA =================
with tab2:
    if 'models' in st.session_state:
        all_ids = st.session_state['models']['model_id'].tolist()
        select_options = ["None"] + all_ids
        c1, c2 = st.columns(2)
        
        with c1:
            st.subheader("Champion (Model A)")
            model_a = st.selectbox("Select Model A", select_options, index=1 if len(all_ids)>0 else 0)
            quant_a = st.selectbox("Quantization A", ["None (FP16)", "8-bit (Int8)"], key="q_a")
        with c2:
            st.subheader("Challenger (Model B)")
            model_b = st.selectbox("Select Model B", select_options, index=0)
            quant_b = st.selectbox("Quantization B", ["None (FP16)", "8-bit (Int8)"], key="q_b")

        bench_opts = ["Perplexity", "MMLU", "GSM8K", "ARC-C", "ARC-E", "HellaSwag", "PIQA"]
        selected_bench = st.multiselect("Benchmarks", bench_opts, default=["Perplexity", "MMLU"])
        
        if st.button("βš”οΈ Run Comparison"):
            col_a, col_mid, col_b = st.columns([1, 0.1, 1])
            results_a, results_b = {}, {}
            q_map_a = "8-bit" if "8-bit" in quant_a else "None"
            q_map_b = "8-bit" if "8-bit" in quant_b else "None"

            with col_a:
                if model_a != "None":
                    succ, msg = st.session_state['manager'].load_model(model_a, quantization=q_map_a)
                    if succ:
                        mod, tok = st.session_state['manager'].get_components(model_a, quantization=q_map_a)
                        suite = BenchmarkSuite(mod, tok, model_id=f"{model_a}_{q_map_a}")
                        for b in selected_bench:
                            res = suite.run_benchmark(b, simulation_mode=True)
                            results_a[b] = res
                            st.markdown(f"""<div class='metric-card'><div style='color:#aaa;'>{b}</div><div class='metric-val'>{res['score']:.2f}</div><div>{res['rating']}</div></div>""", unsafe_allow_html=True)
            with col_b:
                if model_b != "None":
                    succ, msg = st.session_state['manager'].load_model(model_b, quantization=q_map_b)
                    if succ:
                        mod, tok = st.session_state['manager'].get_components(model_b, quantization=q_map_b)
                        suite = BenchmarkSuite(mod, tok, model_id=f"{model_b}_{q_map_b}")
                        for b in selected_bench:
                            res = suite.run_benchmark(b, simulation_mode=True)
                            results_b[b] = res
                            st.markdown(f"""<div class='metric-card'><div style='color:#aaa;'>{b}</div><div class='metric-val'>{res['score']:.2f}</div><div>{res['rating']}</div></div>""", unsafe_allow_html=True)
            
            if results_a and results_b:
                st.markdown("### πŸ•ΈοΈ Comparison Map")
                categories = list(results_a.keys())
                vals_a = [r['score'] if r['unit'] == "%" else (100-r['score']) for r in results_a.values()]
                vals_b = [r['score'] if r['unit'] == "%" else (100-r['score']) for r in results_b.values()]
                fig = go.Figure()
                fig.add_trace(go.Scatterpolar(r=vals_a, theta=categories, fill='toself', name=f"{model_a} ({q_map_a})", line_color="#00d4ff"))
                fig.add_trace(go.Scatterpolar(r=vals_b, theta=categories, fill='toself', name=f"{model_b} ({q_map_b})", line_color="#ff0055"))
                fig.update_layout(polar=dict(radialaxis=dict(visible=True, range=[0, 100])), paper_bgcolor="rgba(0,0,0,0)", font_color="white")
                st.plotly_chart(fig, use_container_width=True)
    else:
        st.warning("Go to Discovery tab first.")

# ================= TAB 3: PLAYGROUND =================
with tab3:
    st.subheader("πŸ’¬ Generation Playground")
    if 'models' in st.session_state:
        all_ids = st.session_state['models']['model_id'].tolist()
        select_options_play = ["None"] + all_ids
        pc1, pc2 = st.columns(2)
        with pc1:
            pm_a = st.selectbox("Generator A", select_options_play, index=1 if len(all_ids)>0 else 0, key="pm_a")
            pq_a = st.selectbox("Quant A", ["None (FP16)", "8-bit (Int8)"], key="pq_a")
        with pc2:
            pm_b = st.selectbox("Generator B", select_options_play, index=0, key="pm_b")
            pq_b = st.selectbox("Quant B", ["None (FP16)", "8-bit (Int8)"], key="pq_b")
        user_prompt = st.text_area("Prompt", value="Explain quantum computing like I'm 5.")
        if st.button("Generate Text"):
            c1, c2 = st.columns(2)
            pq_map_a = "8-bit" if "8-bit" in pq_a else "None"
            pq_map_b = "8-bit" if "8-bit" in pq_b else "None"
            with c1:
                if pm_a != "None":
                    succ, msg = st.session_state['manager'].load_model(pm_a, quantization=pq_map_a)
                    if succ:
                        out = st.session_state['manager'].generate_text(pm_a, pq_map_a, user_prompt)
                        st.info(out)
            with c2:
                if pm_b != "None":
                    succ, msg = st.session_state['manager'].load_model(pm_b, quantization=pq_map_b)
                    if succ:
                        out = st.session_state['manager'].generate_text(pm_b, pq_map_b, user_prompt)
                        st.success(out)
    else: st.warning("Please fetch models in Tab 1 first.")

# ================= TAB 4: HARDWARE FORECAST =================
with tab4:
    st.header("πŸ’Ύ Hardware Forecast")
    col1, col2 = st.columns(2)
    with col1:
        vram_input = st.text_input("Enter Model Size (e.g., 7B, 13B)", value="7B")
        if st.button("Calculate VRAM"):
            res = ModelDiagnostics.estimate_vram(vram_input)
            if res: st.session_state['vram_res'] = res
            else: st.error("Invalid format.")
    with col2:
        if 'vram_res' in st.session_state:
            res = st.session_state['vram_res']
            st.success(f"**Results for {res['params_in_billions']}B Params**")
            st.markdown(f"- **Training (FP32):** `{res['FP32 (Training/Full)']}`\n- **Inference (FP16):** `{res['FP16 (Inference)']}`\n- **Quantized (Int8):** `{res['INT8 (Quantized)']}`")

# ================= TAB 5: MODEL X-RAY =================
with tab5:
    st.header("πŸ” Model X-Ray")
    if 'models' in st.session_state:
        all_ids = st.session_state['models']['model_id'].tolist()
        xray_model = st.selectbox("Select Model to Inspect", all_ids)
        if st.button("Scan Layers"):
            succ, msg = st.session_state['manager'].load_model(xray_model, quantization="None")
            if succ:
                mod, _ = st.session_state['manager'].get_components(xray_model, quantization="None")
                structure = ModelDiagnostics.get_layer_structure(mod)
                st.text_area("Raw PyTorch Structure", value=structure, height=400)
    else: st.warning("Go to Discovery tab first.")

# ================= TAB 6: ABLATION LAB (NEW) =================
with tab6:
    # We delegate the entire rendering to the specialized module
    ab.render_ablation_dashboard()