Spaces:
Running
Running
Dmitry Beresnev
Refactor the C++ LLM manager into modular components, moves Python modules under python/, and keeps the current control-plane behavior intact. The C++ server now has clearer separation for config, model lifecycle, runtime services, request parsing, HTTP helpers, and server routing, while Docker build/runtime paths were updated to compile multiple C++ files and load Python code from the new package folder
332826f | import subprocess | |
| import signal | |
| import os | |
| import time | |
| import asyncio | |
| from typing import Optional, Dict, List, Any | |
| from dataclasses import dataclass, field | |
| from collections import OrderedDict | |
| from datetime import datetime, timedelta | |
| import hashlib | |
| import json | |
| import uuid | |
| import aiohttp | |
| from fastapi import FastAPI, HTTPException, BackgroundTasks, Request | |
| from fastapi.openapi.utils import get_openapi | |
| from pydantic import BaseModel, Field | |
| from duckduckgo_search import DDGS | |
| from bs4 import BeautifulSoup | |
| from logger import get_logger | |
| logger = get_logger(__name__) | |
| app = FastAPI( | |
| title="AGI Multi-Model API", | |
| description=""" | |
| **High-Performance Dynamic Multi-Model LLM API with Web Search** | |
| This API provides: | |
| * π Dynamic model switching with intelligent caching | |
| * π¬ OpenAI-compatible chat completions | |
| * π Web-augmented chat with real-time search | |
| * π Model management and performance monitoring | |
| * β‘ Async/await architecture for maximum throughput | |
| ## Available Models | |
| - **deepseek-chat** (default): General purpose conversational model | |
| - **mistral-7b**: Financial analysis and summarization | |
| - **openhermes-7b**: Advanced instruction following | |
| - **deepseek-coder**: Specialized coding assistance | |
| - **llama-7b**: Lightweight and fast responses | |
| ## Performance Features | |
| - Parallel model loading | |
| - Connection pooling for HTTP requests | |
| - Web search result caching | |
| - Background model preloading | |
| - Request queuing to prevent overload | |
| - Real-time performance metrics | |
| ## Quick Start | |
| 1. Check available models: `GET /models` | |
| 2. Switch model (optional): `POST /switch-model` | |
| 3. Chat: `POST /v1/chat/completions` | |
| 4. Chat with web search: `POST /v1/web-chat/completions` | |
| 5. View metrics: `GET /metrics` | |
| """, | |
| version="0.1.0.2026.01.24", | |
| contact={ | |
| "name": "API Support", | |
| "email": "support@example.com", | |
| }, | |
| license_info={ | |
| "name": "MIT", | |
| }, | |
| openapi_tags=[ | |
| { | |
| "name": "status", | |
| "description": "System status and health checks", | |
| }, | |
| { | |
| "name": "models", | |
| "description": "Model management and switching operations", | |
| }, | |
| { | |
| "name": "chat", | |
| "description": "Chat completion endpoints (OpenAI-compatible)", | |
| }, | |
| { | |
| "name": "monitoring", | |
| "description": "Performance metrics and monitoring", | |
| }, | |
| { | |
| "name": "documentation", | |
| "description": "API documentation and OpenAPI specification", | |
| }, | |
| ] | |
| ) | |
| # Predefined list of available models (TheBloke only - verified, fits 18GB Space) | |
| AVAILABLE_MODELS = { | |
| # === General Purpose (Default) === | |
| "deepseek-chat": "TheBloke/deepseek-llm-7B-chat-GGUF:deepseek-llm-7b-chat.Q4_K_M.gguf", | |
| # === Financial & Summarization Models === | |
| "mistral-7b": "TheBloke/Mistral-7B-Instruct-v0.2-GGUF:mistral-7b-instruct-v0.2.Q4_K_M.gguf", | |
| "openhermes-7b": "TheBloke/OpenHermes-2.5-Mistral-7B-GGUF:openhermes-2.5-mistral-7b.Q4_K_M.gguf", | |
| # === Coding Models === | |
| "deepseek-coder": "TheBloke/deepseek-coder-6.7B-instruct-GGUF:deepseek-coder-6.7b-instruct.Q4_K_M.gguf", | |
| # === Lightweight/Fast === | |
| "llama-7b": "TheBloke/Llama-2-7B-Chat-GGUF:llama-2-7b-chat.Q4_K_M.gguf", | |
| } | |
| # Configuration - now environment-variable driven | |
| MAX_CACHED_MODELS = int(os.getenv("MAX_CACHED_MODELS", "2")) | |
| BASE_PORT = int(os.getenv("BASE_PORT", "8080")) | |
| PRELOAD_MODELS = os.getenv("PRELOAD_MODELS", "").split(",") if os.getenv("PRELOAD_MODELS") else [] | |
| WEB_SEARCH_CACHE_TTL = int(os.getenv("WEB_SEARCH_CACHE_TTL", "3600")) # 1 hour | |
| REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", "300")) # 5 minutes | |
| LOG_REQUEST_BODY = os.getenv("LOG_REQUEST_BODY", "1") == "1" | |
| LOG_REQUEST_BODY_MAX_CHARS = int(os.getenv("LOG_REQUEST_BODY_MAX_CHARS", "2000")) | |
| CONTEXT_SIZE = int(os.getenv("CONTEXT_SIZE", "2048")) | |
| PROMPT_MARGIN_TOKENS = int(os.getenv("PROMPT_MARGIN_TOKENS", "256")) | |
| CHARS_PER_TOKEN_EST = float(os.getenv("CHARS_PER_TOKEN_EST", "4.0")) | |
| SYSTEM_PROMPT_MAX_TOKENS = int(os.getenv("SYSTEM_PROMPT_MAX_TOKENS", "512")) | |
| ALLOW_LONG_SYSTEM_PROMPT = os.getenv("ALLOW_LONG_SYSTEM_PROMPT", "0") == "1" | |
| HARD_REQUEST_TIMEOUT = int(os.getenv("HARD_REQUEST_TIMEOUT", "300")) | |
| def _estimate_tokens(text: str) -> int: | |
| """Rough token estimate based on character count.""" | |
| if not text: | |
| return 0 | |
| return int(len(text) / CHARS_PER_TOKEN_EST) + 1 | |
| def _truncate_text_to_tokens(text: str, max_tokens: int) -> str: | |
| """Truncate text to an approximate token budget.""" | |
| if not text or max_tokens <= 0: | |
| return "" | |
| max_chars = int(max_tokens * CHARS_PER_TOKEN_EST) | |
| if len(text) <= max_chars: | |
| return text | |
| return text[:max_chars] + "...[truncated]" | |
| def _compact_messages(messages: list[dict], max_tokens: int) -> list[dict]: | |
| """ | |
| Compact messages to fit within the prompt budget. | |
| Strategy: | |
| - Cap system message content size. | |
| - Drop oldest non-system messages until within budget. | |
| - As a last resort, truncate the oldest remaining non-system message. | |
| """ | |
| if not messages: | |
| return messages | |
| prompt_budget = CONTEXT_SIZE - max_tokens - PROMPT_MARGIN_TOKENS | |
| if prompt_budget <= 0: | |
| return messages | |
| # Work on a copy to avoid mutating caller input | |
| compacted = [dict(m) for m in messages] | |
| # Cap system messages | |
| system_cap = min(1024, max(256, prompt_budget // 3)) | |
| for msg in compacted: | |
| if msg.get("role") == "system" and "content" in msg: | |
| if not ALLOW_LONG_SYSTEM_PROMPT and _estimate_tokens(str(msg["content"])) > SYSTEM_PROMPT_MAX_TOKENS: | |
| msg["content"] = "" | |
| else: | |
| msg["content"] = _truncate_text_to_tokens(str(msg["content"]), system_cap) | |
| def total_tokens(msgs: list[dict]) -> int: | |
| return sum(_estimate_tokens(str(m.get("content", ""))) for m in msgs) | |
| # Drop oldest non-system messages until under budget | |
| while total_tokens(compacted) > prompt_budget: | |
| idx = next((i for i, m in enumerate(compacted) if m.get("role") != "system"), None) | |
| if idx is None: | |
| break | |
| compacted.pop(idx) | |
| # Last resort: truncate oldest non-system content | |
| if total_tokens(compacted) > prompt_budget: | |
| idx = next((i for i, m in enumerate(compacted) if m.get("role") != "system"), None) | |
| if idx is not None: | |
| remaining_budget = max(1, prompt_budget - (total_tokens(compacted) - _estimate_tokens(str(compacted[idx].get("content", ""))))) | |
| compacted[idx]["content"] = _truncate_text_to_tokens(str(compacted[idx].get("content", "")), remaining_budget) | |
| return compacted | |
| def _estimate_messages_tokens(messages: list[dict]) -> int: | |
| """Estimate total tokens for a list of messages.""" | |
| return sum(_estimate_tokens(str(m.get("content", ""))) for m in messages) | |
| class CachedModel: | |
| """Represents a cached model with its process and connection info.""" | |
| name: str | |
| model_id: str | |
| process: subprocess.Popen | |
| port: int | |
| url: str | |
| last_used: float | |
| load_time: float = 0.0 | |
| request_count: int = 0 | |
| total_latency: float = 0.0 | |
| class PerformanceMetrics: | |
| """Performance metrics for monitoring.""" | |
| total_requests: int = 0 | |
| total_switches: int = 0 | |
| cache_hits: int = 0 | |
| cache_misses: int = 0 | |
| total_web_searches: int = 0 | |
| web_search_cache_hits: int = 0 | |
| model_metrics: Dict[str, Dict] = field(default_factory=dict) | |
| startup_time: float = 0.0 | |
| def record_request(self, model_name: str, latency: float): | |
| """Record a request for metrics.""" | |
| self.total_requests += 1 | |
| if model_name not in self.model_metrics: | |
| self.model_metrics[model_name] = { | |
| "requests": 0, | |
| "total_latency": 0.0, | |
| "avg_latency": 0.0 | |
| } | |
| self.model_metrics[model_name]["requests"] += 1 | |
| self.model_metrics[model_name]["total_latency"] += latency | |
| self.model_metrics[model_name]["avg_latency"] = ( | |
| self.model_metrics[model_name]["total_latency"] / | |
| self.model_metrics[model_name]["requests"] | |
| ) | |
| class WebSearchCacheEntry: | |
| """Cache entry for web search results.""" | |
| results: List[dict] | |
| timestamp: float | |
| ttl: int = WEB_SEARCH_CACHE_TTL | |
| def is_expired(self) -> bool: | |
| """Check if cache entry has expired.""" | |
| return time.time() - self.timestamp > self.ttl | |
| class WebSearchCache: | |
| """LRU cache for web search results.""" | |
| def __init__(self, max_size: int = 100): | |
| self.max_size = max_size | |
| self.cache: OrderedDict[str, WebSearchCacheEntry] = OrderedDict() | |
| def _get_cache_key(self, query: str, max_results: int) -> str: | |
| """Generate cache key from query.""" | |
| key = f"{query}:{max_results}" | |
| return hashlib.md5(key.encode()).hexdigest() | |
| def get(self, query: str, max_results: int) -> Optional[List[dict]]: | |
| """Get cached search results if available and not expired.""" | |
| key = self._get_cache_key(query, max_results) | |
| if key in self.cache: | |
| entry = self.cache[key] | |
| if not entry.is_expired(): | |
| # Move to end (most recently used) | |
| self.cache.move_to_end(key) | |
| return entry.results | |
| else: | |
| # Remove expired entry | |
| del self.cache[key] | |
| return None | |
| def put(self, query: str, max_results: int, results: List[dict]): | |
| """Cache search results.""" | |
| key = self._get_cache_key(query, max_results) | |
| # Evict oldest if cache is full | |
| if len(self.cache) >= self.max_size and key not in self.cache: | |
| self.cache.popitem(last=False) | |
| self.cache[key] = WebSearchCacheEntry( | |
| results=results, | |
| timestamp=time.time() | |
| ) | |
| def clear(self): | |
| """Clear all cached results.""" | |
| self.cache.clear() | |
| class ModelCache: | |
| """ | |
| High-performance in-memory LRU cache for loaded models. | |
| Features: | |
| - Manages multiple llama-server processes on different ports | |
| - LRU eviction when cache is full | |
| - Parallel model loading support | |
| - Performance metrics tracking | |
| """ | |
| def __init__(self, max_size: int = MAX_CACHED_MODELS): | |
| self.max_size = max_size | |
| self.cache: OrderedDict[str, CachedModel] = OrderedDict() | |
| self.port_counter = BASE_PORT | |
| self.used_ports = set() | |
| self._loading_lock = asyncio.Lock() | |
| self._loading_models: Dict[str, asyncio.Task] = {} | |
| def _get_next_port(self) -> int: | |
| """Get next available port for a model.""" | |
| while self.port_counter in self.used_ports: | |
| self.port_counter += 1 | |
| port = self.port_counter | |
| self.used_ports.add(port) | |
| self.port_counter += 1 | |
| return port | |
| def _release_port(self, port: int): | |
| """Release a port back to the pool.""" | |
| self.used_ports.discard(port) | |
| async def _evict_lru(self): | |
| """Evict the least recently used model.""" | |
| if not self.cache: | |
| return | |
| # Get the first (oldest) item | |
| model_name, cached_model = self.cache.popitem(last=False) | |
| logger.info(f"Evicting model from cache: {model_name}") | |
| # Stop the process | |
| try: | |
| if os.name != 'nt': | |
| os.killpg(os.getpgid(cached_model.process.pid), signal.SIGTERM) | |
| else: | |
| cached_model.process.terminate() | |
| # Wait asynchronously for process to stop | |
| for _ in range(10): | |
| if cached_model.process.poll() is not None: | |
| break | |
| await asyncio.sleep(0.1) | |
| else: | |
| # Force kill if not stopped | |
| if os.name != 'nt': | |
| os.killpg(os.getpgid(cached_model.process.pid), signal.SIGKILL) | |
| else: | |
| cached_model.process.kill() | |
| except Exception as e: | |
| logger.error(f"Error stopping model {model_name}: {e}") | |
| # Release the port | |
| self._release_port(cached_model.port) | |
| def get(self, model_name: str) -> Optional[CachedModel]: | |
| """Get a model from cache, updating its last used time.""" | |
| if model_name in self.cache: | |
| cached_model = self.cache[model_name] | |
| cached_model.last_used = time.time() | |
| # Move to end (most recently used) | |
| self.cache.move_to_end(model_name) | |
| logger.debug(f"Cache hit for model: {model_name}") | |
| return cached_model | |
| logger.debug(f"Cache miss for model: {model_name}") | |
| return None | |
| async def put(self, model_name: str, model_id: str, process: subprocess.Popen, port: int, load_time: float = 0.0): | |
| """Add a model to the cache.""" | |
| # Evict if cache is full | |
| while len(self.cache) >= self.max_size: | |
| await self._evict_lru() | |
| url = f"http://localhost:{port}" | |
| cached_model = CachedModel( | |
| name=model_name, | |
| model_id=model_id, | |
| process=process, | |
| port=port, | |
| url=url, | |
| last_used=time.time(), | |
| load_time=load_time | |
| ) | |
| self.cache[model_name] = cached_model | |
| logger.info(f"Cached model: {model_name} on port {port} (load time: {load_time:.2f}s)") | |
| async def clear(self): | |
| """Clear all cached models.""" | |
| logger.info("Clearing model cache...") | |
| for model_name, cached_model in list(self.cache.items()): | |
| try: | |
| if os.name != 'nt': | |
| os.killpg(os.getpgid(cached_model.process.pid), signal.SIGTERM) | |
| else: | |
| cached_model.process.terminate() | |
| # Wait asynchronously | |
| for _ in range(10): | |
| if cached_model.process.poll() is not None: | |
| break | |
| await asyncio.sleep(0.1) | |
| except: | |
| try: | |
| if os.name != 'nt': | |
| os.killpg(os.getpgid(cached_model.process.pid), signal.SIGKILL) | |
| else: | |
| cached_model.process.kill() | |
| except: | |
| pass | |
| self._release_port(cached_model.port) | |
| self.cache.clear() | |
| def get_cache_info(self) -> Dict: | |
| """Get information about cached models.""" | |
| return { | |
| "max_size": self.max_size, | |
| "current_size": len(self.cache), | |
| "cached_models": [ | |
| { | |
| "name": name, | |
| "port": model.port, | |
| "url": model.url, | |
| "last_used": model.last_used, | |
| "load_time": model.load_time, | |
| "request_count": model.request_count, | |
| "avg_latency": model.total_latency / model.request_count if model.request_count > 0 else 0.0 | |
| } | |
| for name, model in self.cache.items() | |
| ] | |
| } | |
| # Global state | |
| current_model = "deepseek-chat" # Default model | |
| model_cache = ModelCache(max_size=MAX_CACHED_MODELS) | |
| web_search_cache = WebSearchCache(max_size=100) | |
| metrics = PerformanceMetrics() | |
| # HTTP session for connection pooling (will be initialized in startup) | |
| http_session: Optional[aiohttp.ClientSession] = None | |
| SENSITIVE_HEADERS = {"authorization", "proxy-authorization", "x-api-key", "api-key"} | |
| SENSITIVE_FIELDS = {"authorization", "api_key", "api-key", "password", "token"} | |
| def _redact_headers(headers: Dict[str, str]) -> Dict[str, str]: | |
| """Redact sensitive headers before logging.""" | |
| redacted = {} | |
| for key, value in headers.items(): | |
| if key.lower() in SENSITIVE_HEADERS: | |
| redacted[key] = "[redacted]" | |
| else: | |
| redacted[key] = value | |
| return redacted | |
| def _redact_json(obj: Any) -> Any: | |
| """Redact sensitive fields in JSON-like structures.""" | |
| if isinstance(obj, dict): | |
| redacted = {} | |
| for key, value in obj.items(): | |
| if str(key).lower() in SENSITIVE_FIELDS: | |
| redacted[key] = "[redacted]" | |
| else: | |
| redacted[key] = _redact_json(value) | |
| return redacted | |
| if isinstance(obj, list): | |
| return [_redact_json(item) for item in obj] | |
| return obj | |
| def _format_body_for_log(content_type: str, body: bytes) -> str: | |
| """Format request body for logging with size limits and redaction.""" | |
| if not body: | |
| return "" | |
| text_preview: str | |
| if "application/json" in (content_type or ""): | |
| try: | |
| parsed = json.loads(body.decode("utf-8")) | |
| redacted = _redact_json(parsed) | |
| text_preview = json.dumps(redacted, ensure_ascii=True) | |
| except Exception: | |
| text_preview = body.decode("utf-8", errors="replace") | |
| else: | |
| text_preview = body.decode("utf-8", errors="replace") | |
| if len(text_preview) > LOG_REQUEST_BODY_MAX_CHARS: | |
| text_preview = text_preview[:LOG_REQUEST_BODY_MAX_CHARS] + "...[truncated]" | |
| return text_preview | |
| async def log_received_request(request: Request, call_next): | |
| """Log incoming requests and responses with basic metadata.""" | |
| request_id = uuid.uuid4().hex[:12] | |
| request.state.request_id = request_id | |
| start = time.perf_counter() | |
| body_text = "" | |
| body_bytes = b"" | |
| if LOG_REQUEST_BODY: | |
| body_bytes = await request.body() | |
| body_text = _format_body_for_log(request.headers.get("content-type", ""), body_bytes) | |
| async def receive(): | |
| return {"type": "http.request", "body": body_bytes, "more_body": False} | |
| # Recreate request so downstream can read body again | |
| request = Request(request.scope, receive) | |
| headers = _redact_headers(dict(request.headers)) | |
| client_host = request.client.host if request.client else "-" | |
| query = f"?{request.url.query}" if request.url.query else "" | |
| logger.info( | |
| f"β‘οΈ {request_id} {request.method} {request.url.path}{query} " | |
| f"from {client_host} ua={headers.get('user-agent', '-')}" | |
| ) | |
| if body_text: | |
| logger.info(f" body={body_text}") | |
| try: | |
| response = await call_next(request) | |
| elapsed_ms = (time.perf_counter() - start) * 1000 | |
| logger.info(f"β¬ οΈ {request_id} {response.status_code} {elapsed_ms:.1f}ms") | |
| return response | |
| except Exception: | |
| elapsed_ms = (time.perf_counter() - start) * 1000 | |
| logger.exception(f"β¬ οΈ {request_id} 500 {elapsed_ms:.1f}ms unhandled error") | |
| raise | |
| class ModelSwitchRequest(BaseModel): | |
| """Request to switch the active LLM model.""" | |
| model_name: str = Field( | |
| ..., | |
| description="Name of the model to switch to", | |
| examples=["deepseek-chat", "mistral-7b", "deepseek-coder"] | |
| ) | |
| model_config = { | |
| "json_schema_extra": { | |
| "examples": [ | |
| {"model_name": "deepseek-coder"}, | |
| {"model_name": "mistral-7b"} | |
| ] | |
| } | |
| } | |
| class ChatCompletionRequest(BaseModel): | |
| """OpenAI-compatible chat completion request.""" | |
| messages: list[dict] = Field( | |
| ..., | |
| description="Array of message objects with 'role' and 'content' fields", | |
| examples=[[ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": "Hello!"} | |
| ]] | |
| ) | |
| max_tokens: int = Field( | |
| default=256, | |
| description="Maximum number of tokens to generate", | |
| ge=1, | |
| le=4096 | |
| ) | |
| temperature: float = Field( | |
| default=0.7, | |
| description="Sampling temperature (0.0 to 2.0). Higher values make output more random.", | |
| ge=0.0, | |
| le=2.0 | |
| ) | |
| model: Optional[str] = Field( | |
| default=None, | |
| description="Optional model name (ignored by this server; use /switch-model)." | |
| ) | |
| model_config = { | |
| "json_schema_extra": { | |
| "examples": [ | |
| { | |
| "messages": [ | |
| {"role": "user", "content": "What is the capital of France?"} | |
| ], | |
| "max_tokens": 100, | |
| "temperature": 0.7 | |
| } | |
| ] | |
| } | |
| } | |
| class WebChatRequest(BaseModel): | |
| """Chat completion request with web search augmentation.""" | |
| messages: list[dict] = Field( | |
| ..., | |
| description="Array of message objects. The last user message is used for web search.", | |
| examples=[[ | |
| {"role": "user", "content": "What are the latest developments in AI?"} | |
| ]] | |
| ) | |
| max_tokens: int = Field( | |
| default=512, | |
| description="Maximum number of tokens to generate", | |
| ge=1, | |
| le=4096 | |
| ) | |
| temperature: float = Field( | |
| default=0.7, | |
| description="Sampling temperature (0.0 to 2.0)", | |
| ge=0.0, | |
| le=2.0 | |
| ) | |
| max_search_results: int = Field( | |
| default=5, | |
| description="Maximum number of web search results to include in context", | |
| ge=1, | |
| le=10 | |
| ) | |
| model_config = { | |
| "json_schema_extra": { | |
| "examples": [ | |
| { | |
| "messages": [ | |
| {"role": "user", "content": "What's the weather like today in San Francisco?"} | |
| ], | |
| "max_tokens": 512, | |
| "temperature": 0.7, | |
| "max_search_results": 5 | |
| } | |
| ] | |
| } | |
| } | |
| class StatusResponse(BaseModel): | |
| """API status response.""" | |
| status: str = Field(..., description="Current API status") | |
| current_model: str = Field(..., description="Currently active model") | |
| available_models: list[str] = Field(..., description="List of available models") | |
| class ModelsResponse(BaseModel): | |
| """Available models response.""" | |
| current_model: str = Field(..., description="Currently active model") | |
| available_models: list[str] = Field(..., description="List of all available models") | |
| class ModelSwitchResponse(BaseModel): | |
| """Model switch response.""" | |
| message: str = Field(..., description="Status message") | |
| model: str = Field(..., description="New active model name") | |
| async def start_llama_server(model_id: str, port: int) -> tuple[subprocess.Popen, float]: | |
| """ | |
| Start llama-server with specified model on a specific port. | |
| Returns tuple of (process, load_time_seconds). | |
| Uses async/await with exponential backoff for health checks. | |
| """ | |
| start_time = time.time() | |
| cmd = [ | |
| "llama-server", | |
| "-hf", model_id, | |
| "--host", "0.0.0.0", | |
| "--port", str(port), | |
| "-c", str(CONTEXT_SIZE), # Context size | |
| "-t", "4", # CPU threads | |
| "-ngl", "0", # GPU layers (0 for CPU-only) | |
| "--cont-batching", # Enable continuous batching | |
| "-b", "512", # Batch size | |
| ] | |
| logger.info(f"Starting llama-server with model: {model_id} on port {port}") | |
| process = subprocess.Popen( | |
| cmd, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| preexec_fn=os.setsid if os.name != 'nt' else None, | |
| text=True, | |
| bufsize=1 | |
| ) | |
| # Wait for server to be ready with exponential backoff | |
| server_url = f"http://localhost:{port}" | |
| max_wait_time = 300 # 5 minutes | |
| backoff_time = 0.1 # Start with 100ms | |
| max_backoff = 2.0 # Max 2 seconds between checks | |
| elapsed = 0 | |
| while elapsed < max_wait_time: | |
| # Check if process died | |
| if process.poll() is not None: | |
| # Process died - collect output for debugging | |
| try: | |
| stdout = process.stdout.read() if process.stdout else "" | |
| except: | |
| stdout = "Unable to read stdout" | |
| logger.error(f"llama-server exited with code {process.returncode}") | |
| logger.error(f"Model ID: {model_id}") | |
| logger.error(f"Port: {port}") | |
| logger.error(f"Output:\n{stdout}") | |
| # Provide helpful error message | |
| error_msg = f"llama-server process died (exit code {process.returncode})" | |
| if "HTTPS is not supported" in str(stdout): | |
| error_msg += "\n\nHTTPS support is missing. The llama-server binary needs to be rebuilt with CURL/SSL support." | |
| error_msg += "\nAdd -DLLAMA_CURL=ON to the cmake build flags." | |
| elif "no usable GPU found" in str(stdout): | |
| error_msg += "\n\nNote: Running on CPU only (no GPU detected)." | |
| raise RuntimeError(error_msg) | |
| try: | |
| # Use aiohttp for async health check | |
| async with http_session.get(f"{server_url}/health", timeout=aiohttp.ClientTimeout(total=2)) as response: | |
| if response.status in [200, 404]: # 404 is ok, means server is up | |
| load_time = time.time() - start_time | |
| logger.info(f"llama-server ready after {load_time:.2f}s") | |
| return process, load_time | |
| except (aiohttp.ClientError, asyncio.TimeoutError): | |
| # Server not ready yet | |
| pass | |
| # Exponential backoff | |
| await asyncio.sleep(backoff_time) | |
| elapsed += backoff_time | |
| backoff_time = min(backoff_time * 1.5, max_backoff) | |
| raise RuntimeError("llama-server failed to start within 5 minutes") | |
| async def preload_models_background(): | |
| """Background task to preload popular models.""" | |
| if not PRELOAD_MODELS: | |
| return | |
| logger.info(f"Preloading models in background: {PRELOAD_MODELS}") | |
| for model_name in PRELOAD_MODELS: | |
| if model_name not in AVAILABLE_MODELS: | |
| logger.warning(f"Preload model not found: {model_name}") | |
| continue | |
| if model_cache.get(model_name): | |
| logger.info(f"Model already cached: {model_name}") | |
| continue | |
| try: | |
| model_id = AVAILABLE_MODELS[model_name] | |
| port = model_cache._get_next_port() | |
| process, load_time = await start_llama_server(model_id, port) | |
| await model_cache.put(model_name, model_id, process, port, load_time) | |
| logger.info(f"Preloaded model: {model_name}") | |
| except Exception as e: | |
| logger.error(f"Failed to preload model {model_name}: {e}") | |
| async def startup_event(): | |
| """Initialize HTTP session and start with default model.""" | |
| global current_model, http_session | |
| startup_start = time.time() | |
| logger.info("Application startup initiated") | |
| # Initialize aiohttp session with connection pooling | |
| connector = aiohttp.TCPConnector( | |
| limit=100, # Max total connections | |
| limit_per_host=10, # Max connections per host | |
| ttl_dns_cache=300 # DNS cache TTL | |
| ) | |
| http_session = aiohttp.ClientSession( | |
| connector=connector, | |
| timeout=aiohttp.ClientTimeout(total=REQUEST_TIMEOUT) | |
| ) | |
| # Start default model | |
| model_id = AVAILABLE_MODELS[current_model] | |
| port = model_cache._get_next_port() | |
| try: | |
| process, load_time = await start_llama_server(model_id, port) | |
| await model_cache.put(current_model, model_id, process, port, load_time) | |
| metrics.startup_time = time.time() - startup_start | |
| logger.info(f"Started with default model: {current_model} (total startup: {metrics.startup_time:.2f}s)") | |
| # Start preloading in background | |
| asyncio.create_task(preload_models_background()) | |
| except Exception as e: | |
| # Clean up on startup failure | |
| logger.error(f"Startup failed: {e}") | |
| if http_session: | |
| await http_session.close() | |
| model_cache._release_port(port) | |
| raise | |
| async def shutdown_event(): | |
| """Clean shutdown - clear cache and close HTTP session.""" | |
| logger.info("Application shutdown initiated") | |
| # Clear model cache first | |
| try: | |
| await model_cache.clear() | |
| except Exception as e: | |
| logger.error(f"Error clearing cache during shutdown: {e}") | |
| # Close HTTP session | |
| if http_session and not http_session.closed: | |
| try: | |
| await http_session.close() | |
| # Give it a moment to close gracefully | |
| await asyncio.sleep(0.1) | |
| except Exception as e: | |
| logger.error(f"Error closing HTTP session: {e}") | |
| async def root(): | |
| """ | |
| Returns the current status of the AGI Multi-Model API. | |
| This endpoint provides information about: | |
| - Current API status | |
| - Currently active LLM model | |
| - List of all available models | |
| """ | |
| return { | |
| "status": "AGI Multi-Model API - High Performance Edition", | |
| "current_model": current_model, | |
| "available_models": list(AVAILABLE_MODELS.keys()) | |
| } | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| return { | |
| "status": "healthy", | |
| "timestamp": time.time(), | |
| "cached_models": len(model_cache.cache), | |
| "current_model": current_model | |
| } | |
| async def list_models(): | |
| """ | |
| List all available LLM models. | |
| Returns: | |
| - current_model: The model currently in use | |
| - available_models: Array of all available model names | |
| """ | |
| return { | |
| "current_model": current_model, | |
| "available_models": list(AVAILABLE_MODELS.keys()) | |
| } | |
| async def switch_model(request: ModelSwitchRequest): | |
| """ | |
| Switch to a different LLM model with intelligent caching. | |
| **Performance optimizations:** | |
| - Instant switching for cached models | |
| - Async model loading with exponential backoff | |
| - Connection pooling for health checks | |
| - Background preloading of popular models | |
| """ | |
| global current_model | |
| if request.model_name not in AVAILABLE_MODELS: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Model '{request.model_name}' not found. Available: {list(AVAILABLE_MODELS.keys())}" | |
| ) | |
| if request.model_name == current_model: | |
| return {"message": f"Already using model: {current_model}", "model": current_model} | |
| metrics.total_switches += 1 | |
| # Try to get from cache | |
| cached_model = model_cache.get(request.model_name) | |
| if cached_model: | |
| # Model is cached, instant switch | |
| metrics.cache_hits += 1 | |
| current_model = request.model_name | |
| return { | |
| "message": f"Switched to model: {current_model} (from cache, instant)", | |
| "model": current_model | |
| } | |
| # Model not cached, need to load it | |
| metrics.cache_misses += 1 | |
| model_id = AVAILABLE_MODELS[request.model_name] | |
| port = model_cache._get_next_port() | |
| try: | |
| process, load_time = await start_llama_server(model_id, port) | |
| await model_cache.put(request.model_name, model_id, process, port, load_time) | |
| current_model = request.model_name | |
| return { | |
| "message": f"Switched to model: {current_model} (loaded in {load_time:.2f}s)", | |
| "model": current_model | |
| } | |
| except Exception as e: | |
| # Release port if failed | |
| model_cache._release_port(port) | |
| raise HTTPException(status_code=500, detail=f"Failed to load model: {str(e)}") | |
| async def chat_completions(request: ChatCompletionRequest, raw_request: Request): | |
| """ | |
| OpenAI-compatible chat completions with performance optimizations. | |
| **Performance features:** | |
| - Async/await for non-blocking I/O | |
| - HTTP connection pooling | |
| - Request metrics tracking | |
| """ | |
| request_id = getattr(raw_request.state, "request_id", "-") | |
| slow_task: Optional[asyncio.Task] = None | |
| try: | |
| request_start = time.time() | |
| if not http_session or http_session.closed: | |
| raise HTTPException(status_code=500, detail="HTTP session not initialized") | |
| # Get current model from cache | |
| cached_model = model_cache.get(current_model) | |
| if not cached_model: | |
| raise HTTPException(status_code=500, detail="Current model not loaded") | |
| # Forward to llama-server using aiohttp | |
| prompt_budget = CONTEXT_SIZE - request.max_tokens - PROMPT_MARGIN_TOKENS | |
| original_tokens = _estimate_messages_tokens(request.messages) | |
| if prompt_budget > 0 and original_tokens > prompt_budget: | |
| logger.warning( | |
| f"request_id={request_id} prompt_compaction " | |
| f"original_tokensβ{original_tokens} budgetβ{prompt_budget}" | |
| ) | |
| # Drop system prompts above cap unless allowed | |
| if not ALLOW_LONG_SYSTEM_PROMPT: | |
| for msg in request.messages: | |
| if msg.get("role") == "system" and _estimate_tokens(str(msg.get("content", ""))) > SYSTEM_PROMPT_MAX_TOKENS: | |
| logger.warning( | |
| f"request_id={request_id} system_prompt_dropped " | |
| f"tokensβ{_estimate_tokens(str(msg.get('content', '')))} capβ{SYSTEM_PROMPT_MAX_TOKENS}" | |
| ) | |
| break | |
| compacted_messages = _compact_messages(request.messages, request.max_tokens) | |
| compacted_tokens = _estimate_messages_tokens(compacted_messages) | |
| if compacted_tokens < original_tokens: | |
| logger.info( | |
| f"request_id={request_id} prompt_compacted " | |
| f"tokensβ{original_tokens}->{compacted_tokens}" | |
| ) | |
| async def _slow_request_logger(): | |
| await asyncio.sleep(30) | |
| elapsed = time.time() - request_start | |
| logger.warning(f"request_id={request_id} slow_request {elapsed:.1f}s") | |
| slow_task = asyncio.create_task(_slow_request_logger()) | |
| payload = { | |
| "messages": compacted_messages, | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| } | |
| async def _do_request(): | |
| async with http_session.post( | |
| f"{cached_model.url}/v1/chat/completions", | |
| json=payload | |
| ) as response: | |
| if response.status >= 400: | |
| error_text = await response.text() | |
| logger.error( | |
| f"request_id={request_id} llama-server {response.status} " | |
| f"error_body={error_text[:1000]}" | |
| ) | |
| response.raise_for_status() | |
| return await response.json() | |
| result = await asyncio.wait_for(_do_request(), timeout=HARD_REQUEST_TIMEOUT) | |
| # Update metrics | |
| request_latency = time.time() - request_start | |
| cached_model.request_count += 1 | |
| cached_model.total_latency += request_latency | |
| metrics.record_request(current_model, request_latency) | |
| # Log elapsed time and token rate (if usage available) | |
| usage = result.get("usage") if isinstance(result, dict) else None | |
| if usage and usage.get("completion_tokens"): | |
| completion_tokens = usage.get("completion_tokens", 0) | |
| tok_per_sec = completion_tokens / max(request_latency, 1e-6) | |
| logger.info( | |
| f"request_id={request_id} done " | |
| f"time={request_latency:.2f}s tokens={completion_tokens} tok/s={tok_per_sec:.1f}" | |
| ) | |
| else: | |
| logger.info(f"request_id={request_id} done time={request_latency:.2f}s") | |
| return result | |
| except aiohttp.ClientResponseError as e: | |
| logger.exception(f"request_id={request_id} llama-server error") | |
| raise HTTPException(status_code=e.status, detail=f"llama-server error: {e.message}") | |
| except aiohttp.ClientError as e: | |
| logger.exception(f"request_id={request_id} llama-server error") | |
| raise HTTPException(status_code=500, detail=f"llama-server error: {str(e)}") | |
| except asyncio.TimeoutError: | |
| logger.error(f"request_id={request_id} timeout after {HARD_REQUEST_TIMEOUT}s") | |
| raise HTTPException(status_code=504, detail="Upstream model timed out. Please retry.") | |
| except Exception: | |
| logger.exception(f"request_id={request_id} chat_completions error") | |
| raise | |
| finally: | |
| if slow_task and not slow_task.done(): | |
| slow_task.cancel() | |
| async def search_web_async(query: str, max_results: int = 5) -> list[dict]: | |
| """ | |
| Search the web using DuckDuckGo with result caching. | |
| Implements LRU cache with TTL for search results. | |
| """ | |
| # Check cache first | |
| cached_results = web_search_cache.get(query, max_results) | |
| if cached_results is not None: | |
| metrics.web_search_cache_hits += 1 | |
| logger.debug(f"Web search cache hit for: {query}") | |
| return cached_results | |
| # Perform search | |
| try: | |
| logger.debug(f"Performing web search: {query}") | |
| # Run blocking DDGS in thread pool to avoid blocking event loop | |
| loop = asyncio.get_event_loop() | |
| results = await loop.run_in_executor( | |
| None, | |
| lambda: list(DDGS().text(query, max_results=max_results)) | |
| ) | |
| # Cache results | |
| web_search_cache.put(query, max_results, results) | |
| metrics.total_web_searches += 1 | |
| logger.debug(f"Found {len(results)} search results") | |
| return results | |
| except Exception as e: | |
| logger.error(f"Search error: {e}") | |
| return [] | |
| def format_search_context(query: str, search_results: list[dict]) -> str: | |
| """Format search results into context for the LLM.""" | |
| if not search_results: | |
| return f"No web results found for: {query}" | |
| context = f"# Web Search Results for: {query}\n\n" | |
| for i, result in enumerate(search_results, 1): | |
| title = result.get("title", "No title") | |
| body = result.get("body", "No description") | |
| url = result.get("href", "") | |
| context += f"## Result {i}: {title}\n" | |
| context += f"{body}\n" | |
| if url: | |
| context += f"Source: {url}\n" | |
| context += "\n" | |
| return context | |
| async def web_chat_completions(request: WebChatRequest, raw_request: Request): | |
| """ | |
| Chat completions with web search augmentation. | |
| **Performance optimizations:** | |
| - Async web search | |
| - LRU cache for search results (1 hour TTL) | |
| - Parallel execution where possible | |
| """ | |
| request_id = getattr(raw_request.state, "request_id", "-") | |
| slow_task: Optional[asyncio.Task] = None | |
| try: | |
| # Get the last user message as search query | |
| user_messages = [msg for msg in request.messages if msg.get("role") == "user"] | |
| if not user_messages: | |
| raise HTTPException(status_code=400, detail="No user message found") | |
| search_query = user_messages[-1].get("content", "") | |
| # Perform web search (async with caching) | |
| logger.info(f"Web chat: Searching for '{search_query}'") | |
| search_results = await search_web_async(search_query, request.max_search_results) | |
| # Format search results as context | |
| web_context = format_search_context(search_query, search_results) | |
| # Create augmented messages with web context | |
| augmented_messages = request.messages.copy() | |
| # Insert web context as a system message before the last user message | |
| system_prompt = { | |
| "role": "system", | |
| "content": f"""You are a helpful assistant with access to current web information. | |
| {web_context} | |
| Use the above search results to provide accurate, up-to-date information in your response. | |
| Always cite sources when using information from the search results.""" | |
| } | |
| augmented_messages.insert(-1, system_prompt) | |
| if not http_session or http_session.closed: | |
| raise HTTPException(status_code=500, detail="HTTP session not initialized") | |
| # Compact messages to fit within context | |
| prompt_budget = CONTEXT_SIZE - request.max_tokens - PROMPT_MARGIN_TOKENS | |
| original_tokens = _estimate_messages_tokens(augmented_messages) | |
| if prompt_budget > 0 and original_tokens > prompt_budget: | |
| logger.warning( | |
| f"request_id={request_id} prompt_compaction " | |
| f"original_tokensβ{original_tokens} budgetβ{prompt_budget}" | |
| ) | |
| # Drop system prompts above cap unless allowed | |
| if not ALLOW_LONG_SYSTEM_PROMPT: | |
| for msg in augmented_messages: | |
| if msg.get("role") == "system" and _estimate_tokens(str(msg.get("content", ""))) > SYSTEM_PROMPT_MAX_TOKENS: | |
| logger.warning( | |
| f"request_id={request_id} system_prompt_dropped " | |
| f"tokensβ{_estimate_tokens(str(msg.get('content', '')))} capβ{SYSTEM_PROMPT_MAX_TOKENS}" | |
| ) | |
| break | |
| augmented_messages = _compact_messages(augmented_messages, request.max_tokens) | |
| compacted_tokens = _estimate_messages_tokens(augmented_messages) | |
| if compacted_tokens < original_tokens: | |
| logger.info( | |
| f"request_id={request_id} prompt_compacted " | |
| f"tokensβ{original_tokens}->{compacted_tokens}" | |
| ) | |
| async def _slow_request_logger(): | |
| await asyncio.sleep(30) | |
| logger.warning(f"request_id={request_id} slow_request 30.0s") | |
| slow_task = asyncio.create_task(_slow_request_logger()) | |
| # Get current model from cache | |
| cached_model = model_cache.get(current_model) | |
| if not cached_model: | |
| raise HTTPException(status_code=500, detail="Current model not loaded") | |
| async def _do_request(): | |
| # Forward to llama-server with augmented context | |
| async with http_session.post( | |
| f"{cached_model.url}/v1/chat/completions", | |
| json={ | |
| "messages": augmented_messages, | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| } | |
| ) as response: | |
| response.raise_for_status() | |
| return await response.json() | |
| result = await asyncio.wait_for(_do_request(), timeout=HARD_REQUEST_TIMEOUT) | |
| # Add metadata about search results | |
| result["web_search"] = { | |
| "query": search_query, | |
| "results_count": len(search_results), | |
| "sources": [r.get("href", "") for r in search_results if r.get("href")], | |
| "cached": metrics.web_search_cache_hits > 0 | |
| } | |
| return result | |
| except aiohttp.ClientError as e: | |
| logger.exception(f"request_id={request_id} llama-server error") | |
| raise HTTPException(status_code=500, detail=f"llama-server error: {str(e)}") | |
| except asyncio.TimeoutError: | |
| logger.error(f"request_id={request_id} timeout after {HARD_REQUEST_TIMEOUT}s") | |
| raise HTTPException(status_code=504, detail="Upstream model timed out. Please retry.") | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.exception(f"request_id={request_id} web_chat_completions error") | |
| raise HTTPException(status_code=500, detail=f"Error: {str(e)}") | |
| finally: | |
| if slow_task and not slow_task.done(): | |
| slow_task.cancel() | |
| async def get_cache_info(): | |
| """Get detailed information about the model cache.""" | |
| return model_cache.get_cache_info() | |
| async def get_metrics(): | |
| """ | |
| Get performance metrics for monitoring and optimization. | |
| Returns: | |
| - Request counts and latencies | |
| - Cache hit/miss ratios | |
| - Model-specific statistics | |
| - Web search cache stats | |
| - Startup time | |
| """ | |
| cache_hit_rate = ( | |
| metrics.cache_hits / (metrics.cache_hits + metrics.cache_misses) | |
| if (metrics.cache_hits + metrics.cache_misses) > 0 | |
| else 0.0 | |
| ) | |
| web_cache_hit_rate = ( | |
| metrics.web_search_cache_hits / metrics.total_web_searches | |
| if metrics.total_web_searches > 0 | |
| else 0.0 | |
| ) | |
| return { | |
| "uptime_seconds": time.time() - (metrics.startup_time or time.time()), | |
| "startup_time_seconds": metrics.startup_time, | |
| "total_requests": metrics.total_requests, | |
| "total_model_switches": metrics.total_switches, | |
| "cache_stats": { | |
| "hits": metrics.cache_hits, | |
| "misses": metrics.cache_misses, | |
| "hit_rate": cache_hit_rate, | |
| "current_size": len(model_cache.cache), | |
| "max_size": model_cache.max_size | |
| }, | |
| "web_search_stats": { | |
| "total_searches": metrics.total_web_searches, | |
| "cache_hits": metrics.web_search_cache_hits, | |
| "cache_hit_rate": web_cache_hit_rate, | |
| "cache_size": len(web_search_cache.cache) | |
| }, | |
| "model_metrics": metrics.model_metrics, | |
| "cached_models": model_cache.get_cache_info()["cached_models"] | |
| } | |
| async def clear_cache(): | |
| """Clear all cached models.""" | |
| await model_cache.clear() | |
| return {"message": "Cache cleared successfully"} | |
| async def clear_web_search_cache(): | |
| """Clear web search cache.""" | |
| web_search_cache.clear() | |
| metrics.web_search_cache_hits = 0 | |
| metrics.total_web_searches = 0 | |
| return {"message": "Web search cache cleared successfully"} | |
| async def get_openapi_spec(): | |
| """Export the OpenAPI specification for this API.""" | |
| return app.openapi() | |