from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from typing import List, Dict, Any, Optional, Union import torch import time import logging import asyncio from datetime import datetime import json from contextlib import asynccontextmanager import uvicorn import psutil import GPUtil from ..configs.config import Config, get_balanced_config from ..architecture.model import create_compact_model, CompactAIModel import os logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global model instance model: Optional[CompactAIModel] = None tokenizer = None # We'll use a simple tokenizer for now @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan manager.""" global model # Load model on startup logger.info("Loading Compact AI Model...") try: model_size = os.getenv("MODEL_SIZE", "small") model = create_compact_model(model_size) # Load checkpoint if available checkpoint_path = os.getenv("MODEL_CHECKPOINT") if checkpoint_path and os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location="cpu") model.load_state_dict(checkpoint) logger.info(f"Loaded model checkpoint from {checkpoint_path}") model.eval() if torch.cuda.is_available(): model = model.cuda() logger.info("Model loaded successfully!") except Exception as e: logger.error(f"Failed to load model: {e}") model = None yield # Cleanup on shutdown logger.info("Shutting down...") app = FastAPI( title="Compact AI Model API", description="API for the compact AI model with interleaved thinking", version="1.0.0", lifespan=lifespan, ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Pydantic models for requests/responses class ChatMessage(BaseModel): role: str = Field(..., description="Role of the message (user/assistant/system)") content: str = Field(..., description="Content of the message") class ChatCompletionRequest(BaseModel): model: str = Field(default="compact-ai-v1", description="Model name") messages: List[ChatMessage] = Field(..., description="List of messages") max_tokens: Optional[int] = Field(default=100, description="Maximum tokens to generate") temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature") top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0, description="Top-p sampling") reasoning_depth: Optional[Union[str, int]] = Field(default="adaptive", description="Reasoning depth") early_stop_threshold: Optional[float] = Field(default=0.85, description="Early stop threshold") thinking_visualization: Optional[bool] = Field(default=False, description="Include thinking visualization") class CompletionRequest(BaseModel): model: str = Field(default="compact-ai-v1", description="Model name") prompt: str = Field(..., description="Input prompt") max_tokens: Optional[int] = Field(default=50, description="Maximum tokens to generate") temperature: Optional[float] = Field(default=0.8, ge=0.0, le=2.0, description="Sampling temperature") reasoning_tokens: Optional[int] = Field(default=100, description="Maximum reasoning tokens") class AnthropicMessageRequest(BaseModel): model: str = Field(default="compact-ai-v1", description="Model name") messages: List[ChatMessage] = Field(..., description="List of messages") max_tokens: int = Field(default=1024, description="Maximum tokens to generate") system: Optional[str] = Field(default=None, description="System message") thinking_config: Optional[Dict[str, Any]] = Field(default=None, description="Thinking configuration") class ChatCompletionChoice(BaseModel): index: int message: ChatMessage finish_reason: str thinking_trace: Optional[Dict[str, Any]] = None class ChatCompletionResponse(BaseModel): id: str object: str = "chat.completion" created: int model: str choices: List[ChatCompletionChoice] usage: Dict[str, int] class CompletionChoice(BaseModel): text: str index: int finish_reason: str thinking_tokens: Optional[int] = None class CompletionResponse(BaseModel): id: str object: str = "text_completion" created: int model: str choices: List[CompletionChoice] usage: Dict[str, int] class AnthropicMessageResponse(BaseModel): id: str type: str = "message" role: str = "assistant" content: List[Dict[str, Any]] model: str usage: Dict[str, int] class ModelInfo(BaseModel): id: str object: str = "model" created: int owned_by: str = "compact-ai" class ModelListResponse(BaseModel): object: str = "list" data: List[ModelInfo] class HealthResponse(BaseModel): status: str model_loaded: bool gpu_available: bool memory_usage: Dict[str, Any] uptime: str # Simple tokenizer for demonstration (replace with proper tokenizer) class SimpleTokenizer: def __init__(self, vocab_size=32000): self.vocab_size = vocab_size self.pad_token_id = 0 self.eos_token_id = 1 self.bos_token_id = 2 def encode(self, text: str, max_length=None, truncation=True, padding=False): # Very simple tokenization - split by spaces and map to IDs tokens = text.split() token_ids = [hash(word) % (self.vocab_size - 100) + 100 for word in tokens] if max_length and len(token_ids) > max_length: token_ids = token_ids[:max_length] if padding and max_length: token_ids += [self.pad_token_id] * (max_length - len(token_ids)) return token_ids def decode(self, token_ids: List[int]): # Simple reverse mapping (not accurate for real tokenizers) return " ".join([f"" for tid in token_ids]) tokenizer = SimpleTokenizer() def generate_text( prompt: str, max_tokens: int = 50, temperature: float = 0.8, reasoning_depth: Union[str, int] = "adaptive", early_stop_threshold: float = 0.85, use_thinking: bool = True, ) -> Dict[str, Any]: """Generate text using the model.""" if model is None: raise HTTPException(status_code=503, detail="Model not loaded") try: # Tokenize input input_ids = tokenizer.encode(prompt, max_length=512, truncation=True) input_tensor = torch.tensor([input_ids], dtype=torch.long) if torch.cuda.is_available(): input_tensor = input_tensor.cuda() # Determine reasoning depth if isinstance(reasoning_depth, str): if reasoning_depth == "adaptive": max_reasoning_depth = None # Let model decide elif reasoning_depth == "simple": max_reasoning_depth = 1 elif reasoning_depth == "complex": max_reasoning_depth = 4 else: max_reasoning_depth = 2 else: max_reasoning_depth = reasoning_depth with torch.no_grad(): outputs = model( input_tensor, use_thinking=use_thinking, max_reasoning_depth=max_reasoning_depth, ) logits = outputs["logits"][0] # Remove batch dimension thinking_results = outputs["thinking_results"] reasoning_tokens = outputs.get("final_tokens", 0) # Generate tokens generated_tokens = [] current_logits = logits[-1] # Start from last token for _ in range(max_tokens): if temperature > 0: probs = torch.softmax(current_logits / temperature, dim=-1) next_token = torch.multinomial(probs, 1).item() else: next_token = current_logits.argmax().item() generated_tokens.append(next_token) if next_token == tokenizer.eos_token_id: break # Get next logits (simplified - in practice you'd run the model again) if len(generated_tokens) < max_tokens: current_logits = current_logits # Simplified # Decode generated text generated_text = tokenizer.decode(generated_tokens) return { "generated_text": generated_text, "thinking_results": thinking_results, "reasoning_tokens": reasoning_tokens, "input_tokens": len(input_ids), "output_tokens": len(generated_tokens), } except Exception as e: logger.error(f"Generation error: {e}") raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def chat_completions(request: ChatCompletionRequest): """OpenAI-compatible chat completions endpoint.""" start_time = time.time() # Extract the last user message as prompt user_messages = [msg for msg in request.messages if msg.role == "user"] if not user_messages: raise HTTPException(status_code=400, detail="No user message found") prompt = user_messages[-1].content # Add system message if present system_messages = [msg for msg in request.messages if msg.role == "system"] if system_messages: prompt = f"System: {system_messages[0].content}\n\n{prompt}" # Generate response result = generate_text( prompt=prompt, max_tokens=request.max_tokens or 100, temperature=request.temperature or 0.7, reasoning_depth=request.reasoning_depth or "adaptive", early_stop_threshold=request.early_stop_threshold or 0.85, ) # Prepare thinking visualization if requested thinking_trace = None if request.thinking_visualization and result["thinking_results"]: thinking_trace = { "reasoning_paths": len(result["thinking_results"]), "reasoning_tokens": result["reasoning_tokens"], "confidence_scores": [0.85, 0.78, 0.92], # Mock data } response = ChatCompletionResponse( id=f"chatcmpl-{int(time.time())}", created=int(time.time()), model=request.model, choices=[ ChatCompletionChoice( index=0, message=ChatMessage(role="assistant", content=result["generated_text"]), finish_reason="stop", thinking_trace=thinking_trace, ) ], usage={ "prompt_tokens": result["input_tokens"], "completion_tokens": result["output_tokens"], "total_tokens": result["input_tokens"] + result["output_tokens"], "reasoning_tokens": result["reasoning_tokens"], } ) logger.info(f"Chat completion took {time.time() - start_time:.2f}s") return response @app.post("/v1/completions", response_model=CompletionResponse) async def completions(request: CompletionRequest): """OpenAI-compatible text completions endpoint.""" start_time = time.time() result = generate_text( prompt=request.prompt, max_tokens=request.max_tokens or 50, temperature=request.temperature or 0.8, reasoning_depth=2, # Default for completions early_stop_threshold=0.8, ) response = CompletionResponse( id=f"cmpl-{int(time.time())}", created=int(time.time()), model=request.model, choices=[ CompletionChoice( text=result["generated_text"], index=0, finish_reason="stop", thinking_tokens=result["reasoning_tokens"], ) ], usage={ "prompt_tokens": result["input_tokens"], "completion_tokens": result["output_tokens"], "total_tokens": result["input_tokens"] + result["output_tokens"], } ) logger.info(f"Completion took {time.time() - start_time:.2f}s") return response @app.post("/v1/messages", response_model=AnthropicMessageResponse) async def anthropic_messages(request: AnthropicMessageRequest): """Anthropic-compatible messages endpoint.""" start_time = time.time() # Extract messages messages = [] for msg in request.messages: if msg.role == "user": messages.append(f"Human: {msg.content}") elif msg.role == "assistant": messages.append(f"Assistant: {msg.content}") # Add system message if request.system: messages.insert(0, f"System: {request.system}") prompt = "\n\n".join(messages) # Parse thinking config thinking_config = request.thinking_config or {} reasoning_depth = thinking_config.get("reasoning_depth", "complex") visualization = thinking_config.get("thinking_visualization", True) result = generate_text( prompt=prompt, max_tokens=request.max_tokens, temperature=0.7, # Default for Anthropic reasoning_depth=reasoning_depth, early_stop_threshold=0.85, ) # Prepare content with thinking if requested content = [{"type": "text", "text": result["generated_text"]}] if visualization and result["thinking_results"]: thinking_text = f"\n\nThinking process used {result['reasoning_tokens']} reasoning tokens across {len(result['thinking_results'])} layers." content.insert(0, {"type": "text", "text": thinking_text}) response = AnthropicMessageResponse( id=f"msg_{int(time.time())}", model=request.model, content=content, usage={ "input_tokens": result["input_tokens"], "output_tokens": result["output_tokens"], "total_tokens": result["input_tokens"] + result["output_tokens"], } ) logger.info(f"Anthropic message took {time.time() - start_time:.2f}s") return response @app.get("/v1/models", response_model=ModelListResponse) async def list_models(): """List available models.""" return ModelListResponse( data=[ ModelInfo( id="compact-ai-v1", created=int(time.time()), ) ] ) @app.get("/v1/models/{model_id}") async def get_model(model_id: str): """Get model information.""" if model_id != "compact-ai-v1": raise HTTPException(status_code=404, detail="Model not found") return ModelInfo( id=model_id, created=int(time.time()), ) @app.get("/health", response_model=HealthResponse) async def health_check(): """Health check endpoint.""" memory_info = psutil.virtual_memory() gpu_info = {} try: gpus = GPUtil.getGPUs() if gpus: gpu = gpus[0] gpu_info = { "gpu_name": gpu.name, "gpu_memory_used": gpu.memoryUsed, "gpu_memory_total": gpu.memoryTotal, "gpu_memory_free": gpu.memoryFree, "gpu_utilization": gpu.load * 100, } except: pass return HealthResponse( status="healthy" if model is not None else "unhealthy", model_loaded=model is not None, gpu_available=torch.cuda.is_available(), memory_usage={ "ram_used": memory_info.used, "ram_total": memory_info.total, "ram_percent": memory_info.percent, **gpu_info, }, uptime=str(datetime.now() - datetime.fromtimestamp(psutil.boot_time())), ) @app.get("/") async def root(): """Root endpoint.""" return {"message": "Compact AI Model API", "version": "1.0.0"} if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Run Compact AI Model API") parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") parser.add_argument("--port", type=int, default=8000, help="Port to bind to") parser.add_argument("--workers", type=int, default=1, help="Number of workers") parser.add_argument("--model-size", default="small", choices=["tiny", "small", "medium"], help="Model size") parser.add_argument("--checkpoint", help="Path to model checkpoint") args = parser.parse_args() # Set environment variables os.environ["MODEL_SIZE"] = args.model_size if args.checkpoint: os.environ["MODEL_CHECKPOINT"] = args.checkpoint uvicorn.run( "main:app", host=args.host, port=args.port, workers=args.workers, reload=False, log_level="info", )