Spaces:
Sleeping
Sleeping
| """ | |
| Model loading and inference for Logit Lens Explorer. | |
| Loads Qwen3-1.7B and provides inference with hidden state | |
| capture for logit lens visualization. | |
| Part of E02: Logit Lens Explorer. | |
| """ | |
| from dataclasses import dataclass | |
| from typing import Generator | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache | |
| MODEL_ID = "Qwen/Qwen3-1.7B" | |
| _model = None | |
| _tokenizer = None | |
| _device = None | |
| class LayerPrediction: | |
| """Top-k token predictions from a single transformer layer.""" | |
| layer_index: int # 0 = embedding, 1-28 = transformer layers | |
| top_tokens: list[dict] # [{"token": str, "probability": float}, ...] | |
| class TokenData: | |
| """Data for a single generated token with per-layer logit lens predictions.""" | |
| token: str | |
| token_id: int | |
| probability: float | |
| layer_predictions: list[LayerPrediction] # len = 29 (embedding + 28 layers) | |
| def load_model(): | |
| """Load the Qwen model and tokenizer. Uses cached singleton.""" | |
| global _model, _tokenizer, _device | |
| if _model is not None: | |
| return _model, _tokenizer | |
| _device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {_device}") | |
| print(f"Loading model: {MODEL_ID}...") | |
| _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| _model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| attn_implementation="flash_attention_2", | |
| torch_dtype=torch.float16, | |
| ).to(_device).eval() | |
| print("Model loaded successfully") | |
| return _model, _tokenizer | |
| def project_hidden_states( | |
| hidden_states: torch.Tensor, | |
| model, | |
| tokenizer, | |
| top_k: int = 20, | |
| final_logits: torch.Tensor | None = None, | |
| ) -> list[LayerPrediction]: | |
| """Batch-project hidden states through RMSNorm + lm_head. | |
| Takes stacked hidden states from all layers and projects them through | |
| the model's final normalization and unembedding head in a single | |
| batched operation. | |
| For the final layer, uses the model's actual output logits (if provided) | |
| instead of re-projecting, to avoid numerical precision drift between | |
| the projection and the native forward pass. | |
| Args: | |
| hidden_states: Stacked hidden states, shape (n_layers, 1, hidden_dim). | |
| model: The causal LM model with .model.norm and .lm_head. | |
| tokenizer: Tokenizer for decoding token IDs. | |
| top_k: Number of top predictions per layer. | |
| final_logits: The model's actual output logits for the last position, | |
| shape (vocab_size,). Used for the final layer to guarantee | |
| consistency with greedy decoding. | |
| Returns: | |
| List of LayerPrediction, one per layer. | |
| """ | |
| # Reshape to (n_layers, hidden_dim), removing any size-1 middle dims, upcast to float32 | |
| n_layers = hidden_states.shape[0] | |
| hidden_dim = hidden_states.shape[-1] | |
| hs = hidden_states.reshape(n_layers, hidden_dim).float() | |
| # Apply final RMSNorm (float32 for numerical stability) | |
| normed = model.model.norm(hs) | |
| # Cast back to model weight dtype for lm_head linear projection | |
| logits = model.lm_head(normed.to(model.lm_head.weight.dtype)) | |
| # Replace final layer logits with the model's actual output if provided | |
| if final_logits is not None: | |
| logits[-1] = final_logits | |
| # Softmax in float32 to avoid overflow | |
| probs = torch.softmax(logits.float(), dim=-1) | |
| top_probs, top_indices = torch.topk(probs, k=top_k, dim=-1) | |
| # Move to CPU once for all layers | |
| top_probs_cpu = top_probs.cpu().tolist() | |
| top_indices_cpu = top_indices.cpu().tolist() | |
| predictions = [] | |
| for layer_idx in range(len(top_probs_cpu)): | |
| top_tokens = [ | |
| {"token": tokenizer.decode([int(idx)]), "probability": prob} | |
| for prob, idx in zip(top_probs_cpu[layer_idx], top_indices_cpu[layer_idx]) | |
| ] | |
| predictions.append(LayerPrediction( | |
| layer_index=layer_idx, | |
| top_tokens=top_tokens, | |
| )) | |
| return predictions | |
| def generate_with_logit_lens( | |
| prompt: str, | |
| max_new_tokens: int = 512, | |
| top_k: int = 20, | |
| ) -> Generator[TokenData, None, None]: | |
| """Generate text token-by-token with per-layer logit lens predictions. | |
| Uses greedy decoding (argmax) for deterministic text generation, but | |
| records the natural softmax probabilities (temperature=1) for the logit | |
| lens visualization so layer predictions reflect the model's true | |
| confidence distribution. | |
| Args: | |
| prompt: User prompt text. | |
| max_new_tokens: Maximum tokens to generate. | |
| top_k: Number of top predictions per layer for logit lens. | |
| Yields: | |
| TokenData with token string, ID, probability, and per-layer predictions. | |
| """ | |
| model, tokenizer = load_model() | |
| messages = [{"role": "user", "content": prompt}] | |
| prompt_full = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = tokenizer(prompt_full, return_tensors="pt").to(_device) | |
| input_ids = inputs.input_ids | |
| attention_mask = inputs.attention_mask | |
| # EOS token IDs for stopping | |
| eos_token_id = model.config.eos_token_id | |
| if isinstance(eos_token_id, int): | |
| eos_token_id = [eos_token_id] | |
| elif eos_token_id is None: | |
| eos_token_id = [] | |
| generated_ids = input_ids.clone() | |
| past_key_values = DynamicCache() | |
| seq_length = input_ids.shape[1] | |
| with torch.no_grad(): | |
| for step in range(max_new_tokens): | |
| if step == 0: | |
| cache_position = torch.arange(seq_length, device=_device) | |
| outputs = model( | |
| input_ids=generated_ids, | |
| attention_mask=attention_mask, | |
| cache_position=cache_position, | |
| past_key_values=past_key_values, | |
| output_hidden_states=True, | |
| return_dict=True, | |
| use_cache=True, | |
| ) | |
| else: | |
| cache_position = torch.tensor([seq_length], device=_device) | |
| outputs = model( | |
| input_ids=generated_ids[:, -1:], | |
| attention_mask=attention_mask, | |
| cache_position=cache_position, | |
| past_key_values=past_key_values, | |
| output_hidden_states=True, | |
| return_dict=True, | |
| use_cache=True, | |
| ) | |
| past_key_values = outputs.past_key_values | |
| # Greedy decoding with natural probability recording | |
| next_token_logits = outputs.logits[:, -1, :].float() | |
| probs = torch.softmax(next_token_logits, dim=-1) | |
| next_token_id = torch.argmax(probs, dim=-1).item() | |
| next_token_prob = probs[0, next_token_id].item() | |
| if next_token_id in eos_token_id: | |
| break | |
| # Eager logit lens: stack last-position hidden state from each layer | |
| # outputs.hidden_states is a tuple of (n_layers+1) tensors, | |
| # each shape (batch, seq_len, hidden_dim) | |
| hidden_states = torch.stack([ | |
| hs[:, -1:, :] for hs in outputs.hidden_states | |
| ]) # (n_layers, 1, hidden_dim) | |
| layer_predictions = project_hidden_states( | |
| hidden_states, model, tokenizer, top_k=top_k, | |
| final_logits=next_token_logits[0], | |
| ) | |
| token_str = tokenizer.decode([next_token_id]) | |
| yield TokenData( | |
| token=token_str, | |
| token_id=next_token_id, | |
| probability=next_token_prob, | |
| layer_predictions=layer_predictions, | |
| ) | |
| # Update for next iteration | |
| next_token_tensor = torch.tensor([[next_token_id]], device=_device) | |
| generated_ids = torch.cat([generated_ids, next_token_tensor], dim=-1) | |
| attention_mask = torch.cat( | |
| [attention_mask, torch.ones((1, 1), device=_device, dtype=attention_mask.dtype)], | |
| dim=-1, | |
| ) | |
| seq_length += 1 | |