sumitrwk commited on
Commit
364d233
·
verified ·
1 Parent(s): 11f2119

Delete ablation_lab.py

Browse files
Files changed (1) hide show
  1. ablation_lab.py +0 -286
ablation_lab.py DELETED
@@ -1,286 +0,0 @@
1
- import streamlit as st
2
- import torch
3
- import torch.nn as nn
4
- import pandas as pd
5
- import plotly.express as px
6
- import plotly.graph_objects as go
7
- import networkx as nx
8
- import copy
9
- from backend import ModelManager
10
-
11
- class AblationEngine:
12
- """
13
- Handles the 'Virtual Surgery' of models using PyTorch Hooks.
14
- Instead of deleting code, we intercept signals during inference.
15
- """
16
- def __init__(self, model_manager):
17
- self.manager = model_manager
18
- self.active_hooks = []
19
- self.ablation_log = []
20
-
21
- def clear_hooks(self):
22
- """Removes all active ablations (restores model to baseline)."""
23
- for handle in self.active_hooks:
24
- handle.remove()
25
- self.active_hooks = []
26
-
27
- def register_ablation(self, model, layer_name, ablation_type="zero_out", noise_level=0.1):
28
- """
29
- Injects a hook into a specific layer to modify its output.
30
- """
31
- target_module = dict(model.named_modules())[layer_name]
32
-
33
- def hook_fn(module, input, output):
34
- if ablation_type == "zero_out":
35
- # Structural Ablation: Kill the signal
36
- return output * 0.0
37
-
38
- elif ablation_type == "add_noise":
39
- # Robustness Test: Inject Gaussian noise
40
- noise = torch.randn_like(output) * noise_level
41
- return output + noise
42
-
43
- elif ablation_type == "freeze_mean":
44
- # Information Bottleneck: Replace with batch mean
45
- return torch.mean(output, dim=0, keepdim=True).expand_as(output)
46
-
47
- return output
48
-
49
- # Register the hook
50
- handle = target_module.register_forward_hook(hook_fn)
51
- self.active_hooks.append(handle)
52
- return f"Ablated {layer_name} ({ablation_type})"
53
-
54
- class ArchitectureVisualizer:
55
- """
56
- Builds a Netron-style interactive graph of the model layers using NetworkX + Plotly.
57
- """
58
- @staticmethod
59
- def build_layer_graph(model):
60
- G = nx.DiGraph()
61
- prev_node = "Input"
62
- G.add_node("Input", type="Input")
63
-
64
- # Walk through modules (simplified for visualization)
65
- # We limit depth to avoid 10,000 node graphs for LLMs
66
- for name, module in model.named_modules():
67
- # Filter for high-level blocks only (Layers, Attention, MLP)
68
- if any(k in name for k in ["layer", "block", "attn", "mlp"]) and "." not in name.split(".")[-1]:
69
- # Heuristic: Connect sequential blocks
70
- G.add_node(name, type=module.__class__.__name__, params=sum(p.numel() for p in module.parameters()))
71
- G.add_edge(prev_node, name)
72
- prev_node = name
73
-
74
- G.add_node("Output", type="Output")
75
- G.add_edge(prev_node, "Output")
76
- return G
77
-
78
- @staticmethod
79
- def plot_interactive_graph(G):
80
- pos = nx.spring_layout(G, seed=42, k=0.5)
81
-
82
- edge_x, edge_y = [], []
83
- for edge in G.edges():
84
- x0, y0 = pos[edge[0]]
85
- x1, y1 = pos[edge[1]]
86
- edge_x.extend([x0, x1, None])
87
- edge_y.extend([y0, y1, None])
88
-
89
- edge_trace = go.Scatter(
90
- x=edge_x, y=edge_y,
91
- line=dict(width=0.5, color='#888'),
92
- hoverinfo='none', mode='lines'
93
- )
94
-
95
- node_x, node_y, node_text, node_color = [], [], [], []
96
- for node in G.nodes():
97
- x, y = pos[node]
98
- node_x.append(x)
99
- node_y.append(y)
100
- info = G.nodes[node]
101
- node_text.append(f"{node}<br>{info.get('type', 'Unknown')}<br>Params: {info.get('params', 'N/A')}")
102
-
103
- # Color coding
104
- if "attn" in node.lower(): node_color.append("#FF0055") # Attention
105
- elif "mlp" in node.lower(): node_color.append("#00CC96") # MLP
106
- elif "layer" in node.lower(): node_color.append("#AB63FA") # Blocks
107
- else: node_color.append("#FFFFFF")
108
-
109
- node_trace = go.Scatter(
110
- x=node_x, y=node_y,
111
- mode='markers',
112
- hoverinfo='text',
113
- text=node_text,
114
- marker=dict(showscale=False, color=node_color, size=15, line_width=2)
115
- )
116
-
117
- fig = go.Figure(data=[edge_trace, node_trace],
118
- layout=go.Layout(
119
- showlegend=False,
120
- hovermode='closest',
121
- margin=dict(b=0,l=0,r=0,t=0),
122
- paper_bgcolor='rgba(0,0,0,0)',
123
- plot_bgcolor='rgba(0,0,0,0)',
124
- xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
125
- yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
126
- )
127
- return fig
128
-
129
- def render_ablation_dashboard():
130
- # --- Custom CSS for the Dashboard Feel ---
131
- st.markdown("""
132
- <style>
133
- .ablation-header {
134
- background: linear-gradient(90deg, #FF4B4B 0%, #FF9068 100%);
135
- -webkit-background-clip: text;
136
- -webkit-text-fill-color: transparent;
137
- font-size: 30px; font-weight: 900;
138
- }
139
- .stat-box {
140
- background-color: #1E1E1E; border: 1px solid #333;
141
- padding: 15px; border-radius: 5px; text-align: center;
142
- }
143
- .risk-high { border-left: 5px solid #FF4B4B; }
144
- .risk-med { border-left: 5px solid #FFAA00; }
145
- .risk-low { border-left: 5px solid #00FF00; }
146
- </style>
147
- """, unsafe_allow_html=True)
148
-
149
- st.markdown('<div class="ablation-header">🧪 SYSTEMATIC ABLATION LAB</div>', unsafe_allow_html=True)
150
- st.caption("Surgically alter model components to measure contribution and robustness.")
151
-
152
- if 'models' not in st.session_state:
153
- st.warning("Please load models in the Discovery tab first.")
154
- return
155
-
156
- # 1. Select Subject
157
- col_sel, col_viz = st.columns([1, 3])
158
-
159
- with col_sel:
160
- st.subheader("1. Subject")
161
- all_ids = st.session_state['models']['model_id'].tolist()
162
- target_model_id = st.selectbox("Select Model for Surgery", all_ids)
163
-
164
- # Load Model Button
165
- if st.button("Initialize Surgery Table"):
166
- with st.spinner("Preparing model for hooks..."):
167
- succ, msg = st.session_state['manager'].load_model(target_model_id)
168
- if succ:
169
- st.success("Ready.")
170
- st.session_state['ablation_target'] = target_model_id
171
- # Initialize engine
172
- st.session_state['ablation_engine'] = AblationEngine(st.session_state['manager'])
173
- else:
174
- st.error(msg)
175
-
176
- # 2. Main Workspace
177
- if 'ablation_target' in st.session_state:
178
- target_id = st.session_state['ablation_target']
179
- model_pkg = st.session_state['manager'].loaded_models.get(f"{target_id}_None") # Default FP32/16 key
180
-
181
- if not model_pkg:
182
- st.error("Model lost from memory. Please reload.")
183
- return
184
-
185
- model = model_pkg['model']
186
-
187
- # --- TAB LAYOUT FOR ABLATION ---
188
- t1, t2, t3 = st.tabs(["🧬 Structural Map", "🔪 Ablation Controls", "📊 Impact Report"])
189
-
190
- # === TAB 1: ARCHITECTURE GRAPH ===
191
- with t1:
192
- st.markdown("### Interactive Architecture Map")
193
- st.markdown("Visualize the flow to decide where to cut.")
194
-
195
- if st.button("Generate Graph (Heavy Compute)"):
196
- with st.spinner("Tracing neural pathways..."):
197
- G = ArchitectureVisualizer.build_layer_graph(model)
198
- fig = ArchitectureVisualizer.plot_interactive_graph(G)
199
- st.plotly_chart(fig, use_container_width=True)
200
-
201
- # === TAB 2: CONTROLS ===
202
- with t2:
203
- st.subheader("Configure Ablation Experiment")
204
-
205
- c1, c2 = st.columns(2)
206
- with c1:
207
- # Get all layers
208
- all_layers = [n for n, _ in model.named_modules() if len(n) > 0]
209
- target_layers = st.multiselect("Select Target Layers", all_layers, max_selections=5)
210
-
211
- with c2:
212
- method = st.selectbox("Ablation Method",
213
- ["Zero-Out (Remove)", "Add Noise (Corrupt)", "Freeze Mean (Bottleneck)"])
214
- if method == "Add Noise (Corrupt)":
215
- noise_val = st.slider("Noise Level (Std Dev)", 0.0, 2.0, 0.1)
216
- else:
217
- noise_val = 0.0
218
-
219
- if st.button("🔴 RUN ABLATION TEST"):
220
- engine = st.session_state['ablation_engine']
221
- engine.clear_hooks() # Reset previous
222
-
223
- results_log = []
224
-
225
- # 1. Establish Baseline
226
- st.write("Measuring Baseline Performance...")
227
- # We simply use a generation prompt length as a proxy for "Performance"
228
- # or run a quick perplexity check if integrated with benchmarks.
229
- # For this dashboard, we run the "Prompt Integrity Test"
230
-
231
- prompt = "The capital of France is"
232
- base_out = st.session_state['manager'].generate_text(target_id, "None", prompt)
233
- results_log.append({"State": "Baseline", "Output": base_out, "Integrity": 100})
234
-
235
- # 2. Apply Hooks
236
- for layer in target_layers:
237
- msg = engine.register_ablation(model, layer, method.lower().split()[0].replace("-","_"), noise_val)
238
- st.toast(msg)
239
-
240
- # 3. Measure Ablated Performance
241
- st.write("Running Ablated Inference...")
242
- ablated_out = st.session_state['manager'].generate_text(target_id, "None", prompt)
243
-
244
- # Simple heuristic: String similarity or length retention
245
- integrity = (len(ablated_out) / len(base_out)) * 100 if len(base_out) > 0 else 0
246
- results_log.append({"State": "Ablated", "Output": ablated_out, "Integrity": integrity})
247
-
248
- st.session_state['ablation_results'] = results_log
249
-
250
- # Cleanup
251
- engine.clear_hooks()
252
- st.success("Experiment Complete. Hooks Removed.")
253
-
254
- # === TAB 3: RESULTS ===
255
- with t3:
256
- if 'ablation_results' in st.session_state:
257
- res = st.session_state['ablation_results']
258
-
259
- # Visual Diff
260
- st.markdown("### 📝 Output Degradation Analysis")
261
-
262
- col_base, col_abl = st.columns(2)
263
- with col_base:
264
- st.info(f"**Baseline:** {res[0]['Output']}")
265
- with col_abl:
266
- st.warning(f"**Ablated:** {res[1]['Output']}")
267
-
268
- # Metrics
269
- deg = 100 - res[1]['Integrity']
270
- st.metric("Model Degradation", f"{deg:.1f}%", delta=f"-{deg:.1f}%", delta_color="inverse")
271
-
272
- # Sensitivity Chart (Mocked for single run, would need loop for real sensitivity analysis)
273
- st.markdown("### 🔥 Layer Sensitivity Heatmap")
274
-
275
- # Creating dummy data to show what the "full suite" would look like
276
- sens_data = pd.DataFrame({
277
- "Layer": ["embed", "layer.0", "layer.1", "layer.2", "head"],
278
- "Sensitivity Score": [95, 10, 15, 80, 100]
279
- })
280
-
281
- fig = px.bar(sens_data, x="Layer", y="Sensitivity Score",
282
- color="Sensitivity Score", color_continuous_scale="RdYlGn_r",
283
- title="Estimated Contribution to Output (Simulated)")
284
- st.plotly_chart(fig, use_container_width=True)
285
- else:
286
- st.info("Run an experiment in Tab 2 to see results.")