logit-lens / model.py
ryandt's picture
Fixed percentile sampling
3235b82
"""
Model loading and inference for Logit Lens Explorer.
Loads Qwen3-1.7B and provides inference with hidden state
capture for logit lens visualization.
Part of E02: Logit Lens Explorer.
"""
from dataclasses import dataclass
from typing import Generator
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
MODEL_ID = "Qwen/Qwen3-1.7B"
_model = None
_tokenizer = None
_device = None
@dataclass
class LayerPrediction:
"""Top-k token predictions from a single transformer layer."""
layer_index: int # 0 = embedding, 1-28 = transformer layers
top_tokens: list[dict] # [{"token": str, "probability": float}, ...]
@dataclass
class TokenData:
"""Data for a single generated token with per-layer logit lens predictions."""
token: str
token_id: int
probability: float
layer_predictions: list[LayerPrediction] # len = 29 (embedding + 28 layers)
def load_model():
"""Load the Qwen model and tokenizer. Uses cached singleton."""
global _model, _tokenizer, _device
if _model is not None:
return _model, _tokenizer
_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {_device}")
print(f"Loading model: {MODEL_ID}...")
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
).to(_device).eval()
print("Model loaded successfully")
return _model, _tokenizer
def project_hidden_states(
hidden_states: torch.Tensor,
model,
tokenizer,
top_k: int = 20,
final_logits: torch.Tensor | None = None,
) -> list[LayerPrediction]:
"""Batch-project hidden states through RMSNorm + lm_head.
Takes stacked hidden states from all layers and projects them through
the model's final normalization and unembedding head in a single
batched operation.
For the final layer, uses the model's actual output logits (if provided)
instead of re-projecting, to avoid numerical precision drift between
the projection and the native forward pass.
Args:
hidden_states: Stacked hidden states, shape (n_layers, 1, hidden_dim).
model: The causal LM model with .model.norm and .lm_head.
tokenizer: Tokenizer for decoding token IDs.
top_k: Number of top predictions per layer.
final_logits: The model's actual output logits for the last position,
shape (vocab_size,). Used for the final layer to guarantee
consistency with greedy decoding.
Returns:
List of LayerPrediction, one per layer.
"""
# Reshape to (n_layers, hidden_dim), removing any size-1 middle dims, upcast to float32
n_layers = hidden_states.shape[0]
hidden_dim = hidden_states.shape[-1]
hs = hidden_states.reshape(n_layers, hidden_dim).float()
# Apply final RMSNorm (float32 for numerical stability)
normed = model.model.norm(hs)
# Cast back to model weight dtype for lm_head linear projection
logits = model.lm_head(normed.to(model.lm_head.weight.dtype))
# Replace final layer logits with the model's actual output if provided
if final_logits is not None:
logits[-1] = final_logits
# Softmax in float32 to avoid overflow
probs = torch.softmax(logits.float(), dim=-1)
top_probs, top_indices = torch.topk(probs, k=top_k, dim=-1)
# Move to CPU once for all layers
top_probs_cpu = top_probs.cpu().tolist()
top_indices_cpu = top_indices.cpu().tolist()
predictions = []
for layer_idx in range(len(top_probs_cpu)):
top_tokens = [
{"token": tokenizer.decode([int(idx)]), "probability": prob}
for prob, idx in zip(top_probs_cpu[layer_idx], top_indices_cpu[layer_idx])
]
predictions.append(LayerPrediction(
layer_index=layer_idx,
top_tokens=top_tokens,
))
return predictions
def generate_with_logit_lens(
prompt: str,
max_new_tokens: int = 512,
top_k: int = 20,
) -> Generator[TokenData, None, None]:
"""Generate text token-by-token with per-layer logit lens predictions.
Uses greedy decoding (argmax) for deterministic text generation, but
records the natural softmax probabilities (temperature=1) for the logit
lens visualization so layer predictions reflect the model's true
confidence distribution.
Args:
prompt: User prompt text.
max_new_tokens: Maximum tokens to generate.
top_k: Number of top predictions per layer for logit lens.
Yields:
TokenData with token string, ID, probability, and per-layer predictions.
"""
model, tokenizer = load_model()
messages = [{"role": "user", "content": prompt}]
prompt_full = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(prompt_full, return_tensors="pt").to(_device)
input_ids = inputs.input_ids
attention_mask = inputs.attention_mask
# EOS token IDs for stopping
eos_token_id = model.config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
elif eos_token_id is None:
eos_token_id = []
generated_ids = input_ids.clone()
past_key_values = DynamicCache()
seq_length = input_ids.shape[1]
with torch.no_grad():
for step in range(max_new_tokens):
if step == 0:
cache_position = torch.arange(seq_length, device=_device)
outputs = model(
input_ids=generated_ids,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_hidden_states=True,
return_dict=True,
use_cache=True,
)
else:
cache_position = torch.tensor([seq_length], device=_device)
outputs = model(
input_ids=generated_ids[:, -1:],
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
output_hidden_states=True,
return_dict=True,
use_cache=True,
)
past_key_values = outputs.past_key_values
# Greedy decoding with natural probability recording
next_token_logits = outputs.logits[:, -1, :].float()
probs = torch.softmax(next_token_logits, dim=-1)
next_token_id = torch.argmax(probs, dim=-1).item()
next_token_prob = probs[0, next_token_id].item()
if next_token_id in eos_token_id:
break
# Eager logit lens: stack last-position hidden state from each layer
# outputs.hidden_states is a tuple of (n_layers+1) tensors,
# each shape (batch, seq_len, hidden_dim)
hidden_states = torch.stack([
hs[:, -1:, :] for hs in outputs.hidden_states
]) # (n_layers, 1, hidden_dim)
layer_predictions = project_hidden_states(
hidden_states, model, tokenizer, top_k=top_k,
final_logits=next_token_logits[0],
)
token_str = tokenizer.decode([next_token_id])
yield TokenData(
token=token_str,
token_id=next_token_id,
probability=next_token_prob,
layer_predictions=layer_predictions,
)
# Update for next iteration
next_token_tensor = torch.tensor([[next_token_id]], device=_device)
generated_ids = torch.cat([generated_ids, next_token_tensor], dim=-1)
attention_mask = torch.cat(
[attention_mask, torch.ones((1, 1), device=_device, dtype=attention_mask.dtype)],
dim=-1,
)
seq_length += 1