import streamlit as st
import torch
import torch.nn as nn
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import networkx as nx
import copy
from backend import ModelManager
class AblationEngine:
"""
Handles the 'Virtual Surgery' of models using PyTorch Hooks.
Instead of deleting code, we intercept signals during inference.
"""
def __init__(self, model_manager):
self.manager = model_manager
self.active_hooks = []
self.ablation_log = []
def clear_hooks(self):
"""Removes all active ablations (restores model to baseline)."""
for handle in self.active_hooks:
handle.remove()
self.active_hooks = []
def register_ablation(self, model, layer_name, ablation_type="zero_out", noise_level=0.1):
"""
Injects a hook into a specific layer to modify its output.
"""
target_module = dict(model.named_modules())[layer_name]
def hook_fn(module, input, output):
if ablation_type == "zero_out":
# Structural Ablation: Kill the signal
return output * 0.0
elif ablation_type == "add_noise":
# Robustness Test: Inject Gaussian noise
noise = torch.randn_like(output) * noise_level
return output + noise
elif ablation_type == "freeze_mean":
# Information Bottleneck: Replace with batch mean
return torch.mean(output, dim=0, keepdim=True).expand_as(output)
return output
# Register the hook
handle = target_module.register_forward_hook(hook_fn)
self.active_hooks.append(handle)
return f"Ablated {layer_name} ({ablation_type})"
class ArchitectureVisualizer:
"""
Builds a Netron-style interactive graph of the model layers using NetworkX + Plotly.
"""
@staticmethod
def build_layer_graph(model):
G = nx.DiGraph()
prev_node = "Input"
G.add_node("Input", type="Input")
# Walk through modules (simplified for visualization)
# We limit depth to avoid 10,000 node graphs for LLMs
for name, module in model.named_modules():
# Filter for high-level blocks only (Layers, Attention, MLP)
if any(k in name for k in ["layer", "block", "attn", "mlp"]) and "." not in name.split(".")[-1]:
# Heuristic: Connect sequential blocks
G.add_node(name, type=module.__class__.__name__, params=sum(p.numel() for p in module.parameters()))
G.add_edge(prev_node, name)
prev_node = name
G.add_node("Output", type="Output")
G.add_edge(prev_node, "Output")
return G
@staticmethod
def plot_interactive_graph(G):
pos = nx.spring_layout(G, seed=42, k=0.5)
edge_x, edge_y = [], []
for edge in G.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
edge_trace = go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=0.5, color='#888'),
hoverinfo='none', mode='lines'
)
node_x, node_y, node_text, node_color = [], [], [], []
for node in G.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
info = G.nodes[node]
node_text.append(f"{node}
{info.get('type', 'Unknown')}
Params: {info.get('params', 'N/A')}")
# Color coding
if "attn" in node.lower(): node_color.append("#FF0055") # Attention
elif "mlp" in node.lower(): node_color.append("#00CC96") # MLP
elif "layer" in node.lower(): node_color.append("#AB63FA") # Blocks
else: node_color.append("#FFFFFF")
node_trace = go.Scatter(
x=node_x, y=node_y,
mode='markers',
hoverinfo='text',
text=node_text,
marker=dict(showscale=False, color=node_color, size=15, line_width=2)
)
fig = go.Figure(data=[edge_trace, node_trace],
layout=go.Layout(
showlegend=False,
hovermode='closest',
margin=dict(b=0,l=0,r=0,t=0),
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)',
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
)
return fig
def render_ablation_dashboard():
# --- Custom CSS for the Dashboard Feel ---
st.markdown("""
""", unsafe_allow_html=True)
st.markdown('