File size: 12,062 Bytes
1971e8d 11f2119 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 |
import streamlit as st
import torch
import torch.nn as nn
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import networkx as nx
import copy
from backend import ModelManager
class AblationEngine:
"""
Handles the 'Virtual Surgery' of models using PyTorch Hooks.
Instead of deleting code, we intercept signals during inference.
"""
def __init__(self, model_manager):
self.manager = model_manager
self.active_hooks = []
self.ablation_log = []
def clear_hooks(self):
"""Removes all active ablations (restores model to baseline)."""
for handle in self.active_hooks:
handle.remove()
self.active_hooks = []
def register_ablation(self, model, layer_name, ablation_type="zero_out", noise_level=0.1):
"""
Injects a hook into a specific layer to modify its output.
"""
target_module = dict(model.named_modules())[layer_name]
def hook_fn(module, input, output):
if ablation_type == "zero_out":
# Structural Ablation: Kill the signal
return output * 0.0
elif ablation_type == "add_noise":
# Robustness Test: Inject Gaussian noise
noise = torch.randn_like(output) * noise_level
return output + noise
elif ablation_type == "freeze_mean":
# Information Bottleneck: Replace with batch mean
return torch.mean(output, dim=0, keepdim=True).expand_as(output)
return output
# Register the hook
handle = target_module.register_forward_hook(hook_fn)
self.active_hooks.append(handle)
return f"Ablated {layer_name} ({ablation_type})"
class ArchitectureVisualizer:
"""
Builds a Netron-style interactive graph of the model layers using NetworkX + Plotly.
"""
@staticmethod
def build_layer_graph(model):
G = nx.DiGraph()
prev_node = "Input"
G.add_node("Input", type="Input")
# Walk through modules (simplified for visualization)
# We limit depth to avoid 10,000 node graphs for LLMs
for name, module in model.named_modules():
# Filter for high-level blocks only (Layers, Attention, MLP)
if any(k in name for k in ["layer", "block", "attn", "mlp"]) and "." not in name.split(".")[-1]:
# Heuristic: Connect sequential blocks
G.add_node(name, type=module.__class__.__name__, params=sum(p.numel() for p in module.parameters()))
G.add_edge(prev_node, name)
prev_node = name
G.add_node("Output", type="Output")
G.add_edge(prev_node, "Output")
return G
@staticmethod
def plot_interactive_graph(G):
pos = nx.spring_layout(G, seed=42, k=0.5)
edge_x, edge_y = [], []
for edge in G.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
edge_trace = go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=0.5, color='#888'),
hoverinfo='none', mode='lines'
)
node_x, node_y, node_text, node_color = [], [], [], []
for node in G.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
info = G.nodes[node]
node_text.append(f"{node}<br>{info.get('type', 'Unknown')}<br>Params: {info.get('params', 'N/A')}")
# Color coding
if "attn" in node.lower(): node_color.append("#FF0055") # Attention
elif "mlp" in node.lower(): node_color.append("#00CC96") # MLP
elif "layer" in node.lower(): node_color.append("#AB63FA") # Blocks
else: node_color.append("#FFFFFF")
node_trace = go.Scatter(
x=node_x, y=node_y,
mode='markers',
hoverinfo='text',
text=node_text,
marker=dict(showscale=False, color=node_color, size=15, line_width=2)
)
fig = go.Figure(data=[edge_trace, node_trace],
layout=go.Layout(
showlegend=False,
hovermode='closest',
margin=dict(b=0,l=0,r=0,t=0),
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)',
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
)
return fig
def render_ablation_dashboard():
# --- Custom CSS for the Dashboard Feel ---
st.markdown("""
<style>
.ablation-header {
background: linear-gradient(90deg, #FF4B4B 0%, #FF9068 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
font-size: 30px; font-weight: 900;
}
.stat-box {
background-color: #1E1E1E; border: 1px solid #333;
padding: 15px; border-radius: 5px; text-align: center;
}
.risk-high { border-left: 5px solid #FF4B4B; }
.risk-med { border-left: 5px solid #FFAA00; }
.risk-low { border-left: 5px solid #00FF00; }
</style>
""", unsafe_allow_html=True)
st.markdown('<div class="ablation-header">π§ͺ SYSTEMATIC ABLATION LAB</div>', unsafe_allow_html=True)
st.caption("Surgically alter model components to measure contribution and robustness.")
if 'models' not in st.session_state:
st.warning("Please load models in the Discovery tab first.")
return
# 1. Select Subject
col_sel, col_viz = st.columns([1, 3])
with col_sel:
st.subheader("1. Subject")
all_ids = st.session_state['models']['model_id'].tolist()
target_model_id = st.selectbox("Select Model for Surgery", all_ids)
# Load Model Button
if st.button("Initialize Surgery Table"):
with st.spinner("Preparing model for hooks..."):
succ, msg = st.session_state['manager'].load_model(target_model_id)
if succ:
st.success("Ready.")
st.session_state['ablation_target'] = target_model_id
# Initialize engine
st.session_state['ablation_engine'] = AblationEngine(st.session_state['manager'])
else:
st.error(msg)
# 2. Main Workspace
if 'ablation_target' in st.session_state:
target_id = st.session_state['ablation_target']
model_pkg = st.session_state['manager'].loaded_models.get(f"{target_id}_None") # Default FP32/16 key
if not model_pkg:
st.error("Model lost from memory. Please reload.")
return
model = model_pkg['model']
# --- TAB LAYOUT FOR ABLATION ---
t1, t2, t3 = st.tabs(["𧬠Structural Map", "πͺ Ablation Controls", "π Impact Report"])
# === TAB 1: ARCHITECTURE GRAPH ===
with t1:
st.markdown("### Interactive Architecture Map")
st.markdown("Visualize the flow to decide where to cut.")
if st.button("Generate Graph (Heavy Compute)"):
with st.spinner("Tracing neural pathways..."):
G = ArchitectureVisualizer.build_layer_graph(model)
fig = ArchitectureVisualizer.plot_interactive_graph(G)
st.plotly_chart(fig, use_container_width=True)
# === TAB 2: CONTROLS ===
with t2:
st.subheader("Configure Ablation Experiment")
c1, c2 = st.columns(2)
with c1:
# Get all layers
all_layers = [n for n, _ in model.named_modules() if len(n) > 0]
target_layers = st.multiselect("Select Target Layers", all_layers, max_selections=5)
with c2:
method = st.selectbox("Ablation Method",
["Zero-Out (Remove)", "Add Noise (Corrupt)", "Freeze Mean (Bottleneck)"])
if method == "Add Noise (Corrupt)":
noise_val = st.slider("Noise Level (Std Dev)", 0.0, 2.0, 0.1)
else:
noise_val = 0.0
if st.button("π΄ RUN ABLATION TEST"):
engine = st.session_state['ablation_engine']
engine.clear_hooks() # Reset previous
results_log = []
# 1. Establish Baseline
st.write("Measuring Baseline Performance...")
# We simply use a generation prompt length as a proxy for "Performance"
# or run a quick perplexity check if integrated with benchmarks.
# For this dashboard, we run the "Prompt Integrity Test"
prompt = "The capital of France is"
base_out = st.session_state['manager'].generate_text(target_id, "None", prompt)
results_log.append({"State": "Baseline", "Output": base_out, "Integrity": 100})
# 2. Apply Hooks
for layer in target_layers:
msg = engine.register_ablation(model, layer, method.lower().split()[0].replace("-","_"), noise_val)
st.toast(msg)
# 3. Measure Ablated Performance
st.write("Running Ablated Inference...")
ablated_out = st.session_state['manager'].generate_text(target_id, "None", prompt)
# Simple heuristic: String similarity or length retention
integrity = (len(ablated_out) / len(base_out)) * 100 if len(base_out) > 0 else 0
results_log.append({"State": "Ablated", "Output": ablated_out, "Integrity": integrity})
st.session_state['ablation_results'] = results_log
# Cleanup
engine.clear_hooks()
st.success("Experiment Complete. Hooks Removed.")
# === TAB 3: RESULTS ===
with t3:
if 'ablation_results' in st.session_state:
res = st.session_state['ablation_results']
# Visual Diff
st.markdown("### π Output Degradation Analysis")
col_base, col_abl = st.columns(2)
with col_base:
st.info(f"**Baseline:** {res[0]['Output']}")
with col_abl:
st.warning(f"**Ablated:** {res[1]['Output']}")
# Metrics
deg = 100 - res[1]['Integrity']
st.metric("Model Degradation", f"{deg:.1f}%", delta=f"-{deg:.1f}%", delta_color="inverse")
# Sensitivity Chart (Mocked for single run, would need loop for real sensitivity analysis)
st.markdown("### π₯ Layer Sensitivity Heatmap")
# Creating dummy data to show what the "full suite" would look like
sens_data = pd.DataFrame({
"Layer": ["embed", "layer.0", "layer.1", "layer.2", "head"],
"Sensitivity Score": [95, 10, 15, 80, 100]
})
fig = px.bar(sens_data, x="Layer", y="Sensitivity Score",
color="Sensitivity Score", color_continuous_scale="RdYlGn_r",
title="Estimated Contribution to Output (Simulated)")
st.plotly_chart(fig, use_container_width=True)
else:
st.info("Run an experiment in Tab 2 to see results.") |