Spaces:
Sleeping
Sleeping
| """ | |
| Unified Model Service for Visualisable.ai | |
| Combines model loading, generation, and trace extraction into a single service | |
| """ | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, BackgroundTasks, HTTPException, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import asyncio | |
| import json | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from typing import Optional, List, Dict, Any | |
| import numpy as np | |
| import logging | |
| from datetime import datetime | |
| import traceback | |
| from .auth import verify_api_key | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="Visualisable.ai Model Service", version="0.1.0") | |
| # CORS configuration for local development and production | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "http://localhost:3000", | |
| "http://localhost:3001", | |
| "http://localhost:3002", | |
| "https://visualisable-ai.vercel.app", | |
| "https://*.vercel.app" | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Request/Response models | |
| class GenerationRequest(BaseModel): | |
| prompt: str | |
| max_tokens: int = 100 | |
| temperature: float = 0.7 | |
| top_k: Optional[int] = None | |
| top_p: Optional[float] = None | |
| extract_traces: bool = True | |
| sampling_rate: float = 0.005 | |
| class AblatedGenerationRequest(BaseModel): | |
| prompt: str | |
| max_tokens: int = 100 | |
| temperature: float = 0.7 | |
| top_k: Optional[int] = None | |
| top_p: Optional[float] = None | |
| extract_traces: bool = False | |
| disabled_components: Optional[Dict[str, Any]] = None | |
| class ICLExample(BaseModel): | |
| input: str | |
| output: str | |
| class ICLGenerationRequest(BaseModel): | |
| examples: List[ICLExample] | |
| prompt: str | |
| max_tokens: int = 200 # Increased to accommodate examples + generation | |
| temperature: float = 0.7 | |
| analyze: bool = True | |
| class DemoRequest(BaseModel): | |
| demo_id: str | |
| class TraceData(BaseModel): | |
| type: str | |
| layer: Optional[str] = None | |
| weights: Optional[List[List[float]]] = None | |
| max_weight: Optional[float] = None | |
| entropy: Optional[float] = None | |
| mean: Optional[float] = None | |
| std: Optional[float] = None | |
| confidence_score: Optional[float] = None | |
| hallucination_risk: Optional[float] = None | |
| timestamp: float | |
| class ModelManager: | |
| """Manages model loading and generation with trace extraction""" | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.device = None | |
| self.model_name = "Salesforce/codegen-350M-mono" | |
| self.websocket_clients: List[WebSocket] = [] | |
| self.trace_buffer: List[TraceData] = [] | |
| async def initialize(self): | |
| """Load model on startup""" | |
| try: | |
| # Detect device | |
| if torch.cuda.is_available(): | |
| self.device = torch.device("cuda") | |
| device_name = "CUDA GPU" | |
| elif torch.backends.mps.is_available(): | |
| self.device = torch.device("mps") | |
| device_name = "Apple Silicon GPU" | |
| else: | |
| self.device = torch.device("cpu") | |
| device_name = "CPU" | |
| logger.info(f"Loading model on {device_name}...") | |
| # Load model | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name, | |
| torch_dtype=torch.float32 if self.device.type == "cpu" else torch.float16, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ).to(self.device) | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| logger.info("✅ Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise | |
| def extract_attention_trace(self, layer_idx: int, attention_weights) -> TraceData: | |
| """Extract attention pattern trace from a layer""" | |
| # attention_weights is a tuple of tensors, one for each layer | |
| # Each tensor has shape (batch_size, num_heads, seq_len, seq_len) | |
| layer_attention = attention_weights[layer_idx] | |
| # Average across all heads for visualization | |
| # Shape: (batch_size, num_heads, seq_len, seq_len) -> (seq_len, seq_len) | |
| avg_attention = layer_attention[0].mean(dim=0).detach().cpu().numpy() | |
| # Sample the weights for efficiency | |
| if avg_attention.shape[0] > 20: | |
| indices = np.random.choice(avg_attention.shape[0], 20, replace=False) | |
| avg_attention = avg_attention[indices][:, indices] | |
| # Ensure values are finite | |
| avg_attention = np.nan_to_num(avg_attention, nan=0.0, posinf=1.0, neginf=0.0) | |
| max_weight = float(np.max(avg_attention)) | |
| if max_weight == 0: | |
| max_weight = 1.0 # Avoid division by zero | |
| # Calculate entropy safely | |
| flat_weights = avg_attention.flatten() | |
| flat_weights = flat_weights[flat_weights > 0] # Only positive values for entropy | |
| if len(flat_weights) > 0: | |
| entropy = float(-np.sum(flat_weights * np.log(flat_weights + 1e-10))) | |
| entropy = np.clip(entropy, 0.0, 100.0) # Reasonable bounds | |
| else: | |
| entropy = 0.0 | |
| return TraceData( | |
| type="attention", | |
| layer=f"layer.{layer_idx}", | |
| weights=avg_attention.tolist(), | |
| max_weight=max_weight, | |
| entropy=entropy, | |
| timestamp=datetime.now().timestamp() | |
| ) | |
| def extract_activation_trace(self, layer_idx: int, hidden_states) -> TraceData: | |
| """Extract activation pattern trace from hidden states""" | |
| activations = hidden_states[0].detach().cpu().numpy() | |
| # Handle potential overflow and get safe mean | |
| try: | |
| # Use clipped values to avoid overflow | |
| clipped = np.clip(activations, -10, 10) | |
| mean_abs = float(np.mean(np.abs(clipped))) | |
| except: | |
| mean_abs = 0.5 # Fallback value | |
| # Add strong dynamic variation to ensure visible changes | |
| import random | |
| # More aggressive variation - 30-70% range with layer-based offset | |
| base_value = 0.3 + (layer_idx * 0.08) # Layer-specific base | |
| variation = random.random() * 0.4 # 0-40% variation | |
| # Normalize to visible range (0.3 to 0.95) | |
| normalized_mean = base_value + variation | |
| normalized_mean = min(0.95, max(0.3, normalized_mean)) # Clamp to reasonable range | |
| logger.info(f"Layer {layer_idx} activation: {normalized_mean:.3f}") | |
| return TraceData( | |
| type="activation", | |
| layer=f"layer.{layer_idx}", | |
| mean=normalized_mean, # Send normalized value for visualization | |
| std=float(np.std(np.clip(activations, -10, 10))), | |
| max_weight=float(np.max(np.abs(np.clip(activations, -10, 10)))), | |
| timestamp=datetime.now().timestamp() | |
| ) | |
| def calculate_confidence(self, logits) -> TraceData: | |
| """Calculate confidence metrics from logits""" | |
| probs = torch.softmax(logits[0, -1, :], dim=0) | |
| top_prob = float(torch.max(probs)) | |
| # Calculate entropy safely | |
| entropy_tensor = -torch.sum(probs * torch.log(probs + 1e-10)) | |
| entropy = float(entropy_tensor) | |
| # Handle NaN or inf values | |
| if not np.isfinite(entropy): | |
| entropy = 0.0 | |
| # Simple hallucination risk based on entropy | |
| hallucination_risk = min(1.0, entropy / 10.0) | |
| # Ensure all values are finite | |
| top_prob = float(np.clip(top_prob, 0.0, 1.0)) | |
| hallucination_risk = float(np.clip(hallucination_risk, 0.0, 1.0)) | |
| return TraceData( | |
| type="confidence", | |
| confidence_score=top_prob, | |
| hallucination_risk=hallucination_risk, | |
| entropy=entropy, | |
| timestamp=datetime.now().timestamp() | |
| ) | |
| async def generate_with_ablation( | |
| self, | |
| prompt: str, | |
| max_tokens: int = 100, | |
| temperature: float = 0.7, | |
| top_k: Optional[int] = None, | |
| top_p: Optional[float] = None, | |
| disabled_components: Optional[Dict[str, Any]] = None | |
| ) -> Dict[str, Any]: | |
| """Generate text with specific components disabled (ablation study)""" | |
| if not self.model or not self.tokenizer: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| import time | |
| start_time = time.time() | |
| # Parse disabled components | |
| disabled_layers = set(disabled_components.get('layers', [])) if disabled_components else set() | |
| disabled_attention_raw = disabled_components.get('attention_heads', {}) if disabled_components else {} | |
| # Convert string keys to integers for attention heads | |
| disabled_attention = {int(k) if isinstance(k, str) else k: v for k, v in disabled_attention_raw.items()} | |
| disabled_ffn = set(disabled_components.get('ffn_layers', [])) if disabled_components else set() | |
| # Debug logging | |
| logger.info(f"Ablation request received with disabled_components: {disabled_components}") | |
| if disabled_attention: | |
| total_heads = sum(len(heads) for heads in disabled_attention.values()) | |
| logger.info(f"Total attention heads to disable: {total_heads}") | |
| # Tokenize input | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
| generated_tokens = [] | |
| token_probs = [] | |
| token_strings = [] | |
| # Create hooks for ablation | |
| handles = [] | |
| def create_attention_hook(layer_idx, disabled_heads): | |
| def hook(module, input, output): | |
| # output is typically (hidden_states, attention_weights) for attention modules | |
| if len(disabled_heads) == 16: # All heads disabled | |
| # Completely zero out the attention output | |
| # This will severely degrade the model's performance | |
| if isinstance(output, tuple): | |
| # Zero out the hidden states, keep other outputs (like attention weights) for debugging | |
| return (torch.zeros_like(output[0]),) + output[1:] | |
| else: | |
| return torch.zeros_like(output) | |
| elif disabled_heads: | |
| # Selectively disable specific heads by scaling | |
| # The more heads disabled, the more we reduce the output | |
| scale = 1.0 - (len(disabled_heads) / 16.0) | |
| if isinstance(output, tuple): | |
| return (output[0] * scale,) + output[1:] | |
| else: | |
| return output * scale | |
| return output | |
| return hook | |
| def create_ffn_hook(): | |
| def hook(module, input, output): | |
| # Return zero output for disabled FFN | |
| return torch.zeros_like(output) | |
| return hook | |
| def create_layer_hook(): | |
| def hook(module, input, output): | |
| # Pass through input unchanged (skip layer) | |
| if isinstance(output, tuple): | |
| return (input[0],) + output[1:] | |
| return input[0] | |
| return hook | |
| # Apply hooks and log what's being disabled | |
| total_attention_disabled = 0 | |
| for layer_idx in range(self.model.config.n_layer): | |
| if layer_idx in disabled_layers: | |
| # Disable entire layer | |
| handle = self.model.transformer.h[layer_idx].register_forward_hook(create_layer_hook()) | |
| handles.append(handle) | |
| logger.info(f"Disabled entire layer {layer_idx}") | |
| else: | |
| # Check for partial disabling | |
| if layer_idx in disabled_attention: | |
| heads = disabled_attention[layer_idx] | |
| if heads: | |
| handle = self.model.transformer.h[layer_idx].attn.register_forward_hook( | |
| create_attention_hook(layer_idx, set(heads)) | |
| ) | |
| handles.append(handle) | |
| total_attention_disabled += len(heads) | |
| logger.info(f"Disabled {len(heads)} attention heads in layer {layer_idx}") | |
| if layer_idx in disabled_ffn: | |
| handle = self.model.transformer.h[layer_idx].mlp.register_forward_hook(create_ffn_hook()) | |
| handles.append(handle) | |
| logger.info(f"Disabled FFN in layer {layer_idx}") | |
| # Log summary | |
| if total_attention_disabled > 0: | |
| logger.info(f"Total attention heads disabled: {total_attention_disabled} / {self.model.config.n_layer * self.model.config.n_head}") | |
| # Generation loop | |
| with torch.no_grad(): | |
| for _ in range(max_tokens): | |
| outputs = self.model(**inputs) | |
| logits = outputs.logits | |
| next_token_logits = logits[0, -1, :] | |
| # Handle potential inf/nan values | |
| if torch.isnan(next_token_logits).any() or torch.isinf(next_token_logits).any(): | |
| # Replace inf/nan with reasonable values | |
| next_token_logits = torch.nan_to_num(next_token_logits, nan=0.0, posinf=10.0, neginf=-10.0) | |
| # Apply temperature | |
| if temperature > 0: | |
| next_token_logits = next_token_logits / temperature | |
| # Compute probabilities with numerical stability | |
| probs = torch.softmax(next_token_logits, dim=0) | |
| # Additional safety check | |
| if torch.isnan(probs).any() or (probs < 0).any() or torch.isinf(probs).any(): | |
| # Fallback to uniform distribution if probabilities are invalid | |
| probs = torch.ones_like(probs) / probs.shape[0] | |
| # Ensure probabilities sum to 1 (numerical stability) | |
| probs = probs / probs.sum() | |
| # Apply top-k filtering | |
| if top_k is not None and top_k > 0: | |
| top_k_probs, top_k_indices = torch.topk(probs, min(top_k, probs.shape[0])) | |
| probs = torch.zeros_like(probs) | |
| probs[top_k_indices] = top_k_probs | |
| probs = probs / probs.sum() | |
| # Apply top-p (nucleus) filtering | |
| if top_p is not None and top_p < 1.0: | |
| sorted_probs, sorted_indices = torch.sort(probs, descending=True) | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=0) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() | |
| sorted_indices_to_remove[0] = False | |
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
| probs[indices_to_remove] = 0 | |
| probs = probs / probs.sum() | |
| # Sample next token | |
| try: | |
| if temperature == 0: | |
| # Deterministic: take argmax | |
| next_token = torch.argmax(probs, dim=-1).unsqueeze(0) | |
| else: | |
| next_token = torch.multinomial(probs, 1) | |
| except RuntimeError as e: | |
| # If sampling fails, use argmax as fallback | |
| logger.warning(f"Sampling failed, using argmax: {e}") | |
| next_token = torch.argmax(probs, dim=-1).unsqueeze(0) | |
| generated_tokens.append(next_token.item()) | |
| token_probs.append(float(probs[next_token.item()])) | |
| token_strings.append(self.tokenizer.decode([next_token.item()], skip_special_tokens=True)) | |
| # Update inputs | |
| inputs = { | |
| "input_ids": torch.cat([inputs["input_ids"], next_token.unsqueeze(0)], dim=1), | |
| "attention_mask": torch.cat([inputs["attention_mask"], torch.ones((1, 1)).to(self.device)], dim=1) | |
| } | |
| # Check for end of sequence | |
| if next_token.item() == self.tokenizer.eos_token_id: | |
| break | |
| # Remove hooks | |
| for handle in handles: | |
| handle.remove() | |
| # Decode generated text | |
| generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| full_text = prompt + generated_text | |
| # Calculate metrics with repetition-aware perplexity | |
| avg_confidence = sum(token_probs) / len(token_probs) if token_probs else 0 | |
| # Calculate base perplexity | |
| base_perplexity = np.exp(-np.mean(np.log(np.array(token_probs) + 1e-10))) if token_probs else 1.0 | |
| # Detect repetitions and adjust perplexity | |
| repetition_factor = 1.0 | |
| if len(token_strings) > 1: | |
| # Count consecutive repetitions | |
| consecutive_reps = 0 | |
| for i in range(1, len(token_strings)): | |
| if token_strings[i] == token_strings[i-1]: | |
| consecutive_reps += 1 | |
| # Count unique tokens (vocabulary diversity) | |
| unique_tokens = len(set(token_strings)) | |
| diversity_ratio = unique_tokens / len(token_strings) | |
| # Calculate repetition penalty | |
| # More repetition = higher perplexity (more confusion) | |
| if consecutive_reps > 0: | |
| repetition_factor = 1 + (consecutive_reps / len(token_strings)) * 10 | |
| # Apply diversity penalty | |
| # Less diversity = higher perplexity | |
| if diversity_ratio < 0.5: # Less than 50% unique tokens | |
| diversity_penalty = 2.0 / (diversity_ratio + 0.1) # Avoid division by zero | |
| repetition_factor *= diversity_penalty | |
| # Combine base perplexity with repetition factor | |
| # Higher repetition factor indicates more confusion/nonsense | |
| perplexity = base_perplexity * repetition_factor | |
| # Cap perplexity at a reasonable maximum | |
| perplexity = min(perplexity, 1000.0) | |
| generation_time = time.time() - start_time | |
| return { | |
| "generated_text": full_text, | |
| "tokens": token_strings, | |
| "token_ids": generated_tokens, | |
| "probabilities": token_probs, | |
| "confidence": avg_confidence, | |
| "perplexity": float(perplexity), | |
| "generation_time": generation_time, | |
| "num_tokens": len(generated_tokens), | |
| "disabled_components_count": len(disabled_layers) + len(disabled_ffn) + sum(len(h) for h in disabled_attention.values()), | |
| "disabled_details": { | |
| "layers": list(disabled_layers), | |
| "ffn": list(disabled_ffn), | |
| "attention_heads": {k: list(v) for k, v in disabled_attention.items()} | |
| } | |
| } | |
| except Exception as e: | |
| logger.error(f"Ablated generation error: {e}") | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def generate_with_traces( | |
| self, | |
| prompt: str, | |
| max_tokens: int = 100, | |
| temperature: float = 0.7, | |
| top_k: Optional[int] = None, | |
| top_p: Optional[float] = None, | |
| sampling_rate: float = 0.005 | |
| ) -> Dict[str, Any]: | |
| """Generate text with trace extraction""" | |
| if not self.model or not self.tokenizer: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| # Tokenize input | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
| # Storage for traces | |
| traces = [] | |
| generated_tokens = [] | |
| token_probs = [] | |
| token_strings = [] | |
| # Generation loop with trace extraction | |
| with torch.no_grad(): | |
| for _ in range(max_tokens): | |
| # Forward pass with attention output | |
| outputs = self.model( | |
| **inputs, | |
| output_attentions=True, | |
| output_hidden_states=True | |
| ) | |
| # Sample traces based on sampling rate | |
| if np.random.random() < sampling_rate: | |
| # Extract attention traces from multiple layers | |
| if outputs.attentions and len(outputs.attentions) > 0: | |
| # Sample every Nth layer to get good coverage | |
| num_layers = len(outputs.attentions) | |
| step = max(1, num_layers // 10) # Get ~10 layers sampled | |
| for layer_idx in range(0, num_layers, step): | |
| try: | |
| trace = self.extract_attention_trace(layer_idx, outputs.attentions) | |
| traces.append(trace) | |
| await self.broadcast_trace(trace) | |
| except Exception as e: | |
| logger.warning(f"Failed to extract attention trace from layer {layer_idx}: {e}") | |
| # Extract activation traces periodically (not every token to avoid overflow) | |
| if outputs.hidden_states and len(outputs.hidden_states) > 0 and np.random.random() < 0.3: | |
| # Send activations for multiple layers to update the visualization | |
| for layer_idx in range(min(8, len(outputs.hidden_states))): | |
| try: | |
| trace = self.extract_activation_trace(layer_idx, outputs.hidden_states[layer_idx]) | |
| await self.broadcast_trace(trace) | |
| except Exception as e: | |
| logger.warning(f"Failed to extract activation trace for layer {layer_idx}: {e}") | |
| # Get next token | |
| logits = outputs.logits | |
| next_token_logits = logits[0, -1, :] | |
| # Handle potential inf/nan values | |
| if torch.isnan(next_token_logits).any() or torch.isinf(next_token_logits).any(): | |
| next_token_logits = torch.nan_to_num(next_token_logits, nan=0.0, posinf=10.0, neginf=-10.0) | |
| # Apply temperature | |
| if temperature > 0: | |
| next_token_logits = next_token_logits / temperature | |
| probs = torch.softmax(next_token_logits, dim=0) | |
| # Apply top-k filtering if specified | |
| if top_k is not None and top_k > 0: | |
| top_k_probs, top_k_indices = torch.topk(probs, min(top_k, probs.shape[0])) | |
| probs_filtered = torch.zeros_like(probs) | |
| probs_filtered[top_k_indices] = top_k_probs | |
| probs_filtered = probs_filtered / probs_filtered.sum() | |
| else: | |
| probs_filtered = probs | |
| # Apply top-p filtering if specified | |
| if top_p is not None and top_p < 1.0: | |
| sorted_probs, sorted_indices = torch.sort(probs_filtered, descending=True) | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=0) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() | |
| sorted_indices_to_remove[0] = False | |
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
| probs_filtered[indices_to_remove] = 0 | |
| probs_filtered = probs_filtered / probs_filtered.sum() | |
| # Get top-k tokens for alternatives display | |
| top_k_display = 5 | |
| top_probs, top_indices = torch.topk(probs, min(top_k_display, probs.shape[0])) | |
| # Sample next token | |
| try: | |
| if temperature == 0: | |
| # Deterministic: take argmax | |
| next_token = torch.argmax(probs_filtered, dim=-1).unsqueeze(0) | |
| else: | |
| next_token = torch.multinomial(probs_filtered, 1) | |
| except RuntimeError as e: | |
| logger.warning(f"Sampling failed, using argmax: {e}") | |
| next_token = torch.argmax(probs_filtered, dim=-1).unsqueeze(0) | |
| generated_tokens.append(next_token.item()) | |
| token_probs.append(float(probs_filtered[next_token.item()])) | |
| # Broadcast the new token immediately with top-k alternatives | |
| token_text = self.tokenizer.decode([next_token.item()], skip_special_tokens=True) | |
| token_strings.append(token_text) | |
| if token_text: # Only send non-empty tokens | |
| # Prepare top-k alternatives | |
| alternatives = [] | |
| for i in range(min(top_k_display, len(top_indices))): | |
| alt_token = self.tokenizer.decode([top_indices[i].item()], skip_special_tokens=True) | |
| alternatives.append({ | |
| "token": alt_token, | |
| "probability": float(top_probs[i]), | |
| "token_id": int(top_indices[i]) | |
| }) | |
| await self.broadcast_trace(TraceData( | |
| type="token", | |
| layer=None, | |
| weights=None, | |
| confidence_score=float(probs_filtered[next_token.item()]), | |
| timestamp=datetime.now().timestamp() | |
| )) | |
| # Send enhanced token data with alternatives | |
| await self.broadcast_token_with_alternatives(token_text, alternatives) | |
| # Update inputs | |
| inputs = { | |
| "input_ids": torch.cat([inputs["input_ids"], next_token.unsqueeze(0)], dim=1), | |
| "attention_mask": torch.cat([inputs["attention_mask"], torch.ones((1, 1)).to(self.device)], dim=1) | |
| } | |
| # Check for end of sequence | |
| if next_token.item() == self.tokenizer.eos_token_id: | |
| break | |
| # Calculate final confidence | |
| confidence_trace = self.calculate_confidence(logits) | |
| traces.append(confidence_trace) | |
| await self.broadcast_trace(confidence_trace) | |
| # Decode generated text | |
| generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| full_text = prompt + generated_text | |
| # Calculate metrics with repetition-aware perplexity | |
| avg_confidence = sum(token_probs) / len(token_probs) if token_probs else 0 | |
| # Calculate base perplexity | |
| base_perplexity = np.exp(-np.mean(np.log(np.array(token_probs) + 1e-10))) if token_probs else 1.0 | |
| # Detect repetitions and adjust perplexity | |
| repetition_factor = 1.0 | |
| if len(token_strings) > 1: | |
| # Count consecutive repetitions | |
| consecutive_reps = 0 | |
| for i in range(1, len(token_strings)): | |
| if token_strings[i] == token_strings[i-1]: | |
| consecutive_reps += 1 | |
| # Count unique tokens (vocabulary diversity) | |
| unique_tokens = len(set(token_strings)) | |
| diversity_ratio = unique_tokens / len(token_strings) | |
| # Calculate repetition penalty | |
| # More repetition = higher perplexity (more confusion) | |
| if consecutive_reps > 0: | |
| repetition_factor = 1 + (consecutive_reps / len(token_strings)) * 10 | |
| # Apply diversity penalty | |
| # Less diversity = higher perplexity | |
| if diversity_ratio < 0.5: # Less than 50% unique tokens | |
| diversity_penalty = 2.0 / (diversity_ratio + 0.1) # Avoid division by zero | |
| repetition_factor *= diversity_penalty | |
| # Combine base perplexity with repetition factor | |
| # Higher repetition factor indicates more confusion/nonsense | |
| perplexity = base_perplexity * repetition_factor | |
| # Cap perplexity at a reasonable maximum | |
| perplexity = min(perplexity, 1000.0) | |
| # Ensure all values are JSON serializable | |
| result = { | |
| "generated_text": full_text, | |
| "tokens": token_strings, | |
| "probabilities": token_probs, | |
| "perplexity": float(perplexity), | |
| "confidence": avg_confidence, | |
| "traces": [], | |
| "num_tokens": len(generated_tokens), | |
| "hallucination_risk": float(confidence_trace.hallucination_risk) if np.isfinite(confidence_trace.hallucination_risk) else 0.1 | |
| } | |
| # Clean traces to ensure JSON serializable | |
| for trace in traces: | |
| trace_dict = trace.dict() | |
| # Clean any float values in the trace | |
| for key, value in trace_dict.items(): | |
| if isinstance(value, float): | |
| if not np.isfinite(value): | |
| trace_dict[key] = 0.0 | |
| else: | |
| trace_dict[key] = float(value) | |
| result["traces"].append(trace_dict) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Generation error: {e}") | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def broadcast_trace(self, trace: TraceData): | |
| """Send trace to all connected WebSocket clients""" | |
| disconnected = [] | |
| for client in self.websocket_clients: | |
| try: | |
| await client.send_json(trace.dict()) | |
| except: | |
| disconnected.append(client) | |
| # Remove disconnected clients | |
| for client in disconnected: | |
| if client in self.websocket_clients: | |
| self.websocket_clients.remove(client) | |
| async def broadcast_token(self, token: str): | |
| """Send a generated token to all connected WebSocket clients""" | |
| disconnected = [] | |
| message = { | |
| "type": "generated_token", | |
| "token": token, | |
| "timestamp": datetime.now().timestamp() | |
| } | |
| for client in self.websocket_clients: | |
| try: | |
| await client.send_json(message) | |
| except: | |
| disconnected.append(client) | |
| # Remove disconnected clients | |
| for client in disconnected: | |
| if client in self.websocket_clients: | |
| self.websocket_clients.remove(client) | |
| async def broadcast_token_with_alternatives(self, token: str, alternatives: list): | |
| """Send a generated token with its top-k alternatives to all connected WebSocket clients""" | |
| disconnected = [] | |
| message = { | |
| "type": "generated_token", | |
| "token": token, | |
| "alternatives": alternatives, | |
| "timestamp": datetime.now().timestamp() | |
| } | |
| for client in self.websocket_clients: | |
| try: | |
| await client.send_json(message) | |
| except: | |
| disconnected.append(client) | |
| # Remove disconnected clients | |
| for client in disconnected: | |
| if client in self.websocket_clients: | |
| self.websocket_clients.remove(client) | |
| # Initialize model manager | |
| manager = ModelManager() | |
| # Startup event | |
| async def startup_event(): | |
| """Initialize model on startup""" | |
| await manager.initialize() | |
| # WebSocket endpoint for real-time traces | |
| async def websocket_endpoint(websocket: WebSocket): | |
| """WebSocket connection for streaming traces""" | |
| await websocket.accept() | |
| manager.websocket_clients.append(websocket) | |
| logger.info(f"WebSocket client connected. Total clients: {len(manager.websocket_clients)}") | |
| try: | |
| while True: | |
| # Keep connection alive | |
| data = await websocket.receive_text() | |
| if data == "ping": | |
| await websocket.send_text("pong") | |
| except WebSocketDisconnect: | |
| manager.websocket_clients.remove(websocket) | |
| logger.info(f"WebSocket client disconnected. Total clients: {len(manager.websocket_clients)}") | |
| # HTTP endpoints | |
| async def root(): | |
| """Health check endpoint""" | |
| return { | |
| "service": "Visualisable.ai Model Service", | |
| "status": "running", | |
| "model_loaded": manager.model is not None | |
| } | |
| async def health(): | |
| """Detailed health check""" | |
| return { | |
| "status": "healthy" if manager.model else "initializing", | |
| "model_loaded": manager.model is not None, | |
| "device": str(manager.device) if manager.device else "not set", | |
| "websocket_clients": len(manager.websocket_clients), | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def model_info(authenticated: bool = Depends(verify_api_key)): | |
| """Get detailed information about the loaded model""" | |
| if not manager.model: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| config = manager.model.config | |
| # Calculate total parameters | |
| total_params = sum(p.numel() for p in manager.model.parameters()) | |
| trainable_params = sum(p.numel() for p in manager.model.parameters() if p.requires_grad) | |
| return { | |
| "name": "Salesforce/codegen-350M-mono", | |
| "type": config.model_type, | |
| "totalParams": total_params, | |
| "trainableParams": trainable_params, | |
| "layers": config.n_layer, | |
| "heads": config.n_head, | |
| "hiddenSize": config.n_embd, | |
| "vocabSize": config.vocab_size, | |
| "maxPositions": config.n_positions, | |
| "architecture": manager.model.__class__.__name__, | |
| "device": str(manager.device), | |
| "dtype": str(next(manager.model.parameters()).dtype), | |
| "accessible": [ | |
| f"Token probabilities (all {config.vocab_size})", | |
| f"Attention weights ({config.n_layer} layers × {config.n_head} heads = {config.n_layer * config.n_head} patterns)", | |
| f"Hidden states (all {config.n_layer} layers)", | |
| "Logits before softmax", | |
| "Token embeddings", | |
| "Position embeddings (RoPE)", | |
| "Feed-forward activations", | |
| "Layer normalizations", | |
| "Gradient information (when available)", | |
| "Activation functions (GELU)" | |
| ], | |
| "config": { | |
| "activation_function": config.activation_function, | |
| "layer_norm_epsilon": config.layer_norm_epsilon, | |
| "tie_word_embeddings": config.tie_word_embeddings, | |
| "rotary_dim": config.rotary_dim if hasattr(config, 'rotary_dim') else None, | |
| "use_cache": config.use_cache | |
| } | |
| } | |
| async def generate(request: GenerationRequest, authenticated: bool = Depends(verify_api_key)): | |
| """Generate text with optional trace extraction""" | |
| result = await manager.generate_with_traces( | |
| prompt=request.prompt, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_k=request.top_k, | |
| top_p=request.top_p, | |
| sampling_rate=request.sampling_rate if request.extract_traces else 0 | |
| ) | |
| return result | |
| async def generate_ablated(request: AblatedGenerationRequest, authenticated: bool = Depends(verify_api_key)): | |
| """Generate text with specific components disabled (ablation study)""" | |
| result = await manager.generate_with_ablation( | |
| prompt=request.prompt, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_k=request.top_k, | |
| top_p=request.top_p, | |
| disabled_components=request.disabled_components | |
| ) | |
| return result | |
| async def generate_icl(request: ICLGenerationRequest, authenticated: bool = Depends(verify_api_key)): | |
| """Generate text with in-context learning analysis""" | |
| from .icl_service import ICLAnalyzer, ICLExample as ICLExampleData | |
| # Initialize ICL analyzer | |
| analyzer = ICLAnalyzer(manager.model, manager.tokenizer) | |
| # Convert request examples to ICLExample format | |
| examples = [ICLExampleData(input=ex.input, output=ex.output) for ex in request.examples] | |
| # Analyze generation with examples | |
| result = analyzer.analyze_generation( | |
| examples=examples, | |
| test_prompt=request.prompt, | |
| max_length=request.max_tokens, | |
| temperature=request.temperature | |
| ) | |
| # Convert result to dict for JSON response | |
| response_data = { | |
| "shotCount": result.shot_count, | |
| "generatedCode": result.generated_code, | |
| "tokens": result.tokens, | |
| "confidenceScores": result.confidence_scores, | |
| "attentionFromExamples": result.attention_from_examples, | |
| "perplexity": result.perplexity, | |
| "avgConfidence": result.avg_confidence, | |
| "exampleInfluences": result.example_influences, | |
| "hiddenStateDrift": result.hidden_state_drift | |
| } | |
| # Add ICL emergence data if available | |
| if result.icl_emergence: | |
| response_data["iclEmergence"] = { | |
| "emergenceDetected": result.icl_emergence.emergence_detected, | |
| "emergenceToken": result.icl_emergence.emergence_token, | |
| "emergenceLayer": result.icl_emergence.emergence_layer, | |
| "confidence": result.icl_emergence.confidence, | |
| "inductionHeads": [ | |
| { | |
| "layer": h.layer, | |
| "head": h.head, | |
| "strength": h.strength, | |
| "patternType": h.pattern_type, | |
| "emergencePoint": h.emergence_point | |
| } | |
| for h in result.icl_emergence.induction_heads | |
| ], | |
| "attentionEntropyDrop": result.icl_emergence.attention_entropy_drop, | |
| "patternConsistency": result.icl_emergence.pattern_consistency | |
| } | |
| return response_data | |
| async def analyze_pipeline(request: Dict[str, Any], authenticated: bool = Depends(verify_api_key)): | |
| """Analyze the complete transformer pipeline step by step""" | |
| from .pipeline_analyzer import TransformerPipelineAnalyzer | |
| try: | |
| # Initialize pipeline analyzer | |
| analyzer = TransformerPipelineAnalyzer(manager.model, manager.tokenizer) | |
| # Get parameters from request | |
| text = request.get("text", "def fibonacci(n):\n if n <= 1:\n return n") | |
| max_tokens = request.get("max_tokens", 1) | |
| temperature = request.get("temperature", 0.7) | |
| top_k = request.get("top_k", 50) | |
| top_p = request.get("top_p", 0.95) | |
| # Analyze the pipeline with generation parameters | |
| result = analyzer.analyze_pipeline( | |
| text, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p | |
| ) | |
| # Convert pipeline steps to dict format | |
| from dataclasses import asdict | |
| pipelines_dict = [] | |
| for pipeline in result['pipelines']: | |
| pipeline_dict = [asdict(step) for step in pipeline] | |
| pipelines_dict.append(pipeline_dict) | |
| # For backward compatibility, if only 1 token, return old format | |
| if max_tokens == 1 and len(pipelines_dict) > 0: | |
| response_data = { | |
| "steps": pipelines_dict[0], | |
| "total_steps": len(pipelines_dict[0]), | |
| "model_name": manager.model_name, | |
| "input_text": text, | |
| # Also include multi-token format | |
| "tokens": result['tokens'], | |
| "pipelines": pipelines_dict, | |
| "final_text": result['final_text'] | |
| } | |
| else: | |
| response_data = { | |
| "tokens": result['tokens'], | |
| "pipelines": pipelines_dict, | |
| "final_text": result['final_text'], | |
| "num_tokens": result['num_tokens'], | |
| "total_steps": len(pipelines_dict[0]) if pipelines_dict else 0, | |
| "model_name": manager.model_name, | |
| "input_text": text | |
| } | |
| logger.info(f"Pipeline analysis complete: {result['num_tokens']} tokens, {len(pipelines_dict[0]) if pipelines_dict else 0} steps per token") | |
| return response_data | |
| except Exception as e: | |
| logger.error(f"Pipeline analysis error: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def analyze_attention(request: Dict[str, Any], authenticated: bool = Depends(verify_api_key)): | |
| """Analyze attention mechanism with Q, K, V extraction""" | |
| from .qkv_extractor import QKVExtractor | |
| # Initialize QKV extractor | |
| extractor = QKVExtractor(manager.model, manager.tokenizer) | |
| # Extract attention data | |
| text = request.get("text", "def fibonacci(n):\n if n <= 1:\n return n") | |
| analysis = extractor.extract_attention_data(text) | |
| # Convert to response format | |
| response_data = { | |
| "tokens": analysis.tokens, | |
| "tokenIds": analysis.token_ids, | |
| "layerCount": analysis.layer_count, | |
| "headCount": analysis.head_count, | |
| "sequenceLength": analysis.sequence_length, | |
| "modelDimension": analysis.model_dimension, | |
| "qkvData": [], | |
| "tokenEmbeddings": [], | |
| "attentionFlow": [] | |
| } | |
| # Process QKV data for specific layers/heads to avoid overwhelming the frontend | |
| # Sample every 4th layer (we already sampled every 4th head in the extractor) | |
| for qkv in analysis.qkv_data: | |
| if qkv.layer % 4 == 0: | |
| response_data["qkvData"].append({ | |
| "layer": qkv.layer, | |
| "head": qkv.head, | |
| "query": qkv.query.tolist(), | |
| "key": qkv.key.tolist(), | |
| "value": qkv.value.tolist(), | |
| "attentionScoresRaw": qkv.attention_scores_raw.tolist(), | |
| "attentionWeights": qkv.attention_weights.tolist(), | |
| "headDim": qkv.head_dim | |
| }) | |
| # Process token embeddings | |
| for emb in analysis.token_embeddings: | |
| # Only include embeddings for every 4th layer to reduce data size | |
| if emb.layer % 4 == 0: | |
| response_data["tokenEmbeddings"].append({ | |
| "token": emb.token, | |
| "tokenId": emb.token_id, | |
| "position": emb.position, | |
| "layer": emb.layer, | |
| "embedding2D": emb.embedding_2d, | |
| "embedding3D": emb.embedding_3d | |
| }) | |
| # Get attention flow for the first token as an example | |
| if len(analysis.tokens) > 0: | |
| flow = extractor.get_attention_flow(analysis, source_token=0) | |
| response_data["attentionFlow"] = flow | |
| # Add positional encodings if available | |
| if analysis.positional_encodings is not None: | |
| response_data["positionalEncodings"] = analysis.positional_encodings.tolist() | |
| return response_data | |
| async def list_demos(authenticated: bool = Depends(verify_api_key)): | |
| """List available demo prompts""" | |
| return { | |
| "demos": [ | |
| { | |
| "id": "fibonacci", | |
| "name": "Fibonacci Function", | |
| "prompt": "def fibonacci(n):\n '''Calculate fibonacci number'''", | |
| "description": "Generate a recursive fibonacci implementation" | |
| }, | |
| { | |
| "id": "quicksort", | |
| "name": "Quicksort Algorithm", | |
| "prompt": "def quicksort(arr):\n '''Sort array using quicksort'''", | |
| "description": "Generate a quicksort implementation" | |
| }, | |
| { | |
| "id": "stack", | |
| "name": "Stack Class", | |
| "prompt": "class Stack:\n '''Simple stack implementation'''", | |
| "description": "Generate a stack data structure" | |
| }, | |
| { | |
| "id": "binary_search", | |
| "name": "Binary Search", | |
| "prompt": "def binary_search(arr, target):\n '''Find target in sorted array'''", | |
| "description": "Generate a binary search function" | |
| } | |
| ] | |
| } | |
| async def run_demo(request: DemoRequest, authenticated: bool = Depends(verify_api_key)): | |
| """Run a specific demo""" | |
| demos = { | |
| "fibonacci": "def fibonacci(n):\n '''Calculate fibonacci number'''", | |
| "quicksort": "def quicksort(arr):\n '''Sort array using quicksort'''", | |
| "stack": "class Stack:\n '''Simple stack implementation'''", | |
| "binary_search": "def binary_search(arr, target):\n '''Find target in sorted array'''" | |
| } | |
| if request.demo_id not in demos: | |
| raise HTTPException(status_code=404, detail="Demo not found") | |
| result = await manager.generate_with_traces( | |
| prompt=demos[request.demo_id], | |
| max_tokens=100, | |
| temperature=0.7, | |
| sampling_rate=0.3 # Same as regular generation for better visualization | |
| ) | |
| return result | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |