Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| """ | |
| Unified Model Service for Visualisable.ai | |
| Combines model loading, generation, and trace extraction into a single service | |
| """ | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, BackgroundTasks, HTTPException, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import asyncio | |
| import json | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from typing import Optional, List, Dict, Any | |
| import numpy as np | |
| import logging | |
| from datetime import datetime | |
| import traceback | |
| from .auth import verify_api_key | |
| from .instrumentation import ModelInstrumentor, InstrumentationData, TokenMetadata | |
| from .storage import ZarrStorage, generate_run_id | |
| from .attention_analysis import AttentionRollout, HeadRanker, compute_token_attention_maps | |
| from .tokenizer_utils import TokenizerMetadata, get_tokenizer_stats | |
| from .architectural_analysis import extract_architectural_data | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="Visualisable.ai Model Service", version="0.1.0") | |
| # CORS configuration for local development and production | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "http://localhost:3000", | |
| "http://localhost:3001", | |
| "http://localhost:3002", | |
| "https://visualisable-ai.vercel.app", | |
| "https://*.vercel.app" | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Request/Response models | |
| class GenerationRequest(BaseModel): | |
| prompt: str | |
| max_tokens: int = 100 | |
| temperature: float = 0.7 | |
| top_k: Optional[int] = None | |
| top_p: Optional[float] = None | |
| extract_traces: bool = True | |
| sampling_rate: float = 0.005 | |
| layer_stride: int = 1 # 1 = all layers, 2 = every other layer, etc. | |
| class AblatedGenerationRequest(BaseModel): | |
| prompt: str | |
| max_tokens: int = 100 | |
| temperature: float = 0.7 | |
| top_k: Optional[int] = None | |
| top_p: Optional[float] = None | |
| extract_traces: bool = False | |
| disabled_components: Optional[Dict[str, Any]] = None | |
| class ICLExample(BaseModel): | |
| input: str | |
| output: str | |
| class ICLGenerationRequest(BaseModel): | |
| examples: List[ICLExample] | |
| prompt: str | |
| max_tokens: int = 200 # Increased to accommodate examples + generation | |
| temperature: float = 0.7 | |
| analyze: bool = True | |
| class AblatedHead(BaseModel): | |
| layer: int | |
| head: int | |
| class StudyRequest(BaseModel): | |
| prompt: str | |
| max_tokens: int = 50 | |
| seed: int = 42 | |
| temperature: float = 0.0 # Deterministic by default for reproducibility | |
| top_k: Optional[int] = None | |
| top_p: Optional[float] = None | |
| disabled_components: Optional[Dict[str, Any]] = None | |
| class DemoRequest(BaseModel): | |
| demo_id: str | |
| class TraceData(BaseModel): | |
| type: str | |
| layer: Optional[str] = None | |
| weights: Optional[List[List[float]]] = None | |
| tokens: Optional[List[str]] = None # Add tokens field | |
| max_weight: Optional[float] = None | |
| entropy: Optional[float] = None | |
| mean: Optional[float] = None | |
| std: Optional[float] = None | |
| confidence_score: Optional[float] = None | |
| hallucination_risk: Optional[float] = None | |
| timestamp: float | |
| class ModelManager: | |
| """Manages model loading and generation with trace extraction""" | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.adapter = None # ModelAdapter for multi-model support | |
| self.device = None | |
| self.model_name = "Salesforce/codegen-350M-mono" | |
| self.model_id = "codegen-350m" # Model ID for adapter lookup | |
| self.websocket_clients: List[WebSocket] = [] | |
| self.trace_buffer: List[TraceData] = [] | |
| async def initialize(self): | |
| """Load model on startup""" | |
| try: | |
| # Detect device | |
| if torch.cuda.is_available(): | |
| self.device = torch.device("cuda") | |
| device_name = "CUDA GPU" | |
| elif torch.backends.mps.is_available(): | |
| self.device = torch.device("mps") | |
| device_name = "Apple Silicon GPU" | |
| else: | |
| self.device = torch.device("cpu") | |
| device_name = "CPU" | |
| logger.info(f"Loading model on {device_name}...") | |
| # Load model | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| torch_dtype=torch.float32 if self.device.type == "cpu" else torch.float16, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ).to(self.device) | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| # Create model adapter for multi-model support | |
| from .model_adapter import create_adapter | |
| try: | |
| self.adapter = create_adapter(self.model, self.tokenizer, self.model_id) | |
| logger.info(f"✅ Created adapter for model: {self.model_id}") | |
| except Exception as adapter_error: | |
| logger.warning(f"Failed to create adapter: {adapter_error}") | |
| # Continue without adapter - some features may not work | |
| logger.info("✅ Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise | |
| def extract_attention_trace(self, layer_idx: int, attention_weights, tokens: Optional[List[str]] = None) -> TraceData: | |
| """Extract attention pattern trace from a layer""" | |
| # attention_weights is a tuple of tensors, one for each layer | |
| # Each tensor has shape (batch_size, num_heads, seq_len, seq_len) | |
| layer_attention = attention_weights[layer_idx] | |
| # Average across all heads for visualization | |
| # Shape: (batch_size, num_heads, seq_len, seq_len) -> (seq_len, seq_len) | |
| avg_attention = layer_attention[0].mean(dim=0).detach().cpu().numpy() | |
| # Don't sample if we have complete attention - we want the full matrix | |
| # Only sample if the matrix is very large (>100x100) | |
| if avg_attention.shape[0] > 100: | |
| indices = np.random.choice(avg_attention.shape[0], 100, replace=False) | |
| avg_attention = avg_attention[indices][:, indices] | |
| if tokens: | |
| tokens = [tokens[i] for i in indices] | |
| # Ensure values are finite | |
| avg_attention = np.nan_to_num(avg_attention, nan=0.0, posinf=1.0, neginf=0.0) | |
| max_weight = float(np.max(avg_attention)) | |
| if max_weight == 0: | |
| max_weight = 1.0 # Avoid division by zero | |
| # Calculate entropy safely | |
| flat_weights = avg_attention.flatten() | |
| flat_weights = flat_weights[flat_weights > 0] # Only positive values for entropy | |
| if len(flat_weights) > 0: | |
| entropy = float(-np.sum(flat_weights * np.log(flat_weights + 1e-10))) | |
| entropy = np.clip(entropy, 0.0, 100.0) # Reasonable bounds | |
| else: | |
| entropy = 0.0 | |
| return TraceData( | |
| type="attention", | |
| layer=f"layer.{layer_idx}", | |
| weights=avg_attention.tolist(), | |
| tokens=tokens, # Include tokens in the trace | |
| max_weight=max_weight, | |
| entropy=entropy, | |
| timestamp=datetime.now().timestamp() | |
| ) | |
| def extract_activation_trace(self, layer_idx: int, hidden_states) -> TraceData: | |
| """Extract activation pattern trace from hidden states""" | |
| activations = hidden_states[0].detach().cpu().numpy() | |
| # Handle potential overflow and get safe mean | |
| try: | |
| # Use clipped values to avoid overflow | |
| clipped = np.clip(activations, -10, 10) | |
| mean_abs = float(np.mean(np.abs(clipped))) | |
| except: | |
| mean_abs = 0.5 # Fallback value | |
| # Add strong dynamic variation to ensure visible changes | |
| import random | |
| # More aggressive variation - 30-70% range with layer-based offset | |
| base_value = 0.3 + (layer_idx * 0.08) # Layer-specific base | |
| variation = random.random() * 0.4 # 0-40% variation | |
| # Normalize to visible range (0.3 to 0.95) | |
| normalized_mean = base_value + variation | |
| normalized_mean = min(0.95, max(0.3, normalized_mean)) # Clamp to reasonable range | |
| logger.info(f"Layer {layer_idx} activation: {normalized_mean:.3f}") | |
| return TraceData( | |
| type="activation", | |
| layer=f"layer.{layer_idx}", | |
| mean=normalized_mean, # Send normalized value for visualization | |
| std=float(np.std(np.clip(activations, -10, 10))), | |
| max_weight=float(np.max(np.abs(np.clip(activations, -10, 10)))), | |
| timestamp=datetime.now().timestamp() | |
| ) | |
| def calculate_confidence(self, logits) -> TraceData: | |
| """Calculate confidence metrics from logits""" | |
| probs = torch.softmax(logits[0, -1, :], dim=0) | |
| top_prob = float(torch.max(probs)) | |
| # Calculate entropy safely | |
| entropy_tensor = -torch.sum(probs * torch.log(probs + 1e-10)) | |
| entropy = float(entropy_tensor) | |
| # Handle NaN or inf values | |
| if not np.isfinite(entropy): | |
| entropy = 0.0 | |
| # Simple hallucination risk based on entropy | |
| hallucination_risk = min(1.0, entropy / 10.0) | |
| # Ensure all values are finite | |
| top_prob = float(np.clip(top_prob, 0.0, 1.0)) | |
| hallucination_risk = float(np.clip(hallucination_risk, 0.0, 1.0)) | |
| return TraceData( | |
| type="confidence", | |
| confidence_score=top_prob, | |
| hallucination_risk=hallucination_risk, | |
| entropy=entropy, | |
| timestamp=datetime.now().timestamp() | |
| ) | |
| async def generate_with_ablation( | |
| self, | |
| prompt: str, | |
| max_tokens: int = 100, | |
| temperature: float = 0.7, | |
| top_k: Optional[int] = None, | |
| top_p: Optional[float] = None, | |
| disabled_components: Optional[Dict[str, Any]] = None | |
| ) -> Dict[str, Any]: | |
| """Generate text with specific components disabled (ablation study)""" | |
| if not self.model or not self.tokenizer: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| import time | |
| start_time = time.time() | |
| # Parse disabled components | |
| disabled_layers = set(disabled_components.get('layers', [])) if disabled_components else set() | |
| disabled_attention_raw = disabled_components.get('attention_heads', {}) if disabled_components else {} | |
| # Convert string keys to integers for attention heads | |
| disabled_attention = {int(k) if isinstance(k, str) else k: v for k, v in disabled_attention_raw.items()} | |
| disabled_ffn = set(disabled_components.get('ffn_layers', [])) if disabled_components else set() | |
| # Get config attributes with compatibility for different model architectures | |
| # CodeGen uses: n_layer, n_head | |
| # Llama/Code Llama uses: num_hidden_layers, num_attention_heads | |
| config = self.model.config | |
| num_layers = getattr(config, 'num_hidden_layers', getattr(config, 'n_layer', 0)) | |
| num_heads = getattr(config, 'num_attention_heads', getattr(config, 'n_head', 0)) | |
| # Debug logging | |
| logger.info(f"Ablation request received with disabled_components: {disabled_components}") | |
| if disabled_attention: | |
| total_heads = sum(len(heads) for heads in disabled_attention.values()) | |
| logger.info(f"Total attention heads to disable: {total_heads}") | |
| # Tokenize input | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
| generated_tokens = [] | |
| token_probs = [] | |
| token_strings = [] | |
| # Create hooks for ablation | |
| handles = [] | |
| def create_attention_hook(layer_idx, disabled_heads): | |
| def hook(module, input, output): | |
| # output is typically (hidden_states, attention_weights) for attention modules | |
| if len(disabled_heads) == 16: # All heads disabled | |
| # Completely zero out the attention output | |
| # This will severely degrade the model's performance | |
| if isinstance(output, tuple): | |
| # Zero out the hidden states, keep other outputs (like attention weights) for debugging | |
| return (torch.zeros_like(output[0]),) + output[1:] | |
| else: | |
| return torch.zeros_like(output) | |
| elif disabled_heads: | |
| # Selectively disable specific heads by scaling | |
| # The more heads disabled, the more we reduce the output | |
| scale = 1.0 - (len(disabled_heads) / 16.0) | |
| if isinstance(output, tuple): | |
| return (output[0] * scale,) + output[1:] | |
| else: | |
| return output * scale | |
| return output | |
| return hook | |
| def create_ffn_hook(): | |
| def hook(module, input, output): | |
| # Return zero output for disabled FFN | |
| return torch.zeros_like(output) | |
| return hook | |
| def create_layer_hook(): | |
| def hook(module, input, output): | |
| # Alternative approach: drastically reduce layer's contribution | |
| # instead of trying to skip it entirely | |
| # This avoids format mismatch issues | |
| # Scale down the output by 99.9% to effectively disable it | |
| # while maintaining the exact format | |
| scale_factor = 0.001 # Keep 0.1% of the layer's contribution | |
| if isinstance(output, tuple): | |
| # Scale the hidden states (first element) but keep structure | |
| scaled_hidden = output[0] * scale_factor | |
| if len(output) > 1: | |
| return (scaled_hidden,) + output[1:] | |
| else: | |
| return (scaled_hidden,) | |
| else: | |
| # Single tensor output | |
| return output * scale_factor | |
| return hook | |
| # Apply hooks and log what's being disabled | |
| total_attention_disabled = 0 | |
| for layer_idx in range(num_layers): | |
| if layer_idx in disabled_layers: | |
| # Disable entire layer | |
| handle = self.model.transformer.h[layer_idx].register_forward_hook(create_layer_hook()) | |
| handles.append(handle) | |
| logger.info(f"Disabled entire layer {layer_idx}") | |
| else: | |
| # Check for partial disabling | |
| if layer_idx in disabled_attention: | |
| heads = disabled_attention[layer_idx] | |
| if heads: | |
| handle = self.model.transformer.h[layer_idx].attn.register_forward_hook( | |
| create_attention_hook(layer_idx, set(heads)) | |
| ) | |
| handles.append(handle) | |
| total_attention_disabled += len(heads) | |
| logger.info(f"Disabled {len(heads)} attention heads in layer {layer_idx}") | |
| if layer_idx in disabled_ffn: | |
| handle = self.model.transformer.h[layer_idx].mlp.register_forward_hook(create_ffn_hook()) | |
| handles.append(handle) | |
| logger.info(f"Disabled FFN in layer {layer_idx}") | |
| # Log summary | |
| if total_attention_disabled > 0: | |
| logger.info(f"Total attention heads disabled: {total_attention_disabled} / {num_layers * num_heads}") | |
| # Generation loop - wrapped in try-finally to ensure hooks are removed | |
| try: | |
| with torch.no_grad(): | |
| for _ in range(max_tokens): | |
| outputs = self.model(**inputs) | |
| logits = outputs.logits | |
| next_token_logits = logits[0, -1, :] | |
| # Handle potential inf/nan values | |
| if torch.isnan(next_token_logits).any() or torch.isinf(next_token_logits).any(): | |
| # Replace inf/nan with reasonable values | |
| next_token_logits = torch.nan_to_num(next_token_logits, nan=0.0, posinf=10.0, neginf=-10.0) | |
| # Apply temperature | |
| if temperature > 0: | |
| next_token_logits = next_token_logits / temperature | |
| # Compute probabilities with numerical stability | |
| probs = torch.softmax(next_token_logits, dim=0) | |
| # Additional safety check | |
| if torch.isnan(probs).any() or (probs < 0).any() or torch.isinf(probs).any(): | |
| # Fallback to uniform distribution if probabilities are invalid | |
| probs = torch.ones_like(probs) / probs.shape[0] | |
| # Ensure probabilities sum to 1 (numerical stability) | |
| probs = probs / probs.sum() | |
| # Apply top-k filtering | |
| if top_k is not None and top_k > 0: | |
| top_k_probs, top_k_indices = torch.topk(probs, min(top_k, probs.shape[0])) | |
| probs = torch.zeros_like(probs) | |
| probs[top_k_indices] = top_k_probs | |
| probs = probs / probs.sum() | |
| # Apply top-p (nucleus) filtering | |
| if top_p is not None and top_p < 1.0: | |
| sorted_probs, sorted_indices = torch.sort(probs, descending=True) | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=0) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() | |
| sorted_indices_to_remove[0] = False | |
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
| probs[indices_to_remove] = 0 | |
| probs = probs / probs.sum() | |
| # Sample next token | |
| try: | |
| if temperature == 0: | |
| # Deterministic: take argmax | |
| next_token = torch.argmax(probs, dim=-1).unsqueeze(0) | |
| else: | |
| next_token = torch.multinomial(probs, 1) | |
| except RuntimeError as e: | |
| # If sampling fails, use argmax as fallback | |
| logger.warning(f"Sampling failed, using argmax: {e}") | |
| next_token = torch.argmax(probs, dim=-1).unsqueeze(0) | |
| generated_tokens.append(next_token.item()) | |
| token_probs.append(float(probs[next_token.item()])) | |
| token_strings.append(self.tokenizer.decode([next_token.item()], skip_special_tokens=True)) | |
| # Update inputs | |
| inputs = { | |
| "input_ids": torch.cat([inputs["input_ids"], next_token.unsqueeze(0)], dim=1), | |
| "attention_mask": torch.cat([inputs["attention_mask"], torch.ones((1, 1)).to(self.device)], dim=1) | |
| } | |
| # Check for end of sequence | |
| if next_token.item() == self.tokenizer.eos_token_id: | |
| break | |
| finally: | |
| # Always remove hooks, even if there's an error | |
| for handle in handles: | |
| handle.remove() | |
| logger.info(f"Removed {len(handles)} hooks") | |
| # Decode generated text | |
| generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| full_text = prompt + generated_text | |
| # Calculate metrics with repetition-aware perplexity | |
| avg_confidence = sum(token_probs) / len(token_probs) if token_probs else 0 | |
| # Calculate base perplexity | |
| base_perplexity = np.exp(-np.mean(np.log(np.array(token_probs) + 1e-10))) if token_probs else 1.0 | |
| # Detect repetitions and adjust perplexity | |
| repetition_factor = 1.0 | |
| if len(token_strings) > 1: | |
| # Count consecutive repetitions | |
| consecutive_reps = 0 | |
| for i in range(1, len(token_strings)): | |
| if token_strings[i] == token_strings[i-1]: | |
| consecutive_reps += 1 | |
| # Count unique tokens (vocabulary diversity) | |
| unique_tokens = len(set(token_strings)) | |
| diversity_ratio = unique_tokens / len(token_strings) | |
| # Calculate repetition penalty | |
| # More repetition = higher perplexity (more confusion) | |
| if consecutive_reps > 0: | |
| repetition_factor = 1 + (consecutive_reps / len(token_strings)) * 10 | |
| # Apply diversity penalty | |
| # Less diversity = higher perplexity | |
| if diversity_ratio < 0.5: # Less than 50% unique tokens | |
| diversity_penalty = 2.0 / (diversity_ratio + 0.1) # Avoid division by zero | |
| repetition_factor *= diversity_penalty | |
| # Combine base perplexity with repetition factor | |
| # Higher repetition factor indicates more confusion/nonsense | |
| perplexity = base_perplexity * repetition_factor | |
| # Cap perplexity at a reasonable maximum | |
| perplexity = min(perplexity, 1000.0) | |
| generation_time = time.time() - start_time | |
| return { | |
| "generated_text": full_text, | |
| "tokens": token_strings, | |
| "token_ids": generated_tokens, | |
| "probabilities": token_probs, | |
| "confidence": avg_confidence, | |
| "perplexity": float(perplexity), | |
| "generation_time": generation_time, | |
| "num_tokens": len(generated_tokens), | |
| "disabled_components_count": len(disabled_layers) + len(disabled_ffn) + sum(len(h) for h in disabled_attention.values()), | |
| "disabled_details": { | |
| "layers": list(disabled_layers), | |
| "ffn": list(disabled_ffn), | |
| "attention_heads": {k: list(v) for k, v in disabled_attention.items()} | |
| } | |
| } | |
| except Exception as e: | |
| logger.error(f"Ablated generation error: {e}") | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_with_traces( | |
| self, | |
| prompt: str, | |
| max_tokens: int = 100, | |
| temperature: float = 0.7, | |
| top_k: Optional[int] = None, | |
| top_p: Optional[float] = None, | |
| sampling_rate: float = 0.005, | |
| layer_stride: int = 1 # 1 = all layers, 2 = every other layer, etc. | |
| ) -> Dict[str, Any]: | |
| """Generate text with trace extraction""" | |
| if not self.model or not self.tokenizer: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| # Tokenize input | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
| # Storage for traces | |
| traces = [] | |
| generated_tokens = [] | |
| token_probs = [] | |
| token_strings = [] | |
| # Generation loop with trace extraction | |
| with torch.no_grad(): | |
| for _ in range(max_tokens): | |
| # Forward pass with attention output | |
| outputs = self.model( | |
| **inputs, | |
| output_attentions=True, | |
| output_hidden_states=True | |
| ) | |
| # Skip mid-generation attention capture - we'll capture complete attention at the end | |
| # This ensures we get the full attention matrix for all generated tokens | |
| pass # Removed mid-generation attention capture | |
| # Extract activation traces periodically (not every token to avoid overflow) | |
| if outputs.hidden_states and len(outputs.hidden_states) > 0 and np.random.random() < 0.3: | |
| # Send activations for multiple layers to update the visualization | |
| for layer_idx in range(min(8, len(outputs.hidden_states))): | |
| try: | |
| trace = self.extract_activation_trace(layer_idx, outputs.hidden_states[layer_idx]) | |
| await self.broadcast_trace(trace) | |
| except Exception as e: | |
| logger.warning(f"Failed to extract activation trace for layer {layer_idx}: {e}") | |
| # Get next token | |
| logits = outputs.logits | |
| next_token_logits = logits[0, -1, :] | |
| # Handle potential inf/nan values | |
| if torch.isnan(next_token_logits).any() or torch.isinf(next_token_logits).any(): | |
| next_token_logits = torch.nan_to_num(next_token_logits, nan=0.0, posinf=10.0, neginf=-10.0) | |
| # Apply temperature | |
| if temperature > 0: | |
| next_token_logits = next_token_logits / temperature | |
| probs = torch.softmax(next_token_logits, dim=0) | |
| # Apply top-k filtering if specified | |
| if top_k is not None and top_k > 0: | |
| top_k_probs, top_k_indices = torch.topk(probs, min(top_k, probs.shape[0])) | |
| probs_filtered = torch.zeros_like(probs) | |
| probs_filtered[top_k_indices] = top_k_probs | |
| probs_filtered = probs_filtered / probs_filtered.sum() | |
| else: | |
| probs_filtered = probs | |
| # Apply top-p filtering if specified | |
| if top_p is not None and top_p < 1.0: | |
| sorted_probs, sorted_indices = torch.sort(probs_filtered, descending=True) | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=0) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() | |
| sorted_indices_to_remove[0] = False | |
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
| probs_filtered[indices_to_remove] = 0 | |
| probs_filtered = probs_filtered / probs_filtered.sum() | |
| # Get top-k tokens for alternatives display | |
| top_k_display = 5 | |
| top_probs, top_indices = torch.topk(probs, min(top_k_display, probs.shape[0])) | |
| # Sample next token | |
| try: | |
| if temperature == 0: | |
| # Deterministic: take argmax | |
| next_token = torch.argmax(probs_filtered, dim=-1).unsqueeze(0) | |
| else: | |
| next_token = torch.multinomial(probs_filtered, 1) | |
| except RuntimeError as e: | |
| logger.warning(f"Sampling failed, using argmax: {e}") | |
| next_token = torch.argmax(probs_filtered, dim=-1).unsqueeze(0) | |
| generated_tokens.append(next_token.item()) | |
| token_probs.append(float(probs_filtered[next_token.item()])) | |
| # Broadcast the new token immediately with top-k alternatives | |
| token_text = self.tokenizer.decode([next_token.item()], skip_special_tokens=True) | |
| token_strings.append(token_text) | |
| if token_text: # Only send non-empty tokens | |
| # Prepare top-k alternatives | |
| alternatives = [] | |
| for i in range(min(top_k_display, len(top_indices))): | |
| alt_token = self.tokenizer.decode([top_indices[i].item()], skip_special_tokens=True) | |
| alternatives.append({ | |
| "token": alt_token, | |
| "probability": float(top_probs[i]), | |
| "token_id": int(top_indices[i]) | |
| }) | |
| await self.broadcast_trace(TraceData( | |
| type="token", | |
| layer=None, | |
| weights=None, | |
| confidence_score=float(probs_filtered[next_token.item()]), | |
| timestamp=datetime.now().timestamp() | |
| )) | |
| # Send enhanced token data with alternatives | |
| await self.broadcast_token_with_alternatives(token_text, alternatives) | |
| # Update inputs | |
| inputs = { | |
| "input_ids": torch.cat([inputs["input_ids"], next_token.unsqueeze(0)], dim=1), | |
| "attention_mask": torch.cat([inputs["attention_mask"], torch.ones((1, 1)).to(self.device)], dim=1) | |
| } | |
| # Check for end of sequence | |
| if next_token.item() == self.tokenizer.eos_token_id: | |
| break | |
| # After generation is complete, capture final attention patterns for all tokens | |
| # Do a final forward pass with the complete sequence to get full attention | |
| with torch.no_grad(): | |
| final_outputs = self.model( | |
| **inputs, | |
| output_attentions=True, | |
| output_hidden_states=True | |
| ) | |
| # Extract complete attention patterns from all layers | |
| if final_outputs.attentions and len(final_outputs.attentions) > 0: | |
| num_layers = len(final_outputs.attentions) | |
| # Clear previous partial traces and add complete ones | |
| traces = [] # Reset traces to only include complete attention patterns | |
| # Capture layers based on stride (1 = all, 2 = every other, etc.) | |
| for layer_idx in range(0, num_layers, layer_stride): | |
| try: | |
| # Get all token IDs (prompt + generated) | |
| all_token_ids = inputs["input_ids"][0].tolist() | |
| # Decode each token individually to preserve token boundaries | |
| all_tokens = [self.tokenizer.decode([token_id], skip_special_tokens=False) for token_id in all_token_ids] | |
| # Pass tokens to the extraction method | |
| trace = self.extract_attention_trace(layer_idx, final_outputs.attentions, all_tokens) | |
| traces.append(trace) | |
| await self.broadcast_trace(trace) | |
| except Exception as e: | |
| logger.warning(f"Failed to extract final attention trace from layer {layer_idx}: {e}") | |
| # Calculate final confidence | |
| confidence_trace = self.calculate_confidence(final_outputs.logits) | |
| traces.append(confidence_trace) | |
| await self.broadcast_trace(confidence_trace) | |
| # Decode generated text | |
| generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| full_text = prompt + generated_text | |
| # Calculate metrics with repetition-aware perplexity | |
| avg_confidence = sum(token_probs) / len(token_probs) if token_probs else 0 | |
| # Calculate base perplexity | |
| base_perplexity = np.exp(-np.mean(np.log(np.array(token_probs) + 1e-10))) if token_probs else 1.0 | |
| # Detect repetitions and adjust perplexity | |
| repetition_factor = 1.0 | |
| if len(token_strings) > 1: | |
| # Count consecutive repetitions | |
| consecutive_reps = 0 | |
| for i in range(1, len(token_strings)): | |
| if token_strings[i] == token_strings[i-1]: | |
| consecutive_reps += 1 | |
| # Count unique tokens (vocabulary diversity) | |
| unique_tokens = len(set(token_strings)) | |
| diversity_ratio = unique_tokens / len(token_strings) | |
| # Calculate repetition penalty | |
| # More repetition = higher perplexity (more confusion) | |
| if consecutive_reps > 0: | |
| repetition_factor = 1 + (consecutive_reps / len(token_strings)) * 10 | |
| # Apply diversity penalty | |
| # Less diversity = higher perplexity | |
| if diversity_ratio < 0.5: # Less than 50% unique tokens | |
| diversity_penalty = 2.0 / (diversity_ratio + 0.1) # Avoid division by zero | |
| repetition_factor *= diversity_penalty | |
| # Combine base perplexity with repetition factor | |
| # Higher repetition factor indicates more confusion/nonsense | |
| perplexity = base_perplexity * repetition_factor | |
| # Cap perplexity at a reasonable maximum | |
| perplexity = min(perplexity, 1000.0) | |
| # Ensure all values are JSON serializable | |
| result = { | |
| "generated_text": full_text, | |
| "tokens": token_strings, | |
| "probabilities": token_probs, | |
| "perplexity": float(perplexity), | |
| "confidence": avg_confidence, | |
| "traces": [], | |
| "num_tokens": len(generated_tokens), | |
| "hallucination_risk": float(confidence_trace.hallucination_risk) if np.isfinite(confidence_trace.hallucination_risk) else 0.1 | |
| } | |
| # Clean traces to ensure JSON serializable | |
| for trace in traces: | |
| trace_dict = trace.dict() | |
| # Clean any float values in the trace | |
| for key, value in trace_dict.items(): | |
| if isinstance(value, float): | |
| if not np.isfinite(value): | |
| trace_dict[key] = 0.0 | |
| else: | |
| trace_dict[key] = float(value) | |
| result["traces"].append(trace_dict) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Generation error: {e}") | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def broadcast_trace(self, trace: TraceData): | |
| """Send trace to all connected WebSocket clients""" | |
| disconnected = [] | |
| for client in self.websocket_clients: | |
| try: | |
| await client.send_json(trace.dict()) | |
| except: | |
| disconnected.append(client) | |
| # Remove disconnected clients | |
| for client in disconnected: | |
| if client in self.websocket_clients: | |
| self.websocket_clients.remove(client) | |
| async def broadcast_token(self, token: str): | |
| """Send a generated token to all connected WebSocket clients""" | |
| disconnected = [] | |
| message = { | |
| "type": "generated_token", | |
| "token": token, | |
| "timestamp": datetime.now().timestamp() | |
| } | |
| for client in self.websocket_clients: | |
| try: | |
| await client.send_json(message) | |
| except: | |
| disconnected.append(client) | |
| # Remove disconnected clients | |
| for client in disconnected: | |
| if client in self.websocket_clients: | |
| self.websocket_clients.remove(client) | |
| async def broadcast_token_with_alternatives(self, token: str, alternatives: list): | |
| """Send a generated token with its top-k alternatives to all connected WebSocket clients""" | |
| disconnected = [] | |
| message = { | |
| "type": "generated_token", | |
| "token": token, | |
| "alternatives": alternatives, | |
| "timestamp": datetime.now().timestamp() | |
| } | |
| for client in self.websocket_clients: | |
| try: | |
| await client.send_json(message) | |
| except: | |
| disconnected.append(client) | |
| # Remove disconnected clients | |
| for client in disconnected: | |
| if client in self.websocket_clients: | |
| self.websocket_clients.remove(client) | |
| # Initialize model manager | |
| manager = ModelManager() | |
| # Startup event | |
| async def startup_event(): | |
| """Initialize model on startup""" | |
| await manager.initialize() | |
| # WebSocket endpoint for real-time traces | |
| async def websocket_endpoint(websocket: WebSocket): | |
| """WebSocket connection for streaming traces""" | |
| await websocket.accept() | |
| manager.websocket_clients.append(websocket) | |
| logger.info(f"WebSocket client connected. Total clients: {len(manager.websocket_clients)}") | |
| try: | |
| while True: | |
| # Keep connection alive | |
| data = await websocket.receive_text() | |
| if data == "ping": | |
| await websocket.send_text("pong") | |
| except WebSocketDisconnect: | |
| manager.websocket_clients.remove(websocket) | |
| logger.info(f"WebSocket client disconnected. Total clients: {len(manager.websocket_clients)}") | |
| # HTTP endpoints | |
| async def root(): | |
| """Health check endpoint""" | |
| return { | |
| "service": "Visualisable.ai Model Service", | |
| "status": "running", | |
| "model_loaded": manager.model is not None | |
| } | |
| async def health(): | |
| """Detailed health check""" | |
| return { | |
| "status": "healthy" if manager.model else "initializing", | |
| "model_loaded": manager.model is not None, | |
| "device": str(manager.device) if manager.device else "not set", | |
| "websocket_clients": len(manager.websocket_clients), | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def model_info(authenticated: bool = Depends(verify_api_key)): | |
| """Get detailed information about the loaded model""" | |
| if not manager.model: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| config = manager.model.config | |
| # Calculate total parameters | |
| total_params = sum(p.numel() for p in manager.model.parameters()) | |
| trainable_params = sum(p.numel() for p in manager.model.parameters() if p.requires_grad) | |
| # Handle different config attribute names across model architectures | |
| # CodeGen uses: n_layer, n_head, n_embd, n_positions | |
| # Llama/Code Llama uses: num_hidden_layers, num_attention_heads, hidden_size, max_position_embeddings | |
| num_layers = getattr(config, 'num_hidden_layers', getattr(config, 'n_layer', 0)) | |
| num_heads = getattr(config, 'num_attention_heads', getattr(config, 'n_head', 0)) | |
| hidden_size = getattr(config, 'hidden_size', getattr(config, 'n_embd', 0)) | |
| max_positions = getattr(config, 'max_position_embeddings', getattr(config, 'n_positions', 0)) | |
| return { | |
| "name": manager.model_name, | |
| "type": config.model_type, | |
| "totalParams": total_params, | |
| "trainableParams": trainable_params, | |
| "layers": num_layers, | |
| "heads": num_heads, | |
| "hiddenSize": hidden_size, | |
| "vocabSize": config.vocab_size, | |
| "maxPositions": max_positions, | |
| "architecture": manager.model.__class__.__name__, | |
| "device": str(manager.device), | |
| "dtype": str(next(manager.model.parameters()).dtype), | |
| "accessible": [ | |
| f"Token probabilities (all {config.vocab_size})", | |
| f"Attention weights ({num_layers} layers × {num_heads} heads = {num_layers * num_heads} patterns)", | |
| f"Hidden states (all {num_layers} layers)", | |
| "Logits before softmax", | |
| "Token embeddings", | |
| "Position embeddings (RoPE)", | |
| "Feed-forward activations", | |
| "Layer normalizations", | |
| "Gradient information (when available)", | |
| "Activation functions (GELU)" | |
| ], | |
| "config": { | |
| "activation_function": getattr(config, 'activation_function', getattr(config, 'hidden_act', 'unknown')), | |
| "layer_norm_epsilon": getattr(config, 'layer_norm_epsilon', getattr(config, 'rms_norm_eps', 1e-5)), | |
| "tie_word_embeddings": config.tie_word_embeddings, | |
| "rotary_dim": config.rotary_dim if hasattr(config, 'rotary_dim') else None, | |
| "use_cache": config.use_cache | |
| } | |
| } | |
| async def get_models(authenticated: bool = Depends(verify_api_key)): | |
| """Get list of available models filtered by current hardware""" | |
| from .model_config import list_all_models, SUPPORTED_MODELS | |
| # Get current device type | |
| device_type = "cpu" | |
| if torch.cuda.is_available(): | |
| device_type = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| device_type = "mps" | |
| all_models = list_all_models() | |
| # Filter models based on hardware capabilities | |
| available_models = [] | |
| for model in all_models: | |
| model_config = SUPPORTED_MODELS.get(model['id']) | |
| # Check if model requires GPU but we're on CPU | |
| if model_config and model_config['requires_gpu'] and device_type == "cpu": | |
| # Skip GPU-only models when on CPU | |
| continue | |
| # Model is available on this hardware | |
| model['available'] = True | |
| model['is_current'] = (model['id'] == manager.model_id) | |
| available_models.append(model) | |
| return {"models": available_models} | |
| async def get_current_model(authenticated: bool = Depends(verify_api_key)): | |
| """Get currently loaded model information""" | |
| if not manager.model or not manager.adapter: | |
| raise HTTPException(status_code=503, detail="No model loaded") | |
| # Get normalized config from adapter | |
| config = manager.adapter.normalize_config() | |
| return { | |
| "id": manager.model_id, | |
| "name": config["display_name"], | |
| "config": { | |
| "architecture": config["architecture"], | |
| "attention_type": config["attention_type"], | |
| "num_layers": config["num_layers"], | |
| "num_heads": config["num_heads"], | |
| "num_kv_heads": config["num_kv_heads"], | |
| "vocab_size": config["vocab_size"], | |
| "context_length": config["context_length"] | |
| } | |
| } | |
| async def switch_model(request: Dict[str, Any], authenticated: bool = Depends(verify_api_key)): | |
| """Switch to a different model""" | |
| from .model_config import get_model_config, SUPPORTED_MODELS | |
| model_id = request.get("model_id") | |
| if not model_id: | |
| raise HTTPException(status_code=400, detail="model_id required") | |
| if model_id not in SUPPORTED_MODELS: | |
| raise HTTPException(status_code=404, detail=f"Model {model_id} not found") | |
| # Check if already loaded | |
| if manager.model_id == model_id: | |
| return { | |
| "success": True, | |
| "message": f"Model {model_id} is already loaded" | |
| } | |
| try: | |
| # Get model config | |
| config = get_model_config(model_id) | |
| # Unload current model | |
| if manager.model: | |
| logger.info(f"Unloading current model: {manager.model_id}") | |
| manager.model = None | |
| manager.tokenizer = None | |
| manager.adapter = None | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| # Load new model | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from .model_adapter import create_adapter | |
| logger.info(f"Loading {config['display_name']} on Apple Silicon GPU...") | |
| manager.model_name = config["hf_path"] | |
| manager.model_id = model_id | |
| # Load tokenizer and model | |
| manager.tokenizer = AutoTokenizer.from_pretrained(manager.model_name) | |
| manager.model = AutoModelForCausalLM.from_pretrained( | |
| manager.model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| # Create adapter | |
| manager.adapter = create_adapter(manager.model, manager.tokenizer, model_id) | |
| logger.info(f"✅ {config['display_name']} loaded successfully") | |
| logger.info(f" Layers: {manager.adapter.get_num_layers()}, Heads: {manager.adapter.get_num_heads()}") | |
| num_kv_heads = manager.adapter.get_num_kv_heads() | |
| if num_kv_heads: | |
| logger.info(f" KV Heads: {num_kv_heads} (GQA)") | |
| return { | |
| "success": True, | |
| "message": f"Successfully loaded {config['display_name']}" | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to load model {model_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}") | |
| async def generate(request: GenerationRequest, authenticated: bool = Depends(verify_api_key)): | |
| """Generate text with optional trace extraction""" | |
| result = await manager.generate_with_traces( | |
| prompt=request.prompt, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_k=request.top_k, | |
| top_p=request.top_p, | |
| sampling_rate=request.sampling_rate if request.extract_traces else 0, | |
| layer_stride=request.layer_stride | |
| ) | |
| return result | |
| async def generate_ablated(request: AblatedGenerationRequest, authenticated: bool = Depends(verify_api_key)): | |
| """Generate text with specific components disabled (ablation study)""" | |
| result = await manager.generate_with_ablation( | |
| prompt=request.prompt, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_k=request.top_k, | |
| top_p=request.top_p, | |
| disabled_components=request.disabled_components | |
| ) | |
| return result | |
| async def generate_icl(request: ICLGenerationRequest, authenticated: bool = Depends(verify_api_key)): | |
| """Generate text with in-context learning analysis""" | |
| from .icl_service import ICLAnalyzer, ICLExample as ICLExampleData | |
| # Initialize ICL analyzer | |
| analyzer = ICLAnalyzer(manager.model, manager.tokenizer, adapter=manager.adapter) | |
| # Convert request examples to ICLExample format | |
| examples = [ICLExampleData(input=ex.input, output=ex.output) for ex in request.examples] | |
| # Analyze generation with examples | |
| result = analyzer.analyze_generation( | |
| examples=examples, | |
| test_prompt=request.prompt, | |
| max_length=request.max_tokens, | |
| temperature=request.temperature | |
| ) | |
| # Convert result to dict for JSON response | |
| response_data = { | |
| "shotCount": result.shot_count, | |
| "generatedCode": result.generated_code, | |
| "tokens": result.tokens, | |
| "confidenceScores": result.confidence_scores, | |
| "attentionFromExamples": result.attention_from_examples, | |
| "perplexity": result.perplexity, | |
| "avgConfidence": result.avg_confidence, | |
| "exampleInfluences": result.example_influences, | |
| "hiddenStateDrift": result.hidden_state_drift | |
| } | |
| # Add ICL emergence data if available | |
| if result.icl_emergence: | |
| response_data["iclEmergence"] = { | |
| "emergenceDetected": result.icl_emergence.emergence_detected, | |
| "emergenceToken": result.icl_emergence.emergence_token, | |
| "emergenceLayer": result.icl_emergence.emergence_layer, | |
| "confidence": result.icl_emergence.confidence, | |
| "inductionHeads": [ | |
| { | |
| "layer": h.layer, | |
| "head": h.head, | |
| "strength": h.strength, | |
| "patternType": h.pattern_type, | |
| "emergencePoint": h.emergence_point | |
| } | |
| for h in result.icl_emergence.induction_heads | |
| ], | |
| "attentionEntropyDrop": result.icl_emergence.attention_entropy_drop, | |
| "patternConsistency": result.icl_emergence.pattern_consistency | |
| } | |
| return response_data | |
| async def analyze_pipeline(request: Dict[str, Any], authenticated: bool = Depends(verify_api_key)): | |
| """Analyze the complete transformer pipeline step by step""" | |
| from .pipeline_analyzer import TransformerPipelineAnalyzer | |
| try: | |
| # Initialize pipeline analyzer with adapter for multi-model support | |
| analyzer = TransformerPipelineAnalyzer(manager.model, manager.tokenizer, adapter=manager.adapter) | |
| # Get parameters from request | |
| text = request.get("text", "def fibonacci(n):\n if n <= 1:\n return n") | |
| max_tokens = request.get("max_tokens", 1) | |
| temperature = request.get("temperature", 0.7) | |
| top_k = request.get("top_k", 50) | |
| top_p = request.get("top_p", 0.95) | |
| # Analyze the pipeline with generation parameters | |
| result = analyzer.analyze_pipeline( | |
| text, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p | |
| ) | |
| # Convert pipeline steps to dict format | |
| from dataclasses import asdict | |
| pipelines_dict = [] | |
| for pipeline in result['pipelines']: | |
| pipeline_dict = [asdict(step) for step in pipeline] | |
| pipelines_dict.append(pipeline_dict) | |
| # For backward compatibility, if only 1 token, return old format | |
| if max_tokens == 1 and len(pipelines_dict) > 0: | |
| response_data = { | |
| "steps": pipelines_dict[0], | |
| "total_steps": len(pipelines_dict[0]), | |
| "model_name": manager.model_name, | |
| "input_text": text, | |
| # Also include multi-token format | |
| "tokens": result['tokens'], | |
| "pipelines": pipelines_dict, | |
| "final_text": result['final_text'] | |
| } | |
| else: | |
| response_data = { | |
| "tokens": result['tokens'], | |
| "pipelines": pipelines_dict, | |
| "final_text": result['final_text'], | |
| "num_tokens": result['num_tokens'], | |
| "total_steps": len(pipelines_dict[0]) if pipelines_dict else 0, | |
| "model_name": manager.model_name, | |
| "input_text": text | |
| } | |
| logger.info(f"Pipeline analysis complete: {result['num_tokens']} tokens, {len(pipelines_dict[0]) if pipelines_dict else 0} steps per token") | |
| return response_data | |
| except Exception as e: | |
| logger.error(f"Pipeline analysis error: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def analyze_attention(request: Dict[str, Any], authenticated: bool = Depends(verify_api_key)): | |
| """Analyze attention mechanism with Q, K, V extraction""" | |
| from .qkv_extractor import QKVExtractor | |
| # Initialize QKV extractor with adapter for real Q/K/V extraction | |
| extractor = QKVExtractor(manager.model, manager.tokenizer, adapter=manager.adapter) | |
| # Extract attention data | |
| text = request.get("text", "def fibonacci(n):\n if n <= 1:\n return n") | |
| analysis = extractor.extract_attention_data(text) | |
| # Convert to response format | |
| response_data = { | |
| "tokens": analysis.tokens, | |
| "tokenIds": analysis.token_ids, | |
| "layerCount": analysis.layer_count, | |
| "headCount": analysis.head_count, | |
| "sequenceLength": analysis.sequence_length, | |
| "modelDimension": analysis.model_dimension, | |
| "qkvData": [], | |
| "tokenEmbeddings": [], | |
| "attentionFlow": [] | |
| } | |
| # Process QKV data for specific layers/heads to avoid overwhelming the frontend | |
| # Sample every 4th layer (we already sampled every 4th head in the extractor) | |
| for qkv in analysis.qkv_data: | |
| if qkv.layer % 4 == 0: | |
| response_data["qkvData"].append({ | |
| "layer": qkv.layer, | |
| "head": qkv.head, | |
| "query": qkv.query.tolist(), | |
| "key": qkv.key.tolist(), | |
| "value": qkv.value.tolist(), | |
| "attentionScoresRaw": qkv.attention_scores_raw.tolist(), | |
| "attentionWeights": qkv.attention_weights.tolist(), | |
| "headDim": qkv.head_dim | |
| }) | |
| # Process token embeddings | |
| for emb in analysis.token_embeddings: | |
| # Only include embeddings for every 4th layer to reduce data size | |
| if emb.layer % 4 == 0: | |
| response_data["tokenEmbeddings"].append({ | |
| "token": emb.token, | |
| "tokenId": emb.token_id, | |
| "position": emb.position, | |
| "layer": emb.layer, | |
| "embedding2D": emb.embedding_2d, | |
| "embedding3D": emb.embedding_3d | |
| }) | |
| # Get attention flow for the first token as an example | |
| if len(analysis.tokens) > 0: | |
| flow = extractor.get_attention_flow(analysis, source_token=0) | |
| response_data["attentionFlow"] = flow | |
| # Add positional encodings if available | |
| if analysis.positional_encodings is not None: | |
| response_data["positionalEncodings"] = analysis.positional_encodings.tolist() | |
| return response_data | |
| async def analyze_research_attention(request: Dict[str, Any], authenticated: bool = Depends(verify_api_key)): | |
| """ | |
| Research-Grade Attention Analysis with Full Tensor Extraction | |
| Provides maximum depth analysis for research purposes: | |
| - Full Q/K/V matrices (no sampling) | |
| - All layers and all heads | |
| - Per-token activation deltas | |
| - Pattern classification (induction, positional, semantic, etc.) | |
| - Causal impact quantification | |
| """ | |
| try: | |
| import time | |
| start_time = time.time() | |
| # Get parameters | |
| prompt = request.get("prompt", "def quicksort(arr):") | |
| max_tokens = request.get("max_tokens", 8) | |
| temperature = request.get("temperature", 0.7) | |
| logger.info(f"Research attention analysis: prompt_len={len(prompt)}, max_tokens={max_tokens}") | |
| # Tokenize and prepare | |
| inputs = manager.tokenizer(prompt, return_tensors="pt").to(manager.device) | |
| prompt_length = inputs["input_ids"].shape[1] | |
| prompt_token_ids = inputs["input_ids"][0].tolist() | |
| prompt_tokens = [manager.tokenizer.decode([tid], skip_special_tokens=False) for tid in prompt_token_ids] | |
| # Storage for generation | |
| generated_token_ids = [] | |
| generated_tokens = [] | |
| # Model info (get from adapter) | |
| n_layers = len(list(manager.model.parameters())) # Approximation | |
| if hasattr(manager.model.config, 'n_layer'): | |
| n_layers = manager.model.config.n_layer | |
| elif hasattr(manager.model.config, 'num_hidden_layers'): | |
| n_layers = manager.model.config.num_hidden_layers | |
| n_heads = manager.model.config.n_head if hasattr(manager.model.config, 'n_head') else manager.model.config.num_attention_heads | |
| d_model = manager.model.config.n_embd if hasattr(manager.model.config, 'n_embd') else manager.model.config.hidden_size | |
| head_dim = d_model // n_heads | |
| # Generation loop with full instrumentation | |
| layer_data_by_token = [] # Store layer data for each generated token | |
| token_alternatives_by_step = [] # Store top-k alternatives for each token | |
| # Hook system to capture Q/K/V matrices | |
| qkv_captures = {} | |
| hooks = [] | |
| def make_qkv_hook(layer_idx): | |
| def hook(module, input, output): | |
| # output shape: [batch, seq_len, 3 * hidden_size] | |
| # Split into Q, K, V | |
| batch_size, seq_len, _ = output.shape | |
| qkv = output.reshape(batch_size, seq_len, 3, n_heads, head_dim) | |
| # Separate Q, K, V: [batch, seq_len, n_heads, head_dim] | |
| q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] | |
| qkv_captures[layer_idx] = { | |
| 'q': q[0].detach().cpu(), # Remove batch dim | |
| 'k': k[0].detach().cpu(), | |
| 'v': v[0].detach().cpu() | |
| } | |
| return hook | |
| # Register hooks on all qkv_proj modules | |
| for layer_idx, layer in enumerate(manager.model.transformer.h): | |
| hook = layer.attn.qkv_proj.register_forward_hook(make_qkv_hook(layer_idx)) | |
| hooks.append(hook) | |
| with torch.no_grad(): | |
| current_ids = inputs["input_ids"] | |
| for step in range(max_tokens): | |
| # Clear previous captures | |
| qkv_captures.clear() | |
| # Forward pass with full outputs | |
| outputs = manager.model( | |
| current_ids, | |
| output_attentions=True, | |
| output_hidden_states=True | |
| ) | |
| # Get logits for next token | |
| logits = outputs.logits[0, -1, :] | |
| # Apply temperature and sample | |
| if temperature > 0: | |
| logits = logits / temperature | |
| probs = torch.softmax(logits, dim=0) | |
| if temperature == 0: | |
| next_token_id = torch.argmax(probs, dim=-1).item() | |
| else: | |
| next_token_id = torch.multinomial(probs, 1).item() | |
| next_token_text = manager.tokenizer.decode([next_token_id], skip_special_tokens=False) | |
| generated_token_ids.append(next_token_id) | |
| generated_tokens.append(next_token_text) | |
| # Capture top-k token alternatives with probabilities | |
| import math | |
| top_k = 5 # Get top 5 alternatives | |
| top_probs, top_indices = torch.topk(probs, k=min(top_k, len(probs))) | |
| alternatives = [] | |
| for prob, idx in zip(top_probs.tolist(), top_indices.tolist()): | |
| token_text = manager.tokenizer.decode([idx], skip_special_tokens=False) | |
| alternatives.append({ | |
| "token": token_text, | |
| "token_id": idx, | |
| "probability": prob, | |
| "log_probability": math.log(prob) if prob > 0 else float('-inf') | |
| }) | |
| token_alternatives_by_step.append({ | |
| "step": step, | |
| "selected_token": next_token_text, | |
| "selected_token_id": next_token_id, | |
| "alternatives": alternatives | |
| }) | |
| # Process attention and hidden states for ALL layers | |
| layer_data_this_token = [] | |
| for layer_idx in range(len(outputs.attentions)): | |
| # Get attention for this layer [batch, num_heads, seq_len, seq_len] | |
| layer_attn = outputs.attentions[layer_idx][0] # Remove batch dim | |
| # Get hidden states [batch, seq_len, hidden_dim] | |
| current_hidden = outputs.hidden_states[layer_idx + 1] # +1 because hidden_states includes embedding layer | |
| if current_hidden.dim() == 3: | |
| current_hidden = current_hidden[0] # Remove batch dim if present | |
| if layer_idx > 0: | |
| prev_hidden = outputs.hidden_states[layer_idx] | |
| if prev_hidden.dim() == 3: | |
| prev_hidden = prev_hidden[0] | |
| delta_norm = torch.norm(current_hidden - prev_hidden).item() | |
| else: | |
| delta_norm = None | |
| # Calculate layer metrics | |
| import math | |
| activation_magnitude = torch.norm(current_hidden).item() | |
| # Use a simpler entropy calculation based on attention distribution | |
| last_token_hidden = current_hidden[-1] # [hidden_dim] | |
| activation_entropy = torch.std(last_token_hidden).item() # Use std dev as a proxy for activation diversity | |
| hidden_state_norm = torch.norm(last_token_hidden).item() # Norm of last token | |
| # Sanitize to prevent NaN/Inf in JSON | |
| activation_magnitude = 0.0 if math.isnan(activation_magnitude) or math.isinf(activation_magnitude) else activation_magnitude | |
| activation_entropy = 0.0 if math.isnan(activation_entropy) or math.isinf(activation_entropy) else activation_entropy | |
| hidden_state_norm = 0.0 if math.isnan(hidden_state_norm) or math.isinf(hidden_state_norm) else hidden_state_norm | |
| if delta_norm is not None: | |
| delta_norm = 0.0 if math.isnan(delta_norm) or math.isinf(delta_norm) else delta_norm | |
| # Identify critical heads (high max weight or low entropy) | |
| critical_heads = [] | |
| for head_idx in range(layer_attn.shape[0]): | |
| head_weights = layer_attn[head_idx, -1, :] # Attention from last position | |
| max_weight = head_weights.max().item() | |
| entropy = -(head_weights * torch.log(head_weights + 1e-10)).sum().item() | |
| # Sanitize to prevent NaN/Inf in JSON | |
| max_weight = 0.0 if math.isnan(max_weight) or math.isinf(max_weight) else max_weight | |
| entropy = 0.0 if math.isnan(entropy) or math.isinf(entropy) else entropy | |
| # Classify pattern | |
| pattern_type = None | |
| confidence = 0.0 | |
| # Induction pattern: high attention to previous similar tokens | |
| if step > 0 and max_weight > 0.8: | |
| pattern_type = "induction" | |
| confidence = max_weight | |
| # Positional pattern: attention focused on nearby tokens | |
| elif entropy < 1.0: | |
| pattern_type = "positional" | |
| confidence = 1.0 - entropy | |
| # Semantic pattern: broader attention with moderate entropy | |
| elif 1.0 <= entropy < 2.5: | |
| pattern_type = "semantic" | |
| confidence = min(1.0, entropy / 2.5) | |
| # Previous token pattern: sharp focus on immediate predecessor | |
| elif max_weight > 0.9 and head_weights[-2].item() > 0.85: | |
| pattern_type = "previous_token" | |
| confidence = head_weights[-2].item() | |
| # Sanitize confidence | |
| confidence = 0.0 if math.isnan(confidence) or math.isinf(confidence) else confidence | |
| # Get full attention weights for this head [seq_len, seq_len] | |
| attention_matrix = layer_attn[head_idx].cpu().numpy().tolist() | |
| # Get Q/K/V for this head if available | |
| q_matrix = None | |
| k_matrix = None | |
| v_matrix = None | |
| if layer_idx in qkv_captures: | |
| # Q/K/V shape: [seq_len, n_heads, head_dim] | |
| q_matrix = qkv_captures[layer_idx]['q'][:, head_idx, :].numpy().tolist() | |
| k_matrix = qkv_captures[layer_idx]['k'][:, head_idx, :].numpy().tolist() | |
| v_matrix = qkv_captures[layer_idx]['v'][:, head_idx, :].numpy().tolist() | |
| critical_heads.append({ | |
| "head_idx": head_idx, | |
| "entropy": entropy, | |
| "max_weight": max_weight, | |
| "attention_weights": attention_matrix, # Full attention matrix for spreadsheet | |
| "q_matrix": q_matrix, # [seq_len, head_dim] | |
| "k_matrix": k_matrix, | |
| "v_matrix": v_matrix, | |
| "pattern": { | |
| "type": pattern_type, | |
| "confidence": confidence | |
| } if pattern_type else None | |
| }) | |
| # Sort by max_weight (return all heads, frontend will decide how many to display) | |
| critical_heads.sort(key=lambda h: h["max_weight"], reverse=True) | |
| # Detect layer-level pattern | |
| layer_pattern = None | |
| if layer_idx == 0: | |
| layer_pattern = {"type": "positional", "confidence": 0.78} | |
| elif layer_idx <= 5 and step > 0: | |
| layer_pattern = {"type": "previous_token", "confidence": 0.65} | |
| elif 5 <= layer_idx <= 15: | |
| layer_pattern = {"type": "induction", "confidence": 0.87} | |
| elif layer_idx > 15: | |
| layer_pattern = {"type": "semantic", "confidence": 0.92} | |
| layer_data_this_token.append({ | |
| "layer_idx": layer_idx, | |
| "pattern": layer_pattern, | |
| "critical_heads": critical_heads, | |
| "activation_magnitude": activation_magnitude, | |
| "activation_entropy": activation_entropy, | |
| "hidden_state_norm": hidden_state_norm, | |
| "delta_norm": delta_norm | |
| }) | |
| layer_data_by_token.append(layer_data_this_token) | |
| # Update inputs | |
| next_token_tensor = torch.tensor([[next_token_id]], dtype=torch.long, device=manager.device) | |
| current_ids = torch.cat([current_ids, next_token_tensor], dim=1) | |
| # Stop on EOS | |
| if next_token_id == manager.tokenizer.eos_token_id: | |
| break | |
| # Clean up hooks after generation | |
| for hook in hooks: | |
| hook.remove() | |
| # Placeholder for Q/K/V data (will be populated in future iterations) | |
| qkv_by_layer_head = {} | |
| generation_time = time.time() - start_time | |
| # Build response | |
| response = { | |
| "prompt": prompt, | |
| "promptTokens": [{"text": t, "idx": tid, "bytes": len(t.encode('utf-8')), "type": "prompt"} | |
| for tid, t in zip(prompt_token_ids, prompt_tokens)], | |
| "generatedTokens": [{"text": t, "idx": tid, "bytes": len(t.encode('utf-8')), "type": "generated"} | |
| for tid, t in zip(generated_token_ids, generated_tokens)], | |
| "tokenAlternatives": token_alternatives_by_step, # Top-k alternatives for each token | |
| "layersDataByStep": layer_data_by_token, # Layer data for ALL generation steps | |
| "layersData": layer_data_by_token[-1] if layer_data_by_token else [], # Keep for backward compatibility | |
| "qkvData": qkv_by_layer_head, | |
| "modelInfo": { | |
| "numLayers": n_layers, | |
| "numHeads": n_heads, | |
| "modelDimension": d_model, | |
| "headDim": head_dim | |
| }, | |
| "generationTime": generation_time, | |
| "numTokensGenerated": len(generated_tokens) | |
| } | |
| logger.info(f"✅ Research attention analysis complete: {len(generated_tokens)} tokens, {generation_time:.2f}s") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Research attention analysis error: {e}") | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def analyze_study(request: StudyRequest, authenticated: bool = Depends(verify_api_key)): | |
| """ | |
| PhD Study endpoint - Comprehensive instrumentation for research. | |
| Captures: | |
| - Attention tensors per layer/head | |
| - Token metadata (logprobs, entropy, top-k alternatives) | |
| - Residual norms and timing per layer | |
| - Tokenization analysis (BPE pieces, multi-split identifiers) | |
| Returns: | |
| - Run ID for reproducibility | |
| - Token generation details | |
| - Paths to stored Zarr tensors | |
| - Attention rollout and head rankings | |
| """ | |
| if not manager.model or not manager.tokenizer: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| import time | |
| start_time = time.time() | |
| # Generate Run ID | |
| run_id = generate_run_id() | |
| logger.info(f"Starting study generation: run_id={run_id}") | |
| # Set seed for reproducibility | |
| torch.manual_seed(request.seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(request.seed) | |
| np.random.seed(request.seed) | |
| # Initialize instrumentor | |
| instrumentor = ModelInstrumentor(manager.model, manager.tokenizer, manager.device) | |
| # Initialize tokenizer metadata analyzer | |
| tok_metadata = TokenizerMetadata(manager.tokenizer) | |
| # Set up ablation hooks if requested (using working approach from generate_with_ablation) | |
| ablation_hooks = [] | |
| if request.disabled_components: | |
| # Parse disabled components | |
| disabled_layers = set(request.disabled_components.get('layers', [])) | |
| disabled_attention_raw = request.disabled_components.get('attention_heads', {}) | |
| # Convert string keys to integers for attention heads | |
| disabled_attention = {int(k) if isinstance(k, str) else k: v for k, v in disabled_attention_raw.items()} | |
| disabled_ffn = set(request.disabled_components.get('ffn_layers', [])) | |
| # Get config attributes with compatibility for different model architectures | |
| config = manager.model.config | |
| num_layers = getattr(config, 'num_hidden_layers', getattr(config, 'n_layer', 0)) | |
| num_heads = getattr(config, 'num_attention_heads', getattr(config, 'n_head', 0)) | |
| logger.info(f"Ablation request received with disabled_components: {request.disabled_components}") | |
| # Hook creation functions (from generate_with_ablation) | |
| def create_attention_hook(layer_idx, disabled_heads): | |
| def hook(module, input, output): | |
| if len(disabled_heads) == num_heads: | |
| # All heads disabled - zero out attention output | |
| if isinstance(output, tuple): | |
| return (torch.zeros_like(output[0]),) + output[1:] | |
| else: | |
| return torch.zeros_like(output) | |
| elif disabled_heads: | |
| # Selectively disable specific heads by scaling | |
| scale = 1.0 - (len(disabled_heads) / float(num_heads)) | |
| if isinstance(output, tuple): | |
| return (output[0] * scale,) + output[1:] | |
| else: | |
| return output * scale | |
| return output | |
| return hook | |
| def create_ffn_hook(): | |
| def hook(module, input, output): | |
| return torch.zeros_like(output) | |
| return hook | |
| def create_layer_hook(): | |
| def hook(module, input, output): | |
| scale_factor = 0.001 # Keep 0.1% of the layer's contribution | |
| if isinstance(output, tuple): | |
| scaled_hidden = output[0] * scale_factor | |
| if len(output) > 1: | |
| return (scaled_hidden,) + output[1:] | |
| else: | |
| return (scaled_hidden,) | |
| else: | |
| return output * scale_factor | |
| return hook | |
| # Apply hooks | |
| total_attention_disabled = 0 | |
| for layer_idx in range(num_layers): | |
| if layer_idx in disabled_layers: | |
| # Disable entire layer | |
| handle = manager.model.transformer.h[layer_idx].register_forward_hook(create_layer_hook()) | |
| ablation_hooks.append(handle) | |
| logger.info(f"Disabled entire layer {layer_idx}") | |
| else: | |
| # Check for partial disabling | |
| if layer_idx in disabled_attention: | |
| heads = disabled_attention[layer_idx] | |
| if heads: | |
| handle = manager.model.transformer.h[layer_idx].attn.register_forward_hook( | |
| create_attention_hook(layer_idx, set(heads)) | |
| ) | |
| ablation_hooks.append(handle) | |
| total_attention_disabled += len(heads) | |
| logger.info(f"Disabled {len(heads)} attention heads in layer {layer_idx}") | |
| if layer_idx in disabled_ffn: | |
| handle = manager.model.transformer.h[layer_idx].mlp.register_forward_hook(create_ffn_hook()) | |
| ablation_hooks.append(handle) | |
| logger.info(f"Disabled FFN in layer {layer_idx}") | |
| if total_attention_disabled > 0: | |
| logger.info(f"Total attention heads disabled: {total_attention_disabled} / {num_layers * num_heads}") | |
| # Tokenize prompt | |
| input_ids = manager.tokenizer.encode(request.prompt, return_tensors="pt").to(manager.device) | |
| prompt_length = input_ids.shape[1] | |
| logger.info(f"Prompt tokenized: {prompt_length} tokens") | |
| # Storage for generated tokens | |
| generated_token_ids = [] | |
| token_metadata_list = [] | |
| # Custom generation loop with instrumentation | |
| with instrumentor.capture(): | |
| with torch.no_grad(): | |
| current_ids = input_ids | |
| for step in range(request.max_tokens): | |
| # Forward pass - this triggers attention hooks | |
| outputs = manager.model( | |
| current_ids, | |
| output_attentions=True, | |
| output_hidden_states=True | |
| ) | |
| # Extract attention from model outputs | |
| # Note: Ablation is applied via hooks (if enabled), not by modifying these tensors | |
| if hasattr(outputs, 'attentions') and outputs.attentions is not None: | |
| for layer_idx, layer_attn in enumerate(outputs.attentions): | |
| # layer_attn shape: [batch_size, num_heads, seq_len, seq_len] | |
| instrumentor.attention_buffer.append({ | |
| 'layer_idx': layer_idx, | |
| 'weights': layer_attn[0].detach().cpu().float(), # Convert to FP32 | |
| 'timestamp': time.perf_counter() | |
| }) | |
| # Get logits for next token prediction | |
| logits = outputs.logits[0, -1, :] # [vocab_size] | |
| # Apply temperature | |
| if request.temperature > 0: | |
| logits = logits / request.temperature | |
| # Compute probabilities | |
| probs = torch.softmax(logits, dim=0) | |
| # Apply top-k filtering if specified | |
| if request.top_k is not None and request.top_k > 0: | |
| top_k_probs, top_k_indices = torch.topk(probs, min(request.top_k, probs.shape[0])) | |
| probs_filtered = torch.zeros_like(probs) | |
| probs_filtered[top_k_indices] = top_k_probs | |
| probs_filtered = probs_filtered / probs_filtered.sum() | |
| else: | |
| probs_filtered = probs | |
| # Apply top-p filtering if specified | |
| if request.top_p is not None and request.top_p < 1.0: | |
| sorted_probs, sorted_indices = torch.sort(probs_filtered, descending=True) | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=0) | |
| sorted_indices_to_remove = cumulative_probs > request.top_p | |
| sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() | |
| sorted_indices_to_remove[0] = False | |
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
| probs_filtered[indices_to_remove] = 0 | |
| probs_filtered = probs_filtered / probs_filtered.sum() | |
| # Sample next token | |
| if request.temperature == 0: | |
| # Deterministic: take argmax | |
| next_token = torch.argmax(probs_filtered, dim=-1).unsqueeze(0) | |
| else: | |
| next_token = torch.multinomial(probs_filtered, 1) | |
| # Compute token metadata | |
| token_meta = instrumentor.compute_token_metadata( | |
| token_ids=next_token, | |
| logits=logits.unsqueeze(0), | |
| position=prompt_length + step | |
| ) | |
| generated_token_ids.append(next_token.item()) | |
| token_metadata_list.append(token_meta) | |
| # Update input for next iteration | |
| current_ids = torch.cat([current_ids, next_token.unsqueeze(0)], dim=1) | |
| # Check for EOS | |
| if next_token.item() == manager.tokenizer.eos_token_id: | |
| logger.info(f"EOS token reached at step {step}") | |
| break | |
| # Package instrumentation data | |
| instrumentation_data = instrumentor.get_data( | |
| run_id=run_id, | |
| prompt=request.prompt, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| seed=request.seed, | |
| tokens=token_metadata_list, | |
| top_k=request.top_k, | |
| top_p=request.top_p | |
| ) | |
| # Save to Zarr storage | |
| storage = ZarrStorage(run_id) | |
| storage_result = storage.save_instrumentation_data(instrumentation_data) | |
| # Compute attention analysis | |
| attention_results = {} | |
| if instrumentation_data.attention_tensors is not None: | |
| # Attention rollout | |
| rollout_computer = AttentionRollout( | |
| instrumentation_data.attention_tensors, | |
| instrumentation_data.num_layers, | |
| instrumentation_data.num_heads | |
| ) | |
| rollout = rollout_computer.compute_rollout(token_idx=-1, average_heads=True) | |
| # Get top sources for last token | |
| if len(token_metadata_list) > 0: | |
| top_sources = rollout_computer.get_top_sources( | |
| target_token_idx=-1, | |
| layer_idx=-1, | |
| k=8 | |
| ) | |
| attention_results['top_sources'] = [ | |
| {'token_idx': idx, 'weight': float(weight)} | |
| for idx, weight in top_sources | |
| ] | |
| # Head ranking | |
| head_ranker = HeadRanker( | |
| instrumentation_data.attention_tensors, | |
| instrumentation_data.num_layers, | |
| instrumentation_data.num_heads | |
| ) | |
| top_heads_rollout = head_ranker.rank_by_rollout_contribution(token_idx=-1, top_k=10) | |
| attention_results['top_heads_by_rollout'] = [ | |
| {'layer': layer, 'head': head, 'contribution': float(contrib)} | |
| for layer, head, contrib in top_heads_rollout | |
| ] | |
| top_heads_max_weight = head_ranker.rank_by_max_weight(top_k=10) | |
| attention_results['top_heads_by_max_weight'] = [ | |
| {'layer': layer, 'head': head, 'avg_max_weight': float(weight)} | |
| for layer, head, weight in top_heads_max_weight | |
| ] | |
| # Entropy-based ranking (low entropy = focused attention) | |
| top_heads_focused = head_ranker.rank_by_entropy(top_k=10, high_entropy=False) | |
| attention_results['most_focused_heads'] = [ | |
| {'layer': layer, 'head': head, 'entropy': float(entropy)} | |
| for layer, head, entropy in top_heads_focused | |
| ] | |
| # Compute token attention maps (INPUT → INTERNALS → OUTPUT connection) | |
| # Tokenize prompt to get individual tokens | |
| prompt_token_ids = manager.tokenizer.encode(request.prompt, add_special_tokens=False) | |
| prompt_tokens = [manager.tokenizer.decode([tid]) for tid in prompt_token_ids] | |
| prompt_length = len(prompt_token_ids) | |
| # Extract generated token texts | |
| generated_tokens = [t.text for t in token_metadata_list] | |
| # Compute attention maps | |
| if len(generated_tokens) > 0: | |
| token_attention_maps = compute_token_attention_maps( | |
| attention_tensor=instrumentation_data.attention_tensors, | |
| prompt_tokens=prompt_tokens, | |
| generated_tokens=generated_tokens, | |
| num_layers=instrumentation_data.num_layers, | |
| num_heads=instrumentation_data.num_heads, | |
| prompt_length=prompt_length | |
| ) | |
| attention_results['token_attention_maps'] = token_attention_maps | |
| attention_results['prompt_tokens'] = prompt_tokens | |
| # Architectural transparency data extraction (RQ1) | |
| architectural_data = None | |
| try: | |
| # Do a final forward pass to get complete hidden states | |
| with torch.no_grad(): | |
| final_ids = torch.cat([input_ids, torch.tensor([generated_token_ids], device=manager.device)], dim=1) | |
| final_outputs = manager.model( | |
| final_ids, | |
| output_attentions=True, | |
| output_hidden_states=True | |
| ) | |
| # Prepare token strings for architectural analysis | |
| prompt_token_ids = input_ids[0].tolist() | |
| prompt_tokens = [manager.tokenizer.decode([tid], skip_special_tokens=False) for tid in prompt_token_ids] | |
| output_tokens = [manager.tokenizer.decode([tid], skip_special_tokens=False) for tid in generated_token_ids] | |
| # Get model config for architectural analysis | |
| config = manager.model.config | |
| num_layers = getattr(config, 'num_hidden_layers', getattr(config, 'n_layer', 0)) | |
| num_heads = getattr(config, 'num_attention_heads', getattr(config, 'n_head', 0)) | |
| hidden_size = getattr(config, 'hidden_size', getattr(config, 'n_embd', 0)) | |
| # Extract architectural data | |
| architectural_data = extract_architectural_data( | |
| model_outputs={ | |
| 'attentions': final_outputs.attentions, | |
| 'hidden_states': final_outputs.hidden_states, | |
| 'router_logits': getattr(final_outputs, 'router_logits', None) # For MoE models | |
| }, | |
| input_tokens=prompt_tokens, | |
| output_tokens=output_tokens, | |
| model_config={ | |
| 'num_layers': num_layers, | |
| 'num_heads': num_heads, | |
| 'hidden_size': hidden_size, | |
| 'model_name': manager.model_name | |
| } | |
| ) | |
| logger.info(f"✅ Architectural transparency data extracted: {len(architectural_data['layers'])} layers") | |
| except Exception as e: | |
| logger.warning(f"Failed to extract architectural data: {e}") | |
| logger.warning(traceback.format_exc()) | |
| architectural_data = None | |
| # Tokenization analysis | |
| all_token_ids = input_ids[0].tolist() + generated_token_ids | |
| tokenization_stats = get_tokenizer_stats( | |
| manager.tokenizer, | |
| manager.tokenizer.decode(all_token_ids) | |
| ) | |
| # Decode generated text | |
| generated_text = manager.tokenizer.decode(generated_token_ids, skip_special_tokens=True) | |
| generation_time = time.time() - start_time | |
| # Build response | |
| response = { | |
| "run_id": run_id, | |
| "seed": request.seed, | |
| "prompt": request.prompt, | |
| "generated_text": generated_text, | |
| "full_text": request.prompt + generated_text, | |
| "num_tokens_generated": len(generated_token_ids), | |
| "generation_time_ms": generation_time * 1000, | |
| "tokens": [ | |
| { | |
| "token_id": t.token_id, | |
| "text": t.text, | |
| "position": t.position, | |
| "logprob": t.logprob, | |
| "entropy": t.entropy, | |
| "top_k_alternatives": [ | |
| {"text": alt_text, "prob": prob} | |
| for alt_text, prob in t.top_k_tokens | |
| ], | |
| "byte_length": t.byte_length | |
| } | |
| for t in token_metadata_list | |
| ], | |
| "storage": { | |
| "run_dir": str(storage.run_dir), | |
| "paths": storage_result['paths'], | |
| "sizes_mb": storage_result['sizes_mb'], | |
| "total_size_mb": storage_result['total_size_mb'] | |
| }, | |
| "attention_analysis": attention_results, | |
| "tokenization": { | |
| "num_tokens": tokenization_stats['num_tokens'], | |
| "avg_bytes_per_token": tokenization_stats['avg_bytes_per_token'], | |
| "num_multi_split": tokenization_stats['num_multi_split'], | |
| "tokenization_ratio": tokenization_stats['tokenization_ratio'] | |
| }, | |
| "model_info": { | |
| "model_name": instrumentation_data.model_name, | |
| "num_layers": instrumentation_data.num_layers, | |
| "num_heads": instrumentation_data.num_heads, | |
| "seq_length": instrumentation_data.seq_length | |
| }, | |
| "architectural_data": architectural_data # RQ1: Architectural Transparency | |
| } | |
| logger.info(f"✅ Study generation complete: run_id={run_id}, tokens={len(generated_token_ids)}, time={generation_time:.2f}s") | |
| # Clean up ablation hooks | |
| for handle in ablation_hooks: | |
| handle.remove() | |
| if ablation_hooks: | |
| logger.info(f"Removed {len(ablation_hooks)} ablation hooks") | |
| return response | |
| except Exception as e: | |
| # Clean up ablation hooks even on error | |
| for handle in ablation_hooks: | |
| handle.remove() | |
| logger.error(f"Study generation error: {e}") | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def list_demos(authenticated: bool = Depends(verify_api_key)): | |
| """List available demo prompts""" | |
| return { | |
| "demos": [ | |
| { | |
| "id": "fibonacci", | |
| "name": "Fibonacci Function", | |
| "prompt": "def fibonacci(n):\n '''Calculate fibonacci number'''", | |
| "description": "Generate a recursive fibonacci implementation" | |
| }, | |
| { | |
| "id": "quicksort", | |
| "name": "Quicksort Algorithm", | |
| "prompt": "def quicksort(arr):\n '''Sort array using quicksort'''", | |
| "description": "Generate a quicksort implementation" | |
| }, | |
| { | |
| "id": "stack", | |
| "name": "Stack Class", | |
| "prompt": "class Stack:\n '''Simple stack implementation'''", | |
| "description": "Generate a stack data structure" | |
| }, | |
| { | |
| "id": "binary_search", | |
| "name": "Binary Search", | |
| "prompt": "def binary_search(arr, target):\n '''Find target in sorted array'''", | |
| "description": "Generate a binary search function" | |
| } | |
| ] | |
| } | |
| async def run_demo(request: DemoRequest, authenticated: bool = Depends(verify_api_key)): | |
| """Run a specific demo""" | |
| demos = { | |
| "fibonacci": "def fibonacci(n):\n '''Calculate fibonacci number'''", | |
| "quicksort": "def quicksort(arr):\n '''Sort array using quicksort'''", | |
| "stack": "class Stack:\n '''Simple stack implementation'''", | |
| "binary_search": "def binary_search(arr, target):\n '''Find target in sorted array'''" | |
| } | |
| if request.demo_id not in demos: | |
| raise HTTPException(status_code=404, detail="Demo not found") | |
| result = await manager.generate_with_traces( | |
| prompt=demos[request.demo_id], | |
| max_tokens=100, | |
| temperature=0.7, | |
| sampling_rate=0.3 # Same as regular generation for better visualization | |
| ) | |
| return result | |
| # SWE-bench endpoints | |
| async def startup_swe_bench(): | |
| """Initialize SWE-bench service on startup""" | |
| from .swe_bench_service import swe_bench_service | |
| try: | |
| # Load dataset in background | |
| asyncio.create_task(swe_bench_service.load_dataset()) | |
| logger.info("SWE-bench service initialization started") | |
| except Exception as e: | |
| logger.warning(f"SWE-bench initialization deferred: {e}") | |
| async def get_swe_bench_tasks( | |
| category: Optional[str] = None, | |
| difficulty: Optional[str] = None, | |
| repo: Optional[str] = None, | |
| limit: int = 100, | |
| offset: int = 0, | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """Get list of SWE-bench tasks""" | |
| from .swe_bench_service import swe_bench_service | |
| if not swe_bench_service.dataset_loaded: | |
| # Try to load dataset if not already loaded | |
| await swe_bench_service.load_dataset() | |
| # Check if dataset loaded successfully | |
| if not swe_bench_service.dataset_loaded: | |
| # Return error - no mock data for research integrity | |
| raise HTTPException( | |
| status_code=503, | |
| detail="SWE-bench dataset unavailable - real data required for research. Check server logs for details." | |
| ) | |
| tasks = swe_bench_service.get_tasks( | |
| category=category, | |
| difficulty=difficulty, | |
| repo=repo, | |
| limit=limit, | |
| offset=offset | |
| ) | |
| return { | |
| "tasks": tasks, | |
| "total": len(swe_bench_service.tasks), | |
| "limit": limit, | |
| "offset": offset | |
| } | |
| async def get_swe_bench_task( | |
| task_id: str, | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """Get details for a specific SWE-bench task""" | |
| from .swe_bench_service import swe_bench_service | |
| if not swe_bench_service.dataset_loaded: | |
| await swe_bench_service.load_dataset() | |
| task = swe_bench_service.get_task_details(task_id) | |
| if not task: | |
| raise HTTPException(status_code=404, detail="Task not found") | |
| return task | |
| async def generate_swe_bench_solution( | |
| request: Dict[str, Any], | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """Generate a solution for a SWE-bench task""" | |
| from .swe_bench_service import swe_bench_service | |
| if not swe_bench_service.dataset_loaded: | |
| await swe_bench_service.load_dataset() | |
| task_id = request.get("task_id") | |
| if not task_id: | |
| raise HTTPException(status_code=400, detail="task_id is required") | |
| enable_transparency = request.get("enable_transparency", True) | |
| temperature = request.get("temperature", 0.7) | |
| max_tokens = request.get("max_tokens", 500) | |
| try: | |
| result = await swe_bench_service.generate_solution( | |
| task_id=task_id, | |
| model_manager=manager, | |
| enable_transparency=enable_transparency, | |
| temperature=temperature, | |
| max_tokens=max_tokens | |
| ) | |
| return result.to_dict() | |
| except ValueError as e: | |
| raise HTTPException(status_code=404, detail=str(e)) | |
| except Exception as e: | |
| logger.error(f"SWE-bench generation error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def evaluate_swe_bench_solution( | |
| request: Dict[str, Any], | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """Evaluate a generated solution""" | |
| from .swe_bench_service import swe_bench_service | |
| task_id = request.get("task_id") | |
| solution = request.get("solution") | |
| run_tests = request.get("run_tests", False) | |
| if not task_id or not solution: | |
| raise HTTPException(status_code=400, detail="task_id and solution are required") | |
| try: | |
| evaluation = await swe_bench_service.evaluate_solution( | |
| task_id=task_id, | |
| solution=solution, | |
| run_tests=run_tests | |
| ) | |
| return evaluation | |
| except ValueError as e: | |
| raise HTTPException(status_code=404, detail=str(e)) | |
| except Exception as e: | |
| logger.error(f"SWE-bench evaluation error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_swe_bench_metrics( | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """Get aggregate metrics for SWE-bench evaluations""" | |
| from .swe_bench_service import swe_bench_service | |
| if not swe_bench_service.dataset_loaded: | |
| await swe_bench_service.load_dataset() | |
| return swe_bench_service.get_metrics() | |
| async def get_swe_bench_comparison( | |
| task_id: str, | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """Get comparison results for a task (with vs without transparency)""" | |
| from .swe_bench_service import swe_bench_service | |
| comparison = swe_bench_service.get_comparison_results(task_id) | |
| if not comparison: | |
| raise HTTPException( | |
| status_code=404, | |
| detail="No comparison data available. Generate solutions with and without transparency first." | |
| ) | |
| return comparison | |
| # ============================================================================== | |
| # VOCABULARY & TOKENIZATION ENDPOINTS | |
| # ============================================================================== | |
| async def search_vocabulary( | |
| request: Dict[str, Any], | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """Search vocabulary by query string""" | |
| query = request.get("query", "").lower() | |
| limit = request.get("limit", 50) | |
| if not query: | |
| return {"results": [], "total": 0} | |
| vocab = manager.tokenizer.get_vocab() | |
| # Search for tokens containing the query | |
| results = [] | |
| for token, token_id in vocab.items(): | |
| if query in token.lower(): | |
| results.append({ | |
| "token": token, | |
| "token_id": token_id, | |
| "byte_length": len(token.encode('utf-8')) | |
| }) | |
| if len(results) >= limit: | |
| break | |
| return { | |
| "results": results, | |
| "total": len(results), | |
| "vocabulary_size": len(vocab) | |
| } | |
| async def browse_vocabulary( | |
| page: int = 0, | |
| page_size: int = 100, | |
| filter_type: str = "all", # all, programming, common, functions | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """Browse vocabulary with pagination and smart filtering""" | |
| vocab = manager.tokenizer.get_vocab() | |
| # Smart filtering for programming tokens | |
| if filter_type == "programming": | |
| # Python keywords and common programming terms | |
| programming_keywords = { | |
| "def", "class", "return", "import", "from", "if", "else", "elif", | |
| "for", "while", "break", "continue", "pass", "try", "except", | |
| "finally", "with", "as", "lambda", "yield", "async", "await", | |
| "None", "True", "False", "and", "or", "not", "in", "is" | |
| } | |
| filtered_vocab = {k: v for k, v in vocab.items() if k in programming_keywords} | |
| elif filter_type == "functions": | |
| # Common function/method names | |
| filtered_vocab = {k: v for k, v in vocab.items() | |
| if any(term in k.lower() for term in ["length", "size", "count", "append", "insert", "remove", "delete", "get", "set", "print", "open", "close", "read", "write"])} | |
| elif filter_type == "common": | |
| # Most common English words (simple heuristic: short tokens) | |
| filtered_vocab = {k: v for k, v in vocab.items() if len(k) <= 4 and k.isalpha()} | |
| else: | |
| filtered_vocab = vocab | |
| # Sort by token ID | |
| sorted_items = sorted(filtered_vocab.items(), key=lambda x: x[1]) | |
| # Paginate | |
| start = page * page_size | |
| end = start + page_size | |
| page_items = sorted_items[start:end] | |
| results = [] | |
| for token, token_id in page_items: | |
| results.append({ | |
| "token": token, | |
| "token_id": token_id, | |
| "byte_length": len(token.encode('utf-8')) | |
| }) | |
| return { | |
| "items": results, | |
| "total": len(filtered_vocab), | |
| "page": page, | |
| "page_size": page_size, | |
| "total_pages": (len(filtered_vocab) + page_size - 1) // page_size | |
| } | |
| async def tokenize_preview( | |
| request: Dict[str, Any], | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """Live tokenization preview for arbitrary text""" | |
| from .tokenizer_utils import TokenizerMetadata, get_tokenizer_stats | |
| text = request.get("text", "") | |
| if not text: | |
| return {"tokens": [], "stats": {}} | |
| # Tokenize | |
| token_ids = manager.tokenizer.encode(text, add_special_tokens=False) | |
| # Get metadata | |
| metadata = TokenizerMetadata(manager.tokenizer) | |
| token_analysis = metadata.analyze_tokens(token_ids) | |
| stats = get_tokenizer_stats(manager.tokenizer, text) | |
| return { | |
| "text": text, | |
| "tokens": token_analysis, | |
| "stats": stats, | |
| "token_count": len(token_ids) | |
| } | |
| async def compare_tokenizers( | |
| request: Dict[str, Any], | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """Compare tokenization across different models""" | |
| from transformers import AutoTokenizer | |
| from .tokenizer_utils import get_tokenizer_stats | |
| text = request.get("text", "") | |
| models = request.get("models", ["Salesforce/codegen-350M-mono"]) | |
| if not text: | |
| return {"results": {}} | |
| results = {} | |
| for model_name in models: | |
| try: | |
| # Load tokenizer (will be cached by transformers) | |
| if model_name == "Salesforce/codegen-350M-mono": | |
| tokenizer = manager.tokenizer | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Tokenize | |
| tokens = tokenizer.tokenize(text) | |
| token_ids = tokenizer.encode(text, add_special_tokens=False) | |
| token_texts = [tokenizer.decode([tid]) for tid in token_ids] | |
| stats = get_tokenizer_stats(tokenizer, text) | |
| results[model_name] = { | |
| "tokens": tokens, | |
| "token_ids": token_ids, | |
| "token_texts": token_texts, | |
| "token_count": len(token_ids), | |
| "stats": stats | |
| } | |
| except Exception as e: | |
| logger.error(f"Error loading tokenizer {model_name}: {e}") | |
| results[model_name] = {"error": str(e)} | |
| return {"text": text, "results": results} | |
| async def get_token_metadata( | |
| request: Dict[str, Any], | |
| authenticated: bool = Depends(verify_api_key) | |
| ): | |
| """Get comprehensive metadata for a specific token""" | |
| from .tokenizer_utils import TokenizerMetadata | |
| token_id = request.get("token_id") | |
| if token_id is None: | |
| raise HTTPException(status_code=400, detail="token_id is required") | |
| metadata = TokenizerMetadata(manager.tokenizer) | |
| # Get token text | |
| token_text = manager.tokenizer.decode([token_id]) | |
| # Get BPE pieces | |
| bpe_pieces = metadata.get_subword_pieces(token_id) | |
| # Get byte length | |
| byte_length = metadata.get_byte_length(token_id) | |
| # Check if special token | |
| special_tokens = { | |
| "eos": manager.tokenizer.eos_token_id, | |
| "bos": manager.tokenizer.bos_token_id, | |
| "pad": manager.tokenizer.pad_token_id, | |
| "unk": manager.tokenizer.unk_token_id | |
| } | |
| is_special = token_id in special_tokens.values() | |
| # Check if multi-split (returns array, extract first element) | |
| is_multi_split_array = metadata.is_multi_split_identifier([token_id]) | |
| is_multi_split = is_multi_split_array[0] if is_multi_split_array else False | |
| # DEBUG LOGGING | |
| print(f"\n{'='*60}") | |
| print(f"TOKEN METADATA DEBUG - Token ID: {token_id}") | |
| print(f"{'='*60}") | |
| print(f"Token Text: {repr(token_text)}") | |
| print(f"BPE Pieces: {bpe_pieces}") | |
| print(f"Num Pieces: {len(bpe_pieces)}") | |
| print(f"Byte Length: {byte_length}") | |
| print(f"Is Special: {is_special}") | |
| print(f"Multi-split Array: {is_multi_split_array}") | |
| print(f"Multi-split Boolean: {is_multi_split} (type: {type(is_multi_split).__name__})") | |
| print(f"Tokenizer Type: {metadata.tokenizer_type}") | |
| print(f"{'='*60}\n") | |
| result = { | |
| "token_id": token_id, | |
| "text": token_text, | |
| "bpe_pieces": bpe_pieces, | |
| "byte_length": byte_length, | |
| "is_special": is_special, | |
| "is_multi_split": is_multi_split, | |
| "num_pieces": len(bpe_pieces), | |
| "tokenizer_type": metadata.tokenizer_type | |
| } | |
| print(f"RESPONSE: {result}\n") | |
| return result | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |