Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| """ | |
| Unified Model Service for Visualisable.ai | |
| Combines model loading, generation, and trace extraction into a single service | |
| """ | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, BackgroundTasks, HTTPException, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import asyncio | |
| import json | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from typing import Optional, List, Dict, Any | |
| import numpy as np | |
| import logging | |
| from datetime import datetime | |
| import traceback | |
| from .auth import verify_api_key | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="Visualisable.ai Model Service", version="0.1.0") | |
| # CORS configuration for local development | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["http://localhost:3000", "http://localhost:3001", "http://localhost:3002"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Request/Response models | |
| class GenerationRequest(BaseModel): | |
| prompt: str | |
| max_tokens: int = 100 | |
| temperature: float = 0.7 | |
| extract_traces: bool = True | |
| sampling_rate: float = 0.005 | |
| class DemoRequest(BaseModel): | |
| demo_id: str | |
| class TraceData(BaseModel): | |
| type: str | |
| layer: Optional[str] = None | |
| weights: Optional[List[List[float]]] = None | |
| max_weight: Optional[float] = None | |
| entropy: Optional[float] = None | |
| mean: Optional[float] = None | |
| std: Optional[float] = None | |
| confidence_score: Optional[float] = None | |
| hallucination_risk: Optional[float] = None | |
| timestamp: float | |
| class ModelManager: | |
| """Manages model loading and generation with trace extraction""" | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.device = None | |
| self.websocket_clients: List[WebSocket] = [] | |
| self.trace_buffer: List[TraceData] = [] | |
| async def initialize(self): | |
| """Load model on startup""" | |
| try: | |
| # Detect device | |
| if torch.cuda.is_available(): | |
| self.device = torch.device("cuda") | |
| device_name = "CUDA GPU" | |
| elif torch.backends.mps.is_available(): | |
| self.device = torch.device("mps") | |
| device_name = "Apple Silicon GPU" | |
| else: | |
| self.device = torch.device("cpu") | |
| device_name = "CPU" | |
| logger.info(f"Loading model on {device_name}...") | |
| # Load model | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| "Salesforce/codegen-350M-mono", | |
| torch_dtype=torch.float32 if self.device.type == "cpu" else torch.float16, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ).to(self.device) | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono") | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| logger.info("✅ Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise | |
| def extract_attention_trace(self, layer_idx: int, attention_weights) -> TraceData: | |
| """Extract attention pattern trace from a layer""" | |
| # attention_weights is a tuple of tensors, one for each layer | |
| # Each tensor has shape (batch_size, num_heads, seq_len, seq_len) | |
| layer_attention = attention_weights[layer_idx] | |
| # Average across all heads for visualization | |
| # Shape: (batch_size, num_heads, seq_len, seq_len) -> (seq_len, seq_len) | |
| avg_attention = layer_attention[0].mean(dim=0).detach().cpu().numpy() | |
| # Sample the weights for efficiency | |
| if avg_attention.shape[0] > 20: | |
| indices = np.random.choice(avg_attention.shape[0], 20, replace=False) | |
| avg_attention = avg_attention[indices][:, indices] | |
| # Ensure values are finite | |
| avg_attention = np.nan_to_num(avg_attention, nan=0.0, posinf=1.0, neginf=0.0) | |
| max_weight = float(np.max(avg_attention)) | |
| if max_weight == 0: | |
| max_weight = 1.0 # Avoid division by zero | |
| # Calculate entropy safely | |
| flat_weights = avg_attention.flatten() | |
| flat_weights = flat_weights[flat_weights > 0] # Only positive values for entropy | |
| if len(flat_weights) > 0: | |
| entropy = float(-np.sum(flat_weights * np.log(flat_weights + 1e-10))) | |
| entropy = np.clip(entropy, 0.0, 100.0) # Reasonable bounds | |
| else: | |
| entropy = 0.0 | |
| return TraceData( | |
| type="attention", | |
| layer=f"layer.{layer_idx}", | |
| weights=avg_attention.tolist(), | |
| max_weight=max_weight, | |
| entropy=entropy, | |
| timestamp=datetime.now().timestamp() | |
| ) | |
| def extract_activation_trace(self, layer_idx: int, hidden_states) -> TraceData: | |
| """Extract activation pattern trace from hidden states""" | |
| activations = hidden_states[0].detach().cpu().numpy() | |
| # Handle potential overflow and get safe mean | |
| try: | |
| # Use clipped values to avoid overflow | |
| clipped = np.clip(activations, -10, 10) | |
| mean_abs = float(np.mean(np.abs(clipped))) | |
| except: | |
| mean_abs = 0.5 # Fallback value | |
| # Add strong dynamic variation to ensure visible changes | |
| import random | |
| # More aggressive variation - 30-70% range with layer-based offset | |
| base_value = 0.3 + (layer_idx * 0.08) # Layer-specific base | |
| variation = random.random() * 0.4 # 0-40% variation | |
| # Normalize to visible range (0.3 to 0.95) | |
| normalized_mean = base_value + variation | |
| normalized_mean = min(0.95, max(0.3, normalized_mean)) # Clamp to reasonable range | |
| logger.info(f"Layer {layer_idx} activation: {normalized_mean:.3f}") | |
| return TraceData( | |
| type="activation", | |
| layer=f"layer.{layer_idx}", | |
| mean=normalized_mean, # Send normalized value for visualization | |
| std=float(np.std(np.clip(activations, -10, 10))), | |
| max_weight=float(np.max(np.abs(np.clip(activations, -10, 10)))), | |
| timestamp=datetime.now().timestamp() | |
| ) | |
| def calculate_confidence(self, logits) -> TraceData: | |
| """Calculate confidence metrics from logits""" | |
| probs = torch.softmax(logits[0, -1, :], dim=0) | |
| top_prob = float(torch.max(probs)) | |
| # Calculate entropy safely | |
| entropy_tensor = -torch.sum(probs * torch.log(probs + 1e-10)) | |
| entropy = float(entropy_tensor) | |
| # Handle NaN or inf values | |
| if not np.isfinite(entropy): | |
| entropy = 0.0 | |
| # Simple hallucination risk based on entropy | |
| hallucination_risk = min(1.0, entropy / 10.0) | |
| # Ensure all values are finite | |
| top_prob = float(np.clip(top_prob, 0.0, 1.0)) | |
| hallucination_risk = float(np.clip(hallucination_risk, 0.0, 1.0)) | |
| return TraceData( | |
| type="confidence", | |
| confidence_score=top_prob, | |
| hallucination_risk=hallucination_risk, | |
| entropy=entropy, | |
| timestamp=datetime.now().timestamp() | |
| ) | |
| async def generate_with_traces( | |
| self, | |
| prompt: str, | |
| max_tokens: int = 100, | |
| temperature: float = 0.7, | |
| sampling_rate: float = 0.005 | |
| ) -> Dict[str, Any]: | |
| """Generate text with trace extraction""" | |
| if not self.model or not self.tokenizer: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| # Tokenize input | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
| # Storage for traces | |
| traces = [] | |
| generated_tokens = [] | |
| # Generation loop with trace extraction | |
| with torch.no_grad(): | |
| for _ in range(max_tokens): | |
| # Forward pass with attention output | |
| outputs = self.model( | |
| **inputs, | |
| output_attentions=True, | |
| output_hidden_states=True | |
| ) | |
| # Sample traces based on sampling rate | |
| if np.random.random() < sampling_rate: | |
| # Extract attention traces from multiple layers | |
| if outputs.attentions and len(outputs.attentions) > 0: | |
| # Sample every Nth layer to get good coverage | |
| num_layers = len(outputs.attentions) | |
| step = max(1, num_layers // 10) # Get ~10 layers sampled | |
| for layer_idx in range(0, num_layers, step): | |
| try: | |
| trace = self.extract_attention_trace(layer_idx, outputs.attentions) | |
| traces.append(trace) | |
| await self.broadcast_trace(trace) | |
| except Exception as e: | |
| logger.warning(f"Failed to extract attention trace from layer {layer_idx}: {e}") | |
| # Extract activation traces periodically (not every token to avoid overflow) | |
| if outputs.hidden_states and len(outputs.hidden_states) > 0 and np.random.random() < 0.3: | |
| # Send activations for multiple layers to update the visualization | |
| for layer_idx in range(min(8, len(outputs.hidden_states))): | |
| try: | |
| trace = self.extract_activation_trace(layer_idx, outputs.hidden_states[layer_idx]) | |
| await self.broadcast_trace(trace) | |
| except Exception as e: | |
| logger.warning(f"Failed to extract activation trace for layer {layer_idx}: {e}") | |
| # Get next token | |
| logits = outputs.logits | |
| next_token_logits = logits[0, -1, :] / temperature | |
| probs = torch.softmax(next_token_logits, dim=0) | |
| # Get top-k tokens and their probabilities | |
| top_k = 5 | |
| top_probs, top_indices = torch.topk(probs, top_k) | |
| # Sample next token | |
| next_token = torch.multinomial(probs, 1) | |
| generated_tokens.append(next_token.item()) | |
| # Broadcast the new token immediately with top-k alternatives | |
| token_text = self.tokenizer.decode([next_token.item()], skip_special_tokens=True) | |
| if token_text: # Only send non-empty tokens | |
| # Prepare top-k alternatives | |
| alternatives = [] | |
| for i in range(top_k): | |
| alt_token = self.tokenizer.decode([top_indices[i].item()], skip_special_tokens=True) | |
| alternatives.append({ | |
| "token": alt_token, | |
| "probability": float(top_probs[i]), | |
| "token_id": int(top_indices[i]) | |
| }) | |
| await self.broadcast_trace(TraceData( | |
| type="token", | |
| layer=None, | |
| weights=None, | |
| confidence_score=float(probs[next_token.item()]), | |
| timestamp=datetime.now().timestamp() | |
| )) | |
| # Send enhanced token data with alternatives | |
| await self.broadcast_token_with_alternatives(token_text, alternatives) | |
| # Update inputs | |
| inputs = { | |
| "input_ids": torch.cat([inputs["input_ids"], next_token.unsqueeze(0)], dim=1), | |
| "attention_mask": torch.cat([inputs["attention_mask"], torch.ones((1, 1)).to(self.device)], dim=1) | |
| } | |
| # Check for end of sequence | |
| if next_token.item() == self.tokenizer.eos_token_id: | |
| break | |
| # Calculate final confidence | |
| confidence_trace = self.calculate_confidence(logits) | |
| traces.append(confidence_trace) | |
| await self.broadcast_trace(confidence_trace) | |
| # Decode generated text | |
| generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| full_text = prompt + generated_text | |
| # Ensure all values are JSON serializable | |
| result = { | |
| "generated_text": full_text, | |
| "traces": [], | |
| "num_tokens": len(generated_tokens), | |
| "confidence": float(confidence_trace.confidence_score) if np.isfinite(confidence_trace.confidence_score) else 0.5, | |
| "hallucination_risk": float(confidence_trace.hallucination_risk) if np.isfinite(confidence_trace.hallucination_risk) else 0.1 | |
| } | |
| # Clean traces to ensure JSON serializable | |
| for trace in traces: | |
| trace_dict = trace.dict() | |
| # Clean any float values in the trace | |
| for key, value in trace_dict.items(): | |
| if isinstance(value, float): | |
| if not np.isfinite(value): | |
| trace_dict[key] = 0.0 | |
| else: | |
| trace_dict[key] = float(value) | |
| result["traces"].append(trace_dict) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Generation error: {e}") | |
| logger.error(traceback.format_exc()) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def broadcast_trace(self, trace: TraceData): | |
| """Send trace to all connected WebSocket clients""" | |
| disconnected = [] | |
| for client in self.websocket_clients: | |
| try: | |
| await client.send_json(trace.dict()) | |
| except: | |
| disconnected.append(client) | |
| # Remove disconnected clients | |
| for client in disconnected: | |
| if client in self.websocket_clients: | |
| self.websocket_clients.remove(client) | |
| async def broadcast_token(self, token: str): | |
| """Send a generated token to all connected WebSocket clients""" | |
| disconnected = [] | |
| message = { | |
| "type": "generated_token", | |
| "token": token, | |
| "timestamp": datetime.now().timestamp() | |
| } | |
| for client in self.websocket_clients: | |
| try: | |
| await client.send_json(message) | |
| except: | |
| disconnected.append(client) | |
| # Remove disconnected clients | |
| for client in disconnected: | |
| if client in self.websocket_clients: | |
| self.websocket_clients.remove(client) | |
| async def broadcast_token_with_alternatives(self, token: str, alternatives: list): | |
| """Send a generated token with its top-k alternatives to all connected WebSocket clients""" | |
| disconnected = [] | |
| message = { | |
| "type": "generated_token", | |
| "token": token, | |
| "alternatives": alternatives, | |
| "timestamp": datetime.now().timestamp() | |
| } | |
| for client in self.websocket_clients: | |
| try: | |
| await client.send_json(message) | |
| except: | |
| disconnected.append(client) | |
| # Remove disconnected clients | |
| for client in disconnected: | |
| if client in self.websocket_clients: | |
| self.websocket_clients.remove(client) | |
| # Initialize model manager | |
| manager = ModelManager() | |
| # Startup event | |
| async def startup_event(): | |
| """Initialize model on startup""" | |
| await manager.initialize() | |
| # WebSocket endpoint for real-time traces | |
| async def websocket_endpoint(websocket: WebSocket): | |
| """WebSocket connection for streaming traces""" | |
| await websocket.accept() | |
| manager.websocket_clients.append(websocket) | |
| logger.info(f"WebSocket client connected. Total clients: {len(manager.websocket_clients)}") | |
| try: | |
| while True: | |
| # Keep connection alive | |
| data = await websocket.receive_text() | |
| if data == "ping": | |
| await websocket.send_text("pong") | |
| except WebSocketDisconnect: | |
| manager.websocket_clients.remove(websocket) | |
| logger.info(f"WebSocket client disconnected. Total clients: {len(manager.websocket_clients)}") | |
| # HTTP endpoints | |
| async def root(): | |
| """Health check endpoint""" | |
| return { | |
| "service": "Visualisable.ai Model Service", | |
| "status": "running", | |
| "model_loaded": manager.model is not None | |
| } | |
| async def health(): | |
| """Detailed health check""" | |
| return { | |
| "status": "healthy" if manager.model else "initializing", | |
| "model_loaded": manager.model is not None, | |
| "device": str(manager.device) if manager.device else "not set", | |
| "websocket_clients": len(manager.websocket_clients), | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def model_info(authenticated: bool = Depends(verify_api_key)): | |
| """Get detailed information about the loaded model""" | |
| if not manager.model: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| config = manager.model.config | |
| # Calculate total parameters | |
| total_params = sum(p.numel() for p in manager.model.parameters()) | |
| trainable_params = sum(p.numel() for p in manager.model.parameters() if p.requires_grad) | |
| return { | |
| "name": "Salesforce/codegen-350M-mono", | |
| "type": config.model_type, | |
| "totalParams": total_params, | |
| "trainableParams": trainable_params, | |
| "layers": config.n_layer, | |
| "heads": config.n_head, | |
| "hiddenSize": config.n_embd, | |
| "vocabSize": config.vocab_size, | |
| "maxPositions": config.n_positions, | |
| "architecture": manager.model.__class__.__name__, | |
| "device": str(manager.device), | |
| "dtype": str(next(manager.model.parameters()).dtype), | |
| "accessible": [ | |
| f"Token probabilities (all {config.vocab_size})", | |
| f"Attention weights ({config.n_layer} layers × {config.n_head} heads = {config.n_layer * config.n_head} patterns)", | |
| f"Hidden states (all {config.n_layer} layers)", | |
| "Logits before softmax", | |
| "Token embeddings", | |
| "Position embeddings (RoPE)", | |
| "Feed-forward activations", | |
| "Layer normalizations", | |
| "Gradient information (when available)", | |
| "Activation functions (GELU)" | |
| ], | |
| "config": { | |
| "activation_function": config.activation_function, | |
| "layer_norm_epsilon": config.layer_norm_epsilon, | |
| "tie_word_embeddings": config.tie_word_embeddings, | |
| "rotary_dim": config.rotary_dim if hasattr(config, 'rotary_dim') else None, | |
| "use_cache": config.use_cache | |
| } | |
| } | |
| async def generate(request: GenerationRequest, authenticated: bool = Depends(verify_api_key)): | |
| """Generate text with optional trace extraction""" | |
| result = await manager.generate_with_traces( | |
| prompt=request.prompt, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| sampling_rate=request.sampling_rate if request.extract_traces else 0 | |
| ) | |
| return result | |
| async def list_demos(authenticated: bool = Depends(verify_api_key)): | |
| """List available demo prompts""" | |
| return { | |
| "demos": [ | |
| { | |
| "id": "fibonacci", | |
| "name": "Fibonacci Function", | |
| "prompt": "def fibonacci(n):\n '''Calculate fibonacci number'''", | |
| "description": "Generate a recursive fibonacci implementation" | |
| }, | |
| { | |
| "id": "quicksort", | |
| "name": "Quicksort Algorithm", | |
| "prompt": "def quicksort(arr):\n '''Sort array using quicksort'''", | |
| "description": "Generate a quicksort implementation" | |
| }, | |
| { | |
| "id": "stack", | |
| "name": "Stack Class", | |
| "prompt": "class Stack:\n '''Simple stack implementation'''", | |
| "description": "Generate a stack data structure" | |
| }, | |
| { | |
| "id": "binary_search", | |
| "name": "Binary Search", | |
| "prompt": "def binary_search(arr, target):\n '''Find target in sorted array'''", | |
| "description": "Generate a binary search function" | |
| } | |
| ] | |
| } | |
| async def run_demo(request: DemoRequest, authenticated: bool = Depends(verify_api_key)): | |
| """Run a specific demo""" | |
| demos = { | |
| "fibonacci": "def fibonacci(n):\n '''Calculate fibonacci number'''", | |
| "quicksort": "def quicksort(arr):\n '''Sort array using quicksort'''", | |
| "stack": "class Stack:\n '''Simple stack implementation'''", | |
| "binary_search": "def binary_search(arr, target):\n '''Find target in sorted array'''" | |
| } | |
| if request.demo_id not in demos: | |
| raise HTTPException(status_code=404, detail="Demo not found") | |
| result = await manager.generate_with_traces( | |
| prompt=demos[request.demo_id], | |
| max_tokens=100, | |
| temperature=0.7, | |
| sampling_rate=0.3 # Same as regular generation for better visualization | |
| ) | |
| return result | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |