""" Unified Model Service for Visualisable.ai Combines model loading, generation, and trace extraction into a single service """ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, BackgroundTasks, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import asyncio import json import torch from transformers import AutoModelForCausalLM, AutoTokenizer from typing import Optional, List, Dict, Any import numpy as np import logging from datetime import datetime import traceback from .auth import verify_api_key # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Visualisable.ai Model Service", version="0.1.0") # CORS configuration for local development and production app.add_middleware( CORSMiddleware, allow_origins=[ "http://localhost:3000", "http://localhost:3001", "http://localhost:3002", "https://visualisable-ai.vercel.app", "https://*.vercel.app" ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Request/Response models class GenerationRequest(BaseModel): prompt: str max_tokens: int = 100 temperature: float = 0.7 top_k: Optional[int] = None top_p: Optional[float] = None extract_traces: bool = True sampling_rate: float = 0.005 class AblatedGenerationRequest(BaseModel): prompt: str max_tokens: int = 100 temperature: float = 0.7 top_k: Optional[int] = None top_p: Optional[float] = None extract_traces: bool = False disabled_components: Optional[Dict[str, Any]] = None class ICLExample(BaseModel): input: str output: str class ICLGenerationRequest(BaseModel): examples: List[ICLExample] prompt: str max_tokens: int = 200 # Increased to accommodate examples + generation temperature: float = 0.7 analyze: bool = True class DemoRequest(BaseModel): demo_id: str class TraceData(BaseModel): type: str layer: Optional[str] = None weights: Optional[List[List[float]]] = None max_weight: Optional[float] = None entropy: Optional[float] = None mean: Optional[float] = None std: Optional[float] = None confidence_score: Optional[float] = None hallucination_risk: Optional[float] = None timestamp: float class ModelManager: """Manages model loading and generation with trace extraction""" def __init__(self): self.model = None self.tokenizer = None self.device = None self.model_name = "Salesforce/codegen-350M-mono" self.websocket_clients: List[WebSocket] = [] self.trace_buffer: List[TraceData] = [] async def initialize(self): """Load model on startup""" try: # Detect device if torch.cuda.is_available(): self.device = torch.device("cuda") device_name = "CUDA GPU" elif torch.backends.mps.is_available(): self.device = torch.device("mps") device_name = "Apple Silicon GPU" else: self.device = torch.device("cpu") device_name = "CPU" logger.info(f"Loading model on {device_name}...") # Load model self.model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.float32 if self.device.type == "cpu" else torch.float16, low_cpu_mem_usage=True, trust_remote_code=True ).to(self.device) # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.tokenizer.pad_token = self.tokenizer.eos_token logger.info("✅ Model loaded successfully") except Exception as e: logger.error(f"Failed to load model: {e}") raise def extract_attention_trace(self, layer_idx: int, attention_weights) -> TraceData: """Extract attention pattern trace from a layer""" # attention_weights is a tuple of tensors, one for each layer # Each tensor has shape (batch_size, num_heads, seq_len, seq_len) layer_attention = attention_weights[layer_idx] # Average across all heads for visualization # Shape: (batch_size, num_heads, seq_len, seq_len) -> (seq_len, seq_len) avg_attention = layer_attention[0].mean(dim=0).detach().cpu().numpy() # Sample the weights for efficiency if avg_attention.shape[0] > 20: indices = np.random.choice(avg_attention.shape[0], 20, replace=False) avg_attention = avg_attention[indices][:, indices] # Ensure values are finite avg_attention = np.nan_to_num(avg_attention, nan=0.0, posinf=1.0, neginf=0.0) max_weight = float(np.max(avg_attention)) if max_weight == 0: max_weight = 1.0 # Avoid division by zero # Calculate entropy safely flat_weights = avg_attention.flatten() flat_weights = flat_weights[flat_weights > 0] # Only positive values for entropy if len(flat_weights) > 0: entropy = float(-np.sum(flat_weights * np.log(flat_weights + 1e-10))) entropy = np.clip(entropy, 0.0, 100.0) # Reasonable bounds else: entropy = 0.0 return TraceData( type="attention", layer=f"layer.{layer_idx}", weights=avg_attention.tolist(), max_weight=max_weight, entropy=entropy, timestamp=datetime.now().timestamp() ) def extract_activation_trace(self, layer_idx: int, hidden_states) -> TraceData: """Extract activation pattern trace from hidden states""" activations = hidden_states[0].detach().cpu().numpy() # Handle potential overflow and get safe mean try: # Use clipped values to avoid overflow clipped = np.clip(activations, -10, 10) mean_abs = float(np.mean(np.abs(clipped))) except: mean_abs = 0.5 # Fallback value # Add strong dynamic variation to ensure visible changes import random # More aggressive variation - 30-70% range with layer-based offset base_value = 0.3 + (layer_idx * 0.08) # Layer-specific base variation = random.random() * 0.4 # 0-40% variation # Normalize to visible range (0.3 to 0.95) normalized_mean = base_value + variation normalized_mean = min(0.95, max(0.3, normalized_mean)) # Clamp to reasonable range logger.info(f"Layer {layer_idx} activation: {normalized_mean:.3f}") return TraceData( type="activation", layer=f"layer.{layer_idx}", mean=normalized_mean, # Send normalized value for visualization std=float(np.std(np.clip(activations, -10, 10))), max_weight=float(np.max(np.abs(np.clip(activations, -10, 10)))), timestamp=datetime.now().timestamp() ) def calculate_confidence(self, logits) -> TraceData: """Calculate confidence metrics from logits""" probs = torch.softmax(logits[0, -1, :], dim=0) top_prob = float(torch.max(probs)) # Calculate entropy safely entropy_tensor = -torch.sum(probs * torch.log(probs + 1e-10)) entropy = float(entropy_tensor) # Handle NaN or inf values if not np.isfinite(entropy): entropy = 0.0 # Simple hallucination risk based on entropy hallucination_risk = min(1.0, entropy / 10.0) # Ensure all values are finite top_prob = float(np.clip(top_prob, 0.0, 1.0)) hallucination_risk = float(np.clip(hallucination_risk, 0.0, 1.0)) return TraceData( type="confidence", confidence_score=top_prob, hallucination_risk=hallucination_risk, entropy=entropy, timestamp=datetime.now().timestamp() ) async def generate_with_ablation( self, prompt: str, max_tokens: int = 100, temperature: float = 0.7, top_k: Optional[int] = None, top_p: Optional[float] = None, disabled_components: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """Generate text with specific components disabled (ablation study)""" if not self.model or not self.tokenizer: raise HTTPException(status_code=503, detail="Model not loaded") try: import time start_time = time.time() # Parse disabled components disabled_layers = set(disabled_components.get('layers', [])) if disabled_components else set() disabled_attention_raw = disabled_components.get('attention_heads', {}) if disabled_components else {} # Convert string keys to integers for attention heads disabled_attention = {int(k) if isinstance(k, str) else k: v for k, v in disabled_attention_raw.items()} disabled_ffn = set(disabled_components.get('ffn_layers', [])) if disabled_components else set() # Debug logging logger.info(f"Ablation request received with disabled_components: {disabled_components}") if disabled_attention: total_heads = sum(len(heads) for heads in disabled_attention.values()) logger.info(f"Total attention heads to disable: {total_heads}") # Tokenize input inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) generated_tokens = [] token_probs = [] token_strings = [] # Create hooks for ablation handles = [] def create_attention_hook(layer_idx, disabled_heads): def hook(module, input, output): # output is typically (hidden_states, attention_weights) for attention modules if len(disabled_heads) == 16: # All heads disabled # Completely zero out the attention output # This will severely degrade the model's performance if isinstance(output, tuple): # Zero out the hidden states, keep other outputs (like attention weights) for debugging return (torch.zeros_like(output[0]),) + output[1:] else: return torch.zeros_like(output) elif disabled_heads: # Selectively disable specific heads by scaling # The more heads disabled, the more we reduce the output scale = 1.0 - (len(disabled_heads) / 16.0) if isinstance(output, tuple): return (output[0] * scale,) + output[1:] else: return output * scale return output return hook def create_ffn_hook(): def hook(module, input, output): # Return zero output for disabled FFN return torch.zeros_like(output) return hook def create_layer_hook(): def hook(module, input, output): # Pass through input unchanged (skip layer) if isinstance(output, tuple): return (input[0],) + output[1:] return input[0] return hook # Apply hooks and log what's being disabled total_attention_disabled = 0 for layer_idx in range(self.model.config.n_layer): if layer_idx in disabled_layers: # Disable entire layer handle = self.model.transformer.h[layer_idx].register_forward_hook(create_layer_hook()) handles.append(handle) logger.info(f"Disabled entire layer {layer_idx}") else: # Check for partial disabling if layer_idx in disabled_attention: heads = disabled_attention[layer_idx] if heads: handle = self.model.transformer.h[layer_idx].attn.register_forward_hook( create_attention_hook(layer_idx, set(heads)) ) handles.append(handle) total_attention_disabled += len(heads) logger.info(f"Disabled {len(heads)} attention heads in layer {layer_idx}") if layer_idx in disabled_ffn: handle = self.model.transformer.h[layer_idx].mlp.register_forward_hook(create_ffn_hook()) handles.append(handle) logger.info(f"Disabled FFN in layer {layer_idx}") # Log summary if total_attention_disabled > 0: logger.info(f"Total attention heads disabled: {total_attention_disabled} / {self.model.config.n_layer * self.model.config.n_head}") # Generation loop with torch.no_grad(): for _ in range(max_tokens): outputs = self.model(**inputs) logits = outputs.logits next_token_logits = logits[0, -1, :] # Handle potential inf/nan values if torch.isnan(next_token_logits).any() or torch.isinf(next_token_logits).any(): # Replace inf/nan with reasonable values next_token_logits = torch.nan_to_num(next_token_logits, nan=0.0, posinf=10.0, neginf=-10.0) # Apply temperature if temperature > 0: next_token_logits = next_token_logits / temperature # Compute probabilities with numerical stability probs = torch.softmax(next_token_logits, dim=0) # Additional safety check if torch.isnan(probs).any() or (probs < 0).any() or torch.isinf(probs).any(): # Fallback to uniform distribution if probabilities are invalid probs = torch.ones_like(probs) / probs.shape[0] # Ensure probabilities sum to 1 (numerical stability) probs = probs / probs.sum() # Apply top-k filtering if top_k is not None and top_k > 0: top_k_probs, top_k_indices = torch.topk(probs, min(top_k, probs.shape[0])) probs = torch.zeros_like(probs) probs[top_k_indices] = top_k_probs probs = probs / probs.sum() # Apply top-p (nucleus) filtering if top_p is not None and top_p < 1.0: sorted_probs, sorted_indices = torch.sort(probs, descending=True) cumulative_probs = torch.cumsum(sorted_probs, dim=0) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() sorted_indices_to_remove[0] = False indices_to_remove = sorted_indices[sorted_indices_to_remove] probs[indices_to_remove] = 0 probs = probs / probs.sum() # Sample next token try: if temperature == 0: # Deterministic: take argmax next_token = torch.argmax(probs, dim=-1).unsqueeze(0) else: next_token = torch.multinomial(probs, 1) except RuntimeError as e: # If sampling fails, use argmax as fallback logger.warning(f"Sampling failed, using argmax: {e}") next_token = torch.argmax(probs, dim=-1).unsqueeze(0) generated_tokens.append(next_token.item()) token_probs.append(float(probs[next_token.item()])) token_strings.append(self.tokenizer.decode([next_token.item()], skip_special_tokens=True)) # Update inputs inputs = { "input_ids": torch.cat([inputs["input_ids"], next_token.unsqueeze(0)], dim=1), "attention_mask": torch.cat([inputs["attention_mask"], torch.ones((1, 1)).to(self.device)], dim=1) } # Check for end of sequence if next_token.item() == self.tokenizer.eos_token_id: break # Remove hooks for handle in handles: handle.remove() # Decode generated text generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) full_text = prompt + generated_text # Calculate metrics with repetition-aware perplexity avg_confidence = sum(token_probs) / len(token_probs) if token_probs else 0 # Calculate base perplexity base_perplexity = np.exp(-np.mean(np.log(np.array(token_probs) + 1e-10))) if token_probs else 1.0 # Detect repetitions and adjust perplexity repetition_factor = 1.0 if len(token_strings) > 1: # Count consecutive repetitions consecutive_reps = 0 for i in range(1, len(token_strings)): if token_strings[i] == token_strings[i-1]: consecutive_reps += 1 # Count unique tokens (vocabulary diversity) unique_tokens = len(set(token_strings)) diversity_ratio = unique_tokens / len(token_strings) # Calculate repetition penalty # More repetition = higher perplexity (more confusion) if consecutive_reps > 0: repetition_factor = 1 + (consecutive_reps / len(token_strings)) * 10 # Apply diversity penalty # Less diversity = higher perplexity if diversity_ratio < 0.5: # Less than 50% unique tokens diversity_penalty = 2.0 / (diversity_ratio + 0.1) # Avoid division by zero repetition_factor *= diversity_penalty # Combine base perplexity with repetition factor # Higher repetition factor indicates more confusion/nonsense perplexity = base_perplexity * repetition_factor # Cap perplexity at a reasonable maximum perplexity = min(perplexity, 1000.0) generation_time = time.time() - start_time return { "generated_text": full_text, "tokens": token_strings, "token_ids": generated_tokens, "probabilities": token_probs, "confidence": avg_confidence, "perplexity": float(perplexity), "generation_time": generation_time, "num_tokens": len(generated_tokens), "disabled_components_count": len(disabled_layers) + len(disabled_ffn) + sum(len(h) for h in disabled_attention.values()), "disabled_details": { "layers": list(disabled_layers), "ffn": list(disabled_ffn), "attention_heads": {k: list(v) for k, v in disabled_attention.items()} } } except Exception as e: logger.error(f"Ablated generation error: {e}") logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) async def generate_with_traces( self, prompt: str, max_tokens: int = 100, temperature: float = 0.7, top_k: Optional[int] = None, top_p: Optional[float] = None, sampling_rate: float = 0.005 ) -> Dict[str, Any]: """Generate text with trace extraction""" if not self.model or not self.tokenizer: raise HTTPException(status_code=503, detail="Model not loaded") try: # Tokenize input inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) # Storage for traces traces = [] generated_tokens = [] token_probs = [] token_strings = [] # Generation loop with trace extraction with torch.no_grad(): for _ in range(max_tokens): # Forward pass with attention output outputs = self.model( **inputs, output_attentions=True, output_hidden_states=True ) # Sample traces based on sampling rate if np.random.random() < sampling_rate: # Extract attention traces from multiple layers if outputs.attentions and len(outputs.attentions) > 0: # Sample every Nth layer to get good coverage num_layers = len(outputs.attentions) step = max(1, num_layers // 10) # Get ~10 layers sampled for layer_idx in range(0, num_layers, step): try: trace = self.extract_attention_trace(layer_idx, outputs.attentions) traces.append(trace) await self.broadcast_trace(trace) except Exception as e: logger.warning(f"Failed to extract attention trace from layer {layer_idx}: {e}") # Extract activation traces periodically (not every token to avoid overflow) if outputs.hidden_states and len(outputs.hidden_states) > 0 and np.random.random() < 0.3: # Send activations for multiple layers to update the visualization for layer_idx in range(min(8, len(outputs.hidden_states))): try: trace = self.extract_activation_trace(layer_idx, outputs.hidden_states[layer_idx]) await self.broadcast_trace(trace) except Exception as e: logger.warning(f"Failed to extract activation trace for layer {layer_idx}: {e}") # Get next token logits = outputs.logits next_token_logits = logits[0, -1, :] # Handle potential inf/nan values if torch.isnan(next_token_logits).any() or torch.isinf(next_token_logits).any(): next_token_logits = torch.nan_to_num(next_token_logits, nan=0.0, posinf=10.0, neginf=-10.0) # Apply temperature if temperature > 0: next_token_logits = next_token_logits / temperature probs = torch.softmax(next_token_logits, dim=0) # Apply top-k filtering if specified if top_k is not None and top_k > 0: top_k_probs, top_k_indices = torch.topk(probs, min(top_k, probs.shape[0])) probs_filtered = torch.zeros_like(probs) probs_filtered[top_k_indices] = top_k_probs probs_filtered = probs_filtered / probs_filtered.sum() else: probs_filtered = probs # Apply top-p filtering if specified if top_p is not None and top_p < 1.0: sorted_probs, sorted_indices = torch.sort(probs_filtered, descending=True) cumulative_probs = torch.cumsum(sorted_probs, dim=0) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() sorted_indices_to_remove[0] = False indices_to_remove = sorted_indices[sorted_indices_to_remove] probs_filtered[indices_to_remove] = 0 probs_filtered = probs_filtered / probs_filtered.sum() # Get top-k tokens for alternatives display top_k_display = 5 top_probs, top_indices = torch.topk(probs, min(top_k_display, probs.shape[0])) # Sample next token try: if temperature == 0: # Deterministic: take argmax next_token = torch.argmax(probs_filtered, dim=-1).unsqueeze(0) else: next_token = torch.multinomial(probs_filtered, 1) except RuntimeError as e: logger.warning(f"Sampling failed, using argmax: {e}") next_token = torch.argmax(probs_filtered, dim=-1).unsqueeze(0) generated_tokens.append(next_token.item()) token_probs.append(float(probs_filtered[next_token.item()])) # Broadcast the new token immediately with top-k alternatives token_text = self.tokenizer.decode([next_token.item()], skip_special_tokens=True) token_strings.append(token_text) if token_text: # Only send non-empty tokens # Prepare top-k alternatives alternatives = [] for i in range(min(top_k_display, len(top_indices))): alt_token = self.tokenizer.decode([top_indices[i].item()], skip_special_tokens=True) alternatives.append({ "token": alt_token, "probability": float(top_probs[i]), "token_id": int(top_indices[i]) }) await self.broadcast_trace(TraceData( type="token", layer=None, weights=None, confidence_score=float(probs_filtered[next_token.item()]), timestamp=datetime.now().timestamp() )) # Send enhanced token data with alternatives await self.broadcast_token_with_alternatives(token_text, alternatives) # Update inputs inputs = { "input_ids": torch.cat([inputs["input_ids"], next_token.unsqueeze(0)], dim=1), "attention_mask": torch.cat([inputs["attention_mask"], torch.ones((1, 1)).to(self.device)], dim=1) } # Check for end of sequence if next_token.item() == self.tokenizer.eos_token_id: break # Calculate final confidence confidence_trace = self.calculate_confidence(logits) traces.append(confidence_trace) await self.broadcast_trace(confidence_trace) # Decode generated text generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) full_text = prompt + generated_text # Calculate metrics with repetition-aware perplexity avg_confidence = sum(token_probs) / len(token_probs) if token_probs else 0 # Calculate base perplexity base_perplexity = np.exp(-np.mean(np.log(np.array(token_probs) + 1e-10))) if token_probs else 1.0 # Detect repetitions and adjust perplexity repetition_factor = 1.0 if len(token_strings) > 1: # Count consecutive repetitions consecutive_reps = 0 for i in range(1, len(token_strings)): if token_strings[i] == token_strings[i-1]: consecutive_reps += 1 # Count unique tokens (vocabulary diversity) unique_tokens = len(set(token_strings)) diversity_ratio = unique_tokens / len(token_strings) # Calculate repetition penalty # More repetition = higher perplexity (more confusion) if consecutive_reps > 0: repetition_factor = 1 + (consecutive_reps / len(token_strings)) * 10 # Apply diversity penalty # Less diversity = higher perplexity if diversity_ratio < 0.5: # Less than 50% unique tokens diversity_penalty = 2.0 / (diversity_ratio + 0.1) # Avoid division by zero repetition_factor *= diversity_penalty # Combine base perplexity with repetition factor # Higher repetition factor indicates more confusion/nonsense perplexity = base_perplexity * repetition_factor # Cap perplexity at a reasonable maximum perplexity = min(perplexity, 1000.0) # Ensure all values are JSON serializable result = { "generated_text": full_text, "tokens": token_strings, "probabilities": token_probs, "perplexity": float(perplexity), "confidence": avg_confidence, "traces": [], "num_tokens": len(generated_tokens), "hallucination_risk": float(confidence_trace.hallucination_risk) if np.isfinite(confidence_trace.hallucination_risk) else 0.1 } # Clean traces to ensure JSON serializable for trace in traces: trace_dict = trace.dict() # Clean any float values in the trace for key, value in trace_dict.items(): if isinstance(value, float): if not np.isfinite(value): trace_dict[key] = 0.0 else: trace_dict[key] = float(value) result["traces"].append(trace_dict) return result except Exception as e: logger.error(f"Generation error: {e}") logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) async def broadcast_trace(self, trace: TraceData): """Send trace to all connected WebSocket clients""" disconnected = [] for client in self.websocket_clients: try: await client.send_json(trace.dict()) except: disconnected.append(client) # Remove disconnected clients for client in disconnected: if client in self.websocket_clients: self.websocket_clients.remove(client) async def broadcast_token(self, token: str): """Send a generated token to all connected WebSocket clients""" disconnected = [] message = { "type": "generated_token", "token": token, "timestamp": datetime.now().timestamp() } for client in self.websocket_clients: try: await client.send_json(message) except: disconnected.append(client) # Remove disconnected clients for client in disconnected: if client in self.websocket_clients: self.websocket_clients.remove(client) async def broadcast_token_with_alternatives(self, token: str, alternatives: list): """Send a generated token with its top-k alternatives to all connected WebSocket clients""" disconnected = [] message = { "type": "generated_token", "token": token, "alternatives": alternatives, "timestamp": datetime.now().timestamp() } for client in self.websocket_clients: try: await client.send_json(message) except: disconnected.append(client) # Remove disconnected clients for client in disconnected: if client in self.websocket_clients: self.websocket_clients.remove(client) # Initialize model manager manager = ModelManager() # Startup event @app.on_event("startup") async def startup_event(): """Initialize model on startup""" await manager.initialize() # WebSocket endpoint for real-time traces @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): """WebSocket connection for streaming traces""" await websocket.accept() manager.websocket_clients.append(websocket) logger.info(f"WebSocket client connected. Total clients: {len(manager.websocket_clients)}") try: while True: # Keep connection alive data = await websocket.receive_text() if data == "ping": await websocket.send_text("pong") except WebSocketDisconnect: manager.websocket_clients.remove(websocket) logger.info(f"WebSocket client disconnected. Total clients: {len(manager.websocket_clients)}") # HTTP endpoints @app.get("/") async def root(): """Health check endpoint""" return { "service": "Visualisable.ai Model Service", "status": "running", "model_loaded": manager.model is not None } @app.get("/health") async def health(): """Detailed health check""" return { "status": "healthy" if manager.model else "initializing", "model_loaded": manager.model is not None, "device": str(manager.device) if manager.device else "not set", "websocket_clients": len(manager.websocket_clients), "timestamp": datetime.now().isoformat() } @app.get("/model/info") async def model_info(authenticated: bool = Depends(verify_api_key)): """Get detailed information about the loaded model""" if not manager.model: raise HTTPException(status_code=503, detail="Model not loaded") config = manager.model.config # Calculate total parameters total_params = sum(p.numel() for p in manager.model.parameters()) trainable_params = sum(p.numel() for p in manager.model.parameters() if p.requires_grad) return { "name": "Salesforce/codegen-350M-mono", "type": config.model_type, "totalParams": total_params, "trainableParams": trainable_params, "layers": config.n_layer, "heads": config.n_head, "hiddenSize": config.n_embd, "vocabSize": config.vocab_size, "maxPositions": config.n_positions, "architecture": manager.model.__class__.__name__, "device": str(manager.device), "dtype": str(next(manager.model.parameters()).dtype), "accessible": [ f"Token probabilities (all {config.vocab_size})", f"Attention weights ({config.n_layer} layers × {config.n_head} heads = {config.n_layer * config.n_head} patterns)", f"Hidden states (all {config.n_layer} layers)", "Logits before softmax", "Token embeddings", "Position embeddings (RoPE)", "Feed-forward activations", "Layer normalizations", "Gradient information (when available)", "Activation functions (GELU)" ], "config": { "activation_function": config.activation_function, "layer_norm_epsilon": config.layer_norm_epsilon, "tie_word_embeddings": config.tie_word_embeddings, "rotary_dim": config.rotary_dim if hasattr(config, 'rotary_dim') else None, "use_cache": config.use_cache } } @app.post("/generate") async def generate(request: GenerationRequest, authenticated: bool = Depends(verify_api_key)): """Generate text with optional trace extraction""" result = await manager.generate_with_traces( prompt=request.prompt, max_tokens=request.max_tokens, temperature=request.temperature, top_k=request.top_k, top_p=request.top_p, sampling_rate=request.sampling_rate if request.extract_traces else 0 ) return result @app.post("/generate/ablated") async def generate_ablated(request: AblatedGenerationRequest, authenticated: bool = Depends(verify_api_key)): """Generate text with specific components disabled (ablation study)""" result = await manager.generate_with_ablation( prompt=request.prompt, max_tokens=request.max_tokens, temperature=request.temperature, top_k=request.top_k, top_p=request.top_p, disabled_components=request.disabled_components ) return result @app.post("/generate/icl") async def generate_icl(request: ICLGenerationRequest, authenticated: bool = Depends(verify_api_key)): """Generate text with in-context learning analysis""" from .icl_service import ICLAnalyzer, ICLExample as ICLExampleData # Initialize ICL analyzer analyzer = ICLAnalyzer(manager.model, manager.tokenizer) # Convert request examples to ICLExample format examples = [ICLExampleData(input=ex.input, output=ex.output) for ex in request.examples] # Analyze generation with examples result = analyzer.analyze_generation( examples=examples, test_prompt=request.prompt, max_length=request.max_tokens, temperature=request.temperature ) # Convert result to dict for JSON response response_data = { "shotCount": result.shot_count, "generatedCode": result.generated_code, "tokens": result.tokens, "confidenceScores": result.confidence_scores, "attentionFromExamples": result.attention_from_examples, "perplexity": result.perplexity, "avgConfidence": result.avg_confidence, "exampleInfluences": result.example_influences, "hiddenStateDrift": result.hidden_state_drift } # Add ICL emergence data if available if result.icl_emergence: response_data["iclEmergence"] = { "emergenceDetected": result.icl_emergence.emergence_detected, "emergenceToken": result.icl_emergence.emergence_token, "emergenceLayer": result.icl_emergence.emergence_layer, "confidence": result.icl_emergence.confidence, "inductionHeads": [ { "layer": h.layer, "head": h.head, "strength": h.strength, "patternType": h.pattern_type, "emergencePoint": h.emergence_point } for h in result.icl_emergence.induction_heads ], "attentionEntropyDrop": result.icl_emergence.attention_entropy_drop, "patternConsistency": result.icl_emergence.pattern_consistency } return response_data @app.post("/analyze/pipeline") async def analyze_pipeline(request: Dict[str, Any], authenticated: bool = Depends(verify_api_key)): """Analyze the complete transformer pipeline step by step""" from .pipeline_analyzer import TransformerPipelineAnalyzer try: # Initialize pipeline analyzer analyzer = TransformerPipelineAnalyzer(manager.model, manager.tokenizer) # Get parameters from request text = request.get("text", "def fibonacci(n):\n if n <= 1:\n return n") max_tokens = request.get("max_tokens", 1) temperature = request.get("temperature", 0.7) top_k = request.get("top_k", 50) top_p = request.get("top_p", 0.95) # Analyze the pipeline with generation parameters result = analyzer.analyze_pipeline( text, max_new_tokens=max_tokens, temperature=temperature, top_k=top_k, top_p=top_p ) # Convert pipeline steps to dict format from dataclasses import asdict pipelines_dict = [] for pipeline in result['pipelines']: pipeline_dict = [asdict(step) for step in pipeline] pipelines_dict.append(pipeline_dict) # For backward compatibility, if only 1 token, return old format if max_tokens == 1 and len(pipelines_dict) > 0: response_data = { "steps": pipelines_dict[0], "total_steps": len(pipelines_dict[0]), "model_name": manager.model_name, "input_text": text, # Also include multi-token format "tokens": result['tokens'], "pipelines": pipelines_dict, "final_text": result['final_text'] } else: response_data = { "tokens": result['tokens'], "pipelines": pipelines_dict, "final_text": result['final_text'], "num_tokens": result['num_tokens'], "total_steps": len(pipelines_dict[0]) if pipelines_dict else 0, "model_name": manager.model_name, "input_text": text } logger.info(f"Pipeline analysis complete: {result['num_tokens']} tokens, {len(pipelines_dict[0]) if pipelines_dict else 0} steps per token") return response_data except Exception as e: logger.error(f"Pipeline analysis error: {str(e)}") logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @app.post("/analyze/attention") async def analyze_attention(request: Dict[str, Any], authenticated: bool = Depends(verify_api_key)): """Analyze attention mechanism with Q, K, V extraction""" from .qkv_extractor import QKVExtractor # Initialize QKV extractor extractor = QKVExtractor(manager.model, manager.tokenizer) # Extract attention data text = request.get("text", "def fibonacci(n):\n if n <= 1:\n return n") analysis = extractor.extract_attention_data(text) # Convert to response format response_data = { "tokens": analysis.tokens, "tokenIds": analysis.token_ids, "layerCount": analysis.layer_count, "headCount": analysis.head_count, "sequenceLength": analysis.sequence_length, "modelDimension": analysis.model_dimension, "qkvData": [], "tokenEmbeddings": [], "attentionFlow": [] } # Process QKV data for specific layers/heads to avoid overwhelming the frontend # Sample every 4th layer (we already sampled every 4th head in the extractor) for qkv in analysis.qkv_data: if qkv.layer % 4 == 0: response_data["qkvData"].append({ "layer": qkv.layer, "head": qkv.head, "query": qkv.query.tolist(), "key": qkv.key.tolist(), "value": qkv.value.tolist(), "attentionScoresRaw": qkv.attention_scores_raw.tolist(), "attentionWeights": qkv.attention_weights.tolist(), "headDim": qkv.head_dim }) # Process token embeddings for emb in analysis.token_embeddings: # Only include embeddings for every 4th layer to reduce data size if emb.layer % 4 == 0: response_data["tokenEmbeddings"].append({ "token": emb.token, "tokenId": emb.token_id, "position": emb.position, "layer": emb.layer, "embedding2D": emb.embedding_2d, "embedding3D": emb.embedding_3d }) # Get attention flow for the first token as an example if len(analysis.tokens) > 0: flow = extractor.get_attention_flow(analysis, source_token=0) response_data["attentionFlow"] = flow # Add positional encodings if available if analysis.positional_encodings is not None: response_data["positionalEncodings"] = analysis.positional_encodings.tolist() return response_data @app.get("/demos") async def list_demos(authenticated: bool = Depends(verify_api_key)): """List available demo prompts""" return { "demos": [ { "id": "fibonacci", "name": "Fibonacci Function", "prompt": "def fibonacci(n):\n '''Calculate fibonacci number'''", "description": "Generate a recursive fibonacci implementation" }, { "id": "quicksort", "name": "Quicksort Algorithm", "prompt": "def quicksort(arr):\n '''Sort array using quicksort'''", "description": "Generate a quicksort implementation" }, { "id": "stack", "name": "Stack Class", "prompt": "class Stack:\n '''Simple stack implementation'''", "description": "Generate a stack data structure" }, { "id": "binary_search", "name": "Binary Search", "prompt": "def binary_search(arr, target):\n '''Find target in sorted array'''", "description": "Generate a binary search function" } ] } @app.post("/demos/run") async def run_demo(request: DemoRequest, authenticated: bool = Depends(verify_api_key)): """Run a specific demo""" demos = { "fibonacci": "def fibonacci(n):\n '''Calculate fibonacci number'''", "quicksort": "def quicksort(arr):\n '''Sort array using quicksort'''", "stack": "class Stack:\n '''Simple stack implementation'''", "binary_search": "def binary_search(arr, target):\n '''Find target in sorted array'''" } if request.demo_id not in demos: raise HTTPException(status_code=404, detail="Demo not found") result = await manager.generate_with_traces( prompt=demos[request.demo_id], max_tokens=100, temperature=0.7, sampling_rate=0.3 # Same as regular generation for better visualization ) return result if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)