""" Model loading and inference for Logit Lens Explorer. Loads Qwen3-1.7B and provides inference with hidden state capture for logit lens visualization. Part of E02: Logit Lens Explorer. """ from dataclasses import dataclass from typing import Generator import torch from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache MODEL_ID = "Qwen/Qwen3-1.7B" _model = None _tokenizer = None _device = None @dataclass class LayerPrediction: """Top-k token predictions from a single transformer layer.""" layer_index: int # 0 = embedding, 1-28 = transformer layers top_tokens: list[dict] # [{"token": str, "probability": float}, ...] @dataclass class TokenData: """Data for a single generated token with per-layer logit lens predictions.""" token: str token_id: int probability: float layer_predictions: list[LayerPrediction] # len = 29 (embedding + 28 layers) def load_model(): """Load the Qwen model and tokenizer. Uses cached singleton.""" global _model, _tokenizer, _device if _model is not None: return _model, _tokenizer _device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"Using device: {_device}") print(f"Loading model: {MODEL_ID}...") _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) _model = AutoModelForCausalLM.from_pretrained( MODEL_ID, attn_implementation="flash_attention_2", torch_dtype=torch.float16, ).to(_device).eval() print("Model loaded successfully") return _model, _tokenizer def project_hidden_states( hidden_states: torch.Tensor, model, tokenizer, top_k: int = 20, final_logits: torch.Tensor | None = None, ) -> list[LayerPrediction]: """Batch-project hidden states through RMSNorm + lm_head. Takes stacked hidden states from all layers and projects them through the model's final normalization and unembedding head in a single batched operation. For the final layer, uses the model's actual output logits (if provided) instead of re-projecting, to avoid numerical precision drift between the projection and the native forward pass. Args: hidden_states: Stacked hidden states, shape (n_layers, 1, hidden_dim). model: The causal LM model with .model.norm and .lm_head. tokenizer: Tokenizer for decoding token IDs. top_k: Number of top predictions per layer. final_logits: The model's actual output logits for the last position, shape (vocab_size,). Used for the final layer to guarantee consistency with greedy decoding. Returns: List of LayerPrediction, one per layer. """ # Reshape to (n_layers, hidden_dim), removing any size-1 middle dims, upcast to float32 n_layers = hidden_states.shape[0] hidden_dim = hidden_states.shape[-1] hs = hidden_states.reshape(n_layers, hidden_dim).float() # Apply final RMSNorm (float32 for numerical stability) normed = model.model.norm(hs) # Cast back to model weight dtype for lm_head linear projection logits = model.lm_head(normed.to(model.lm_head.weight.dtype)) # Replace final layer logits with the model's actual output if provided if final_logits is not None: logits[-1] = final_logits # Softmax in float32 to avoid overflow probs = torch.softmax(logits.float(), dim=-1) top_probs, top_indices = torch.topk(probs, k=top_k, dim=-1) # Move to CPU once for all layers top_probs_cpu = top_probs.cpu().tolist() top_indices_cpu = top_indices.cpu().tolist() predictions = [] for layer_idx in range(len(top_probs_cpu)): top_tokens = [ {"token": tokenizer.decode([int(idx)]), "probability": prob} for prob, idx in zip(top_probs_cpu[layer_idx], top_indices_cpu[layer_idx]) ] predictions.append(LayerPrediction( layer_index=layer_idx, top_tokens=top_tokens, )) return predictions def generate_with_logit_lens( prompt: str, max_new_tokens: int = 512, top_k: int = 20, ) -> Generator[TokenData, None, None]: """Generate text token-by-token with per-layer logit lens predictions. Uses greedy decoding (argmax) for deterministic text generation, but records the natural softmax probabilities (temperature=1) for the logit lens visualization so layer predictions reflect the model's true confidence distribution. Args: prompt: User prompt text. max_new_tokens: Maximum tokens to generate. top_k: Number of top predictions per layer for logit lens. Yields: TokenData with token string, ID, probability, and per-layer predictions. """ model, tokenizer = load_model() messages = [{"role": "user", "content": prompt}] prompt_full = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(prompt_full, return_tensors="pt").to(_device) input_ids = inputs.input_ids attention_mask = inputs.attention_mask # EOS token IDs for stopping eos_token_id = model.config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] elif eos_token_id is None: eos_token_id = [] generated_ids = input_ids.clone() past_key_values = DynamicCache() seq_length = input_ids.shape[1] with torch.no_grad(): for step in range(max_new_tokens): if step == 0: cache_position = torch.arange(seq_length, device=_device) outputs = model( input_ids=generated_ids, attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, output_hidden_states=True, return_dict=True, use_cache=True, ) else: cache_position = torch.tensor([seq_length], device=_device) outputs = model( input_ids=generated_ids[:, -1:], attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, output_hidden_states=True, return_dict=True, use_cache=True, ) past_key_values = outputs.past_key_values # Greedy decoding with natural probability recording next_token_logits = outputs.logits[:, -1, :].float() probs = torch.softmax(next_token_logits, dim=-1) next_token_id = torch.argmax(probs, dim=-1).item() next_token_prob = probs[0, next_token_id].item() if next_token_id in eos_token_id: break # Eager logit lens: stack last-position hidden state from each layer # outputs.hidden_states is a tuple of (n_layers+1) tensors, # each shape (batch, seq_len, hidden_dim) hidden_states = torch.stack([ hs[:, -1:, :] for hs in outputs.hidden_states ]) # (n_layers, 1, hidden_dim) layer_predictions = project_hidden_states( hidden_states, model, tokenizer, top_k=top_k, final_logits=next_token_logits[0], ) token_str = tokenizer.decode([next_token_id]) yield TokenData( token=token_str, token_id=next_token_id, probability=next_token_prob, layer_predictions=layer_predictions, ) # Update for next iteration next_token_tensor = torch.tensor([[next_token_id]], device=_device) generated_ids = torch.cat([generated_ids, next_token_tensor], dim=-1) attention_mask = torch.cat( [attention_mask, torch.ones((1, 1), device=_device, dtype=attention_mask.dtype)], dim=-1, ) seq_length += 1