#!/usr/bin/env python """ FastAPI Application for ContinuumAgent Project Serves the model with patched knowledge Modified for better error handling and compatibility with Hugging Face Spaces """ import os import time import traceback from typing import Dict, List, Any, Optional from fastapi import FastAPI, HTTPException, BackgroundTasks, Query, Path, Depends from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from app.router import ContinuumRouter # Define API models class GenerateRequest(BaseModel): prompt: str = Field(..., description="User input prompt") system_prompt: Optional[str] = Field(None, description="Optional system prompt") max_tokens: int = Field(256, description="Maximum number of tokens to generate") temperature: float = Field(0.7, description="Sampling temperature (0.0-1.0)") top_p: float = Field(0.95, description="Top-p sampling parameter (0.0-1.0)") auto_route: bool = Field(True, description="Auto-route based on query complexity") force_patches: Optional[bool] = Field(None, description="Force usage of patches") class GenerateResponse(BaseModel): text: str = Field(..., description="Generated text") elapsed_seconds: float = Field(..., description="Elapsed time in seconds") used_patches: bool = Field(..., description="Whether patches were used") adapter_paths: List[str] = Field(default_factory=list, description="Paths to used adapters") total_tokens: int = Field(0, description="Total tokens used") class ModelInfo(BaseModel): name: str = Field(..., description="Model name") quantization: str = Field(..., description="Quantization format") patches: List[Dict[str, Any]] = Field(default_factory=list, description="Available patches") using_gpu: bool = Field(False, description="Whether GPU is being used") class StatusResponse(BaseModel): status: str = Field(..., description="Service status") model_info: Optional[ModelInfo] = Field(None, description="Model information") uptime_seconds: float = Field(..., description="Service uptime in seconds") processed_requests: int = Field(0, description="Number of processed requests") is_model_loaded: bool = Field(False, description="Whether model is successfully loaded") # Create FastAPI application app = FastAPI( title="ContinuumAgent API", description="API for the ContinuumAgent knowledge patching system", version="0.1.0", ) # Add CORS middleware for Hugging Face Spaces app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global variables start_time = time.time() request_count = 0 continuum_router = None model_load_error = None @app.on_event("startup") async def startup_event(): """Initialize the router on startup""" global continuum_router, model_load_error # Find model path model_dir = "models/slow" os.makedirs(model_dir, exist_ok=True) model_files = [f for f in os.listdir(model_dir) if f.endswith(".gguf")] if not model_files: model_load_error = "No GGUF models found. Please run download_model.py first." print(f"Error: {model_load_error}") return model_path = os.path.join(model_dir, model_files[0]) print(f"Using model: {model_path}") # Get GPU layers setting from environment n_gpu_layers = int(os.environ.get("N_GPU_LAYERS", "0")) # Initialize router try: continuum_router = ContinuumRouter( model_path=model_path, n_gpu_layers=n_gpu_layers ) # Load patches (can be done in background) continuum_router.load_latest_patches() except Exception as e: error_traceback = traceback.format_exc() model_load_error = f"Error initializing router: {str(e)}\n{error_traceback}" print(model_load_error) def get_router(): """Get router dependency""" if continuum_router is None: raise HTTPException( status_code=503, detail=f"Service not fully initialized: {model_load_error or 'Unknown error'}" ) return continuum_router @app.get("/", response_model=StatusResponse) async def get_status(): """Get service status""" global start_time, request_count, continuum_router, model_load_error # Create base response status_response = StatusResponse( status="initializing" if model_load_error else "running", uptime_seconds=time.time() - start_time, processed_requests=request_count, is_model_loaded=continuum_router is not None ) # Add model info if available if continuum_router: try: model_info = continuum_router.get_model_info() status_response.model_info = model_info except Exception as e: print(f"Error getting model info: {e}") # Add error information if applicable if model_load_error: status_response.status = f"error: {model_load_error.split(chr(10))[0]}" return status_response @app.post("/generate", response_model=GenerateResponse) async def generate(request: GenerateRequest, router: ContinuumRouter = Depends(get_router)): """Generate text from model""" global request_count request_count += 1 try: # Generate text result = router.generate( prompt=request.prompt, system_prompt=request.system_prompt, max_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, auto_route=request.auto_route, force_patches=request.force_patches ) return result except Exception as e: error_traceback = traceback.format_exc() raise HTTPException( status_code=500, detail=f"Error generating text: {str(e)}\n{error_traceback}" ) @app.post("/patches/load") async def load_patches(date_str: Optional[str] = None, router: ContinuumRouter = Depends(get_router)): """Load patches for a specific date""" try: # Load patches loaded = router.load_patches(date_str) return {"status": "success", "loaded_patches": loaded} except Exception as e: error_traceback = traceback.format_exc() raise HTTPException( status_code=500, detail=f"Error loading patches: {str(e)}\n{error_traceback}" ) @app.get("/patches/list") async def list_patches(router: ContinuumRouter = Depends(get_router)): """List available patches""" try: # Get patches patches = router.list_patches() return {"patches": patches} except Exception as e: error_traceback = traceback.format_exc() raise HTTPException( status_code=500, detail=f"Error listing patches: {str(e)}\n{error_traceback}" ) @app.get("/patches/active") async def get_active_patches(router: ContinuumRouter = Depends(get_router)): """Get currently active patches""" try: # Get active patches active = router.get_active_patches() return {"active_patches": active} except Exception as e: error_traceback = traceback.format_exc() raise HTTPException( status_code=500, detail=f"Error getting active patches: {str(e)}\n{error_traceback}" ) # Health check endpoint for Hugging Face Spaces @app.get("/health") async def health_check(): """Health check endpoint""" return {"status": "ok", "model_loaded": continuum_router is not None}