""" AI API Server - FastAPI + Mistral-7B Production-ready API with streaming, authentication, and caching """ import os import logging import asyncio from typing import Optional, Dict, Any, AsyncGenerator from datetime import datetime from functools import lru_cache import hashlib from fastapi import FastAPI, HTTPException, Header, Request, status from fastapi.responses import StreamingResponse, HTMLResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field, validator import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig import uvicorn # ============================================================================ # LOGGING CONFIGURATION # ============================================================================ logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", handlers=[ logging.StreamHandler(), logging.FileHandler("api.log") ] ) logger = logging.getLogger(__name__) # ============================================================================ # CONFIGURATION # ============================================================================ class Config: """Application configuration""" MODEL_NAME = os.getenv("MODEL_NAME", "mistralai/Mistral-7B-Instruct-v0.2") API_KEY = os.getenv("API_KEY", "your-secret-api-key-here") MAX_LENGTH = int(os.getenv("MAX_LENGTH", "2048")) TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7")) TOP_P = float(os.getenv("TOP_P", "0.95")) CACHE_SIZE = int(os.getenv("CACHE_SIZE", "100")) PORT = int(os.getenv("PORT", "7860")) HOST = os.getenv("HOST", "0.0.0.0") # Quantization config for 4-bit loading (optimized for free hardware) QUANTIZATION_CONFIG = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) config = Config() # ============================================================================ # PYDANTIC MODELS # ============================================================================ class ChatRequest(BaseModel): """Request model for chat endpoint""" prompt: str = Field(..., min_length=1, max_length=4000, description="User prompt") language: str = Field(default="en", description="Response language (en, pt, es)") temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0) max_tokens: Optional[int] = Field(default=None, ge=1, le=4096) stream: bool = Field(default=True, description="Enable streaming response") @validator("language") def validate_language(cls, v): allowed = ["en", "pt", "es", "fr", "de", "it", "ja", "zh"] if v not in allowed: raise ValueError(f"Language must be one of {allowed}") return v class ChatResponse(BaseModel): """Response model for chat endpoint""" response: str language: str model: str timestamp: str cached: bool = False class HealthResponse(BaseModel): """Health check response""" status: str model_loaded: bool timestamp: str # ============================================================================ # SYSTEM PROMPTS (MULTI-LANGUAGE) # ============================================================================ SYSTEM_PROMPTS = { "en": "You are a helpful, respectful and honest AI assistant. Always answer as helpfully as possible, while being safe. If you don't know the answer, say so instead of making up information.", "pt": "Você é um assistente de IA útil, respeitoso e honesto. Sempre responda da forma mais útil possível, mantendo a segurança. Se não souber a resposta, diga isso ao invés de inventar informações.", "es": "Eres un asistente de IA útil, respetuoso y honesto. Siempre responde de la manera más útil posible, manteniendo la seguridad. Si no sabes la respuesta, dilo en lugar de inventar información.", "fr": "Vous êtes un assistant IA utile, respectueux et honnête. Répondez toujours de la manière la plus utile possible, tout en restant sûr. Si vous ne connaissez pas la réponse, dites-le au lieu d'inventer des informations.", "de": "Sie sind ein hilfreicher, respektvoller und ehrlicher KI-Assistent. Antworten Sie immer so hilfreich wie möglich und bleiben Sie dabei sicher. Wenn Sie die Antwort nicht wissen, sagen Sie es, anstatt Informationen zu erfinden.", "it": "Sei un assistente AI utile, rispettoso e onesto. Rispondi sempre nel modo più utile possibile, mantenendo la sicurezza. Se non conosci la risposta, dillo invece di inventare informazioni.", "ja": "あなたは親切で、礼儀正しく、正直なAIアシスタントです。常に安全を保ちながら、できるだけ役立つように答えてください。答えがわからない場合は、情報を作り上げるのではなく、そう言ってください。", "zh": "你是一个乐于助人、尊重他人且诚实的AI助手。在保持安全的同时,始终尽可能有帮助地回答。如果你不知道答案,请说出来,而不是编造信息。" } # ============================================================================ # SIMPLE CACHE IMPLEMENTATION # ============================================================================ class ResponseCache: """Simple in-memory cache for responses""" def __init__(self, max_size: int = 100): self.cache: Dict[str, tuple[str, datetime]] = {} self.max_size = max_size logger.info(f"Initialized cache with max size: {max_size}") def _generate_key(self, prompt: str, language: str, temperature: float) -> str: """Generate cache key from parameters""" content = f"{prompt}:{language}:{temperature}" return hashlib.md5(content.encode()).hexdigest() def get(self, prompt: str, language: str, temperature: float) -> Optional[str]: """Retrieve cached response""" key = self._generate_key(prompt, language, temperature) if key in self.cache: response, timestamp = self.cache[key] logger.info(f"Cache HIT for key: {key[:8]}...") return response logger.info(f"Cache MISS for key: {key[:8]}...") return None def set(self, prompt: str, language: str, temperature: float, response: str): """Store response in cache""" if len(self.cache) >= self.max_size: # Remove oldest entry oldest_key = min(self.cache.keys(), key=lambda k: self.cache[k][1]) del self.cache[oldest_key] logger.info(f"Cache full, removed oldest entry: {oldest_key[:8]}...") key = self._generate_key(prompt, language, temperature) self.cache[key] = (response, datetime.now()) logger.info(f"Cached response for key: {key[:8]}...") # ============================================================================ # MODEL LOADING # ============================================================================ class ModelManager: """Manages model loading and inference""" def __init__(self): self.model = None self.tokenizer = None self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Device: {self.device}") async def load_model(self): """Load model with quantization""" try: logger.info(f"Loading model: {config.MODEL_NAME}") # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained( config.MODEL_NAME, trust_remote_code=True ) # Load model with 4-bit quantization self.model = AutoModelForCausalLM.from_pretrained( config.MODEL_NAME, quantization_config=config.QUANTIZATION_CONFIG, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True ) logger.info("Model loaded successfully!") return True except Exception as e: logger.error(f"Failed to load model: {str(e)}", exc_info=True) return False def format_prompt(self, prompt: str, language: str) -> str: """Format prompt with system message""" system_prompt = SYSTEM_PROMPTS.get(language, SYSTEM_PROMPTS["en"]) return f"[INST] {system_prompt}\n\nUser: {prompt} [/INST]" async def generate_stream( self, prompt: str, language: str, temperature: float, max_tokens: int ) -> AsyncGenerator[str, None]: """Generate response with streaming""" try: formatted_prompt = self.format_prompt(prompt, language) # Tokenize input inputs = self.tokenizer( formatted_prompt, return_tensors="pt", truncation=True, max_length=config.MAX_LENGTH ).to(self.device) # Generate with streaming with torch.no_grad(): for i in range(max_tokens): outputs = self.model.generate( **inputs, max_new_tokens=1, temperature=temperature, top_p=config.TOP_P, do_sample=True, pad_token_id=self.tokenizer.eos_token_id ) # Decode new token new_token = self.tokenizer.decode( outputs[0][-1:], skip_special_tokens=True ) # Check for end of sequence if outputs[0][-1] == self.tokenizer.eos_token_id: break yield new_token # Update inputs for next iteration inputs = {"input_ids": outputs} # Small delay to simulate realistic streaming await asyncio.sleep(0.01) except Exception as e: logger.error(f"Generation error: {str(e)}", exc_info=True) yield f"\n\n[Error: {str(e)}]" async def generate( self, prompt: str, language: str, temperature: float, max_tokens: int ) -> str: """Generate complete response (non-streaming)""" try: formatted_prompt = self.format_prompt(prompt, language) inputs = self.tokenizer( formatted_prompt, return_tensors="pt", truncation=True, max_length=config.MAX_LENGTH ).to(self.device) with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=config.TOP_P, do_sample=True, pad_token_id=self.tokenizer.eos_token_id ) response = self.tokenizer.decode( outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True ) return response.strip() except Exception as e: logger.error(f"Generation error: {str(e)}", exc_info=True) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Generation failed: {str(e)}" ) # ============================================================================ # FASTAPI APPLICATION # ============================================================================ app = FastAPI( title="AI API - Mistral 7B", description="Production-ready AI API with streaming, authentication, and caching", version="1.0.0", docs_url="/docs", redoc_url="/redoc" ) # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global instances model_manager = ModelManager() cache = ResponseCache(max_size=config.CACHE_SIZE) # ============================================================================ # AUTHENTICATION # ============================================================================ async def verify_api_key(x_api_key: str = Header(..., alias="X-API-Key")): """Verify API key from header""" if x_api_key != config.API_KEY: logger.warning(f"Invalid API key attempt: {x_api_key[:8]}...") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key" ) return x_api_key # ============================================================================ # STARTUP/SHUTDOWN EVENTS # ============================================================================ @app.on_event("startup") async def startup_event(): """Load model on startup""" logger.info("Starting AI API server...") success = await model_manager.load_model() if not success: logger.error("Failed to load model, server may not function correctly") @app.on_event("shutdown") async def shutdown_event(): """Cleanup on shutdown""" logger.info("Shutting down AI API server...") # Clear cache cache.cache.clear() # ============================================================================ # ROUTES # ============================================================================ @app.get("/", response_class=HTMLResponse) async def root(): """Serve frontend HTML""" html_content = """ AI API - Mistral 7B

🤖 AI API - Mistral 7B

Production-ready AI API with streaming responses

Start a conversation by typing a message below

📚 API Documentation

Endpoint

POST /api/chat

Headers

X-API-Key: your-api-key

Example (curl)

curl -X POST "http://localhost:7860/api/chat" \\
  -H "X-API-Key: your-secret-api-key-here" \\
  -H "Content-Type: application/json" \\
  -d '{
    "prompt": "Explain quantum computing",
    "language": "en",
    "stream": false
  }'
""" return HTMLResponse(content=html_content) @app.get("/health", response_model=HealthResponse) async def health_check(): """Health check endpoint""" return HealthResponse( status="healthy", model_loaded=model_manager.model is not None, timestamp=datetime.now().isoformat() ) @app.post("/api/chat") async def chat( request: ChatRequest, api_key: str = Header(..., alias="X-API-Key") ): """ Chat endpoint with streaming support Requires X-API-Key header for authentication """ # Verify API key await verify_api_key(api_key) # Check if model is loaded if model_manager.model is None: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Model not loaded yet, please try again later" ) # Get parameters temperature = request.temperature or config.TEMPERATURE max_tokens = request.max_tokens or 512 try: # Check cache for non-streaming requests if not request.stream: cached_response = cache.get(request.prompt, request.language, temperature) if cached_response: return ChatResponse( response=cached_response, language=request.language, model=config.MODEL_NAME, timestamp=datetime.now().isoformat(), cached=True ) # Streaming response if request.stream: async def generate(): full_response = "" async for token in model_manager.generate_stream( request.prompt, request.language, temperature, max_tokens ): full_response += token yield f"data: {{'token': '{token}'}}\n\n" # Cache complete response cache.set(request.prompt, request.language, temperature, full_response) yield "data: [DONE]\n\n" return StreamingResponse( generate(), media_type="text/event-stream" ) # Non-streaming response else: response = await model_manager.generate( request.prompt, request.language, temperature, max_tokens ) # Cache response cache.set(request.prompt, request.language, temperature, response) return ChatResponse( response=response, language=request.language, model=config.MODEL_NAME, timestamp=datetime.now().isoformat(), cached=False ) except Exception as e: logger.error(f"Chat endpoint error: {str(e)}", exc_info=True) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) ) # ============================================================================ # MAIN # ============================================================================ if __name__ == "__main__": logger.info(f"Starting server on {config.HOST}:{config.PORT}") uvicorn.run( app, host=config.HOST, port=config.PORT, log_level="info" )