NeuroScope / extraction.py
Alogotron's picture
Upload extraction.py with huggingface_hub
50aba66 verified
"""
NeuroScope — Activation Extraction Pipeline
Loads Qwen3-4B and extracts hidden states + attention patterns for visualization.
Includes a demo mode that generates realistic synthetic data for GPU-free UI testing.
Architecture reference (Qwen3-4B):
- 36 hidden layers, 32 attention heads (GQA with 8 KV heads)
- 2560 hidden dim, 80 head dim
- RoPE positional encoding, SwiGLU MLP
Usage:
from extraction import ActivationExtractor, ExtractionResult
result = ActivationExtractor.generate_demo_data("Hello world")
# or: extractor = ActivationExtractor(); extractor.load_model(); result = extractor.extract("Hello")
"""
import time
import numpy as np
from dataclasses import dataclass
from typing import Optional
# ---------------------------------------------------------------------------
# Qwen3-4B architecture defaults (overridden at runtime when model loads)
# ---------------------------------------------------------------------------
DEFAULT_NUM_LAYERS = 36
DEFAULT_NUM_HEADS = 32
DEFAULT_NUM_KV_HEADS = 8
DEFAULT_HIDDEN_DIM = 2560
DEFAULT_HEAD_DIM = DEFAULT_HIDDEN_DIM // DEFAULT_NUM_HEADS # 80
@dataclass
class ExtractionResult:
"""Structured output from a forward pass or demo data generation."""
tokens: list[str] # Decoded token strings
hidden_states: np.ndarray # (num_layers+1, seq_len, hidden_dim) — includes embedding layer
attentions: np.ndarray # (num_layers, num_heads, seq_len, seq_len)
num_layers: int
num_heads: int
hidden_dim: int
inference_time: float # Seconds
is_demo: bool = False
class ActivationExtractor:
"""Manages Qwen3-4B loading, inference, and activation capture."""
def __init__(self):
self.model = None
self.tokenizer = None
self.device = None
self.num_layers = DEFAULT_NUM_LAYERS
self.num_heads = DEFAULT_NUM_HEADS
self.hidden_dim = DEFAULT_HIDDEN_DIM
self.model_loaded = False
def load_model(
self,
model_name: str = "Qwen/Qwen3-4B",
quantize: bool = False,
) -> str:
"""Load model with optional 4-bit quantization for VRAM efficiency.
Args:
model_name: HuggingFace model identifier.
quantize: If True, use bitsandbytes 4-bit NF4 quantization (~3 GB VRAM).
Returns:
Status string with detected architecture info.
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
load_kwargs: dict = {
"dtype": torch.bfloat16,
"device_map": "auto",
"trust_remote_code": True,
"attn_implementation": "eager",
}
if quantize:
from transformers import BitsAndBytesConfig
load_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
)
self.model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs)
self.model.eval()
# Auto-detect architecture from model config
cfg = self.model.config
self.num_layers = cfg.num_hidden_layers
self.num_heads = cfg.num_attention_heads
self.hidden_dim = cfg.hidden_size
self.device = next(self.model.parameters()).device
self.model_loaded = True
return (
f"✅ Loaded {model_name}: {self.num_layers} layers, "
f"{self.num_heads} heads, {self.hidden_dim} hidden dim, "
f"device={self.device}"
)
def extract(self, prompt: str) -> ExtractionResult:
"""Run forward pass and extract all hidden states + attention weights.
Uses HuggingFace native output_attentions / output_hidden_states for
simplicity and broad model compatibility.
"""
import torch
if not self.model_loaded:
raise RuntimeError(
"Model not loaded. Call load_model() first or use generate_demo_data()."
)
t0 = time.time()
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(
**inputs,
output_attentions=True,
output_hidden_states=True,
)
inference_time = time.time() - t0
# Decode token strings (clean up common BPE prefixes)
token_ids = inputs.input_ids[0].tolist()
tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
tokens = [self._clean_token(t) for t in tokens]
# Stack hidden states → (num_layers+1, seq_len, hidden_dim)
hidden_states = np.stack(
[hs[0].float().cpu().numpy() for hs in outputs.hidden_states]
)
# Stack attentions → (num_layers, num_heads, seq_len, seq_len)
attentions = np.stack(
[attn[0].float().cpu().numpy() for attn in outputs.attentions]
)
return ExtractionResult(
tokens=tokens,
hidden_states=hidden_states,
attentions=attentions,
num_layers=self.num_layers,
num_heads=self.num_heads,
hidden_dim=self.hidden_dim,
inference_time=inference_time,
is_demo=False,
)
def generate_streaming(
self,
prompt: str,
max_new_tokens: int = 32,
):
"""Generate tokens one-by-one, yielding ExtractionResult after each step.
This is a Python generator. Each yield produces an ExtractionResult
containing the full sequence so far (prompt + generated tokens) with
fresh hidden states and attention weights.
Args:
prompt: Input text to continue generating from.
max_new_tokens: Maximum number of new tokens to generate.
Yields:
ExtractionResult for the growing sequence after each new token.
"""
import torch
if not self.model_loaded:
raise RuntimeError(
"Model not loaded. Call load_model() first."
)
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
input_ids = inputs.input_ids
t0 = time.time()
for step in range(max_new_tokens):
with torch.no_grad():
outputs = self.model(
input_ids=input_ids,
output_attentions=True,
output_hidden_states=True,
)
# Greedy decode next token
next_token_id = outputs.logits[0, -1].argmax(dim=-1).unsqueeze(0).unsqueeze(0)
# Check for EOS
if next_token_id.item() == self.tokenizer.eos_token_id:
break
# Build result for current sequence
token_ids = input_ids[0].tolist()
tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
tokens = [self._clean_token(t) for t in tokens]
hidden_states = np.stack(
[hs[0].float().cpu().numpy() for hs in outputs.hidden_states]
)
attentions = np.stack(
[attn[0].float().cpu().numpy() for attn in outputs.attentions]
)
yield ExtractionResult(
tokens=tokens,
hidden_states=hidden_states,
attentions=attentions,
num_layers=self.num_layers,
num_heads=self.num_heads,
hidden_dim=self.hidden_dim,
inference_time=time.time() - t0,
is_demo=False,
)
# Extend sequence for next iteration
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
@staticmethod
def generate_demo_streaming(
prompt: str = "The quick brown fox jumps over the lazy dog",
max_new_tokens: int = 12,
):
"""Yield demo ExtractionResults simulating token-by-token generation."""
# Generate full demo data, then yield growing slices
base = ActivationExtractor.generate_demo_data(prompt)
# Simulate additional generated tokens
rng = np.random.RandomState(99)
gen_tokens = ["and", "then", "it", "ran", "across", "the",
"field", "into", "the", "forest", ".", "<eos>"]
gen_tokens = gen_tokens[:max_new_tokens]
all_tokens = list(base.tokens)
all_hs = list(base.hidden_states.transpose(1, 0, 2)) # list of (n_layers+1, hidden_dim) per token
all_attn = base.attentions.copy() # will rebuild each step
t0 = time.time()
for step, tok in enumerate(gen_tokens):
all_tokens.append(tok)
seq_len = len(all_tokens)
# Generate a new hidden state column for this token
new_hs = np.zeros((base.num_layers + 1, base.hidden_dim), dtype=np.float32)
for layer in range(base.num_layers + 1):
base_mag = 5.0 + layer * 0.8
noise = rng.randn(base.hidden_dim).astype(np.float32) * (1.0 + layer * 0.1)
noise[:64] += base_mag * np.sin(
np.arange(64) * (seq_len) / 12.0
).astype(np.float32)
new_hs[layer] = noise
all_hs.append(new_hs)
# Stack hidden states for current sequence
hs_array = np.stack(all_hs, axis=1) # (n_layers+1, seq_len, hidden_dim)
# Rebuild attention matrices at new seq_len
attn_array = np.zeros(
(base.num_layers, base.num_heads, seq_len, seq_len),
dtype=np.float32,
)
for layer in range(base.num_layers):
for head in range(base.num_heads):
raw = np.tril(rng.exponential(0.5, (seq_len, seq_len)).astype(np.float32))
# Simple causal softmax
mask = np.triu(np.full((seq_len, seq_len), -1e9, dtype=np.float32), k=1)
logits = raw + mask
logits -= logits.max(axis=-1, keepdims=True)
exp = np.exp(logits)
attn_array[layer, head] = exp / (exp.sum(axis=-1, keepdims=True) + 1e-8)
yield ExtractionResult(
tokens=list(all_tokens),
hidden_states=hs_array,
attentions=attn_array,
num_layers=base.num_layers,
num_heads=base.num_heads,
hidden_dim=base.hidden_dim,
inference_time=time.time() - t0,
is_demo=True,
)
time.sleep(0.3) # Simulate generation delay
# -------------------------------------------------------------------
# Demo data generation (no GPU required)
# -------------------------------------------------------------------
@staticmethod
def generate_demo_data(
prompt: str = "The quick brown fox jumps over the lazy dog",
) -> ExtractionResult:
"""Generate realistic synthetic data matching Qwen3-4B dimensions.
Produces structured patterns that look plausible in all four
visualization views:
- Attention: causal masks with head-specific specialization
- Magnitude: increasing L2 norms through depth
- Token-layer grid: per-token evolution with semantic clustering
- Scatter: separable token clusters in PCA space
"""
t0 = time.time()
rng = np.random.RandomState(42)
# Simulate tokenization (split on whitespace, add BOS)
raw_tokens = prompt.replace(",", " ,").replace(".", " .").split()
tokens = ["<|im_start|>"] + raw_tokens
seq_len = len(tokens)
num_layers = DEFAULT_NUM_LAYERS
num_heads = DEFAULT_NUM_HEADS
hidden_dim = DEFAULT_HIDDEN_DIM
# -- Hidden states with realistic depth-dependent structure ----------
hidden_states = np.zeros(
(num_layers + 1, seq_len, hidden_dim), dtype=np.float32
)
for layer in range(num_layers + 1):
# Base magnitude grows through layers (empirical LLM pattern)
base_mag = 5.0 + layer * 0.8
noise_scale = 1.0 + layer * 0.1
hs = rng.randn(seq_len, hidden_dim).astype(np.float32) * noise_scale
for t in range(seq_len):
# Position-dependent sinusoidal bias (simulates positional features)
hs[t, :64] += base_mag * np.sin(
np.arange(64) * (t + 1) / 12.0
).astype(np.float32)
# Layer-specific feature band activation
band_start = (layer * 70) % hidden_dim
band_end = min(band_start + 70, hidden_dim)
hs[t, band_start:band_end] += base_mag * 0.5
# Content words get stronger activations in middle layers
if 10 <= layer <= 28 and t > 0 and len(raw_tokens[t - 1]) > 3:
hs[t, :256] *= 1.3
hidden_states[layer] = hs
# -- Attention patterns with head specialization --------------------
attentions = np.zeros(
(num_layers, num_heads, seq_len, seq_len), dtype=np.float32
)
for layer in range(num_layers):
for head in range(num_heads):
raw = np.tril(
rng.exponential(1.0, (seq_len, seq_len)).astype(np.float32)
)
# Head-type specialization (observed in real LLMs)
head_type = head % 6
if head_type == 0:
# Local window attention (±3 tokens)
for i in range(seq_len):
lo = max(0, i - 3)
raw[i, lo : i + 1] *= 4.0
elif head_type == 1:
# BOS / sink attention
raw[:, 0] *= 6.0
elif head_type == 2:
# Previous-token (induction-style)
for i in range(1, seq_len):
raw[i, i - 1] *= 5.0
elif head_type == 3:
# Copy / identity (diagonal)
for i in range(seq_len):
raw[i, i] *= 5.0
elif head_type == 4:
# Long-range (attend to early tokens)
raw[:, : min(3, seq_len)] *= 3.0
# head_type == 5: uniform / mixed (no special pattern)
# Causal softmax
mask = np.triu(
np.full((seq_len, seq_len), -1e9, dtype=np.float32), k=1
)
logits = raw + mask
logits -= logits.max(axis=-1, keepdims=True)
exp = np.exp(logits)
attentions[layer, head] = exp / (
exp.sum(axis=-1, keepdims=True) + 1e-8
)
inference_time = time.time() - t0
return ExtractionResult(
tokens=tokens,
hidden_states=hidden_states,
attentions=attentions,
num_layers=num_layers,
num_heads=num_heads,
hidden_dim=hidden_dim,
inference_time=inference_time,
is_demo=True,
)
@staticmethod
def _clean_token(tok: str) -> str:
"""Clean BPE artifacts from token string for display."""
return (
tok.replace("Ġ", " ")
.replace("▁", " ")
.replace("Ċ", "\\n")
.replace("ĉ", "\\t")
)