ContinuumAgent / app\main.py
deasdutta's picture
Upload app\main.py with huggingface_hub
19f6105 verified
#!/usr/bin/env python
"""
FastAPI Application for ContinuumAgent Project
Serves the model with patched knowledge
Modified for better error handling and compatibility with Hugging Face Spaces
"""
import os
import time
import traceback
from typing import Dict, List, Any, Optional
from fastapi import FastAPI, HTTPException, BackgroundTasks, Query, Path, Depends
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from app.router import ContinuumRouter
# Define API models
class GenerateRequest(BaseModel):
prompt: str = Field(..., description="User input prompt")
system_prompt: Optional[str] = Field(None, description="Optional system prompt")
max_tokens: int = Field(256, description="Maximum number of tokens to generate")
temperature: float = Field(0.7, description="Sampling temperature (0.0-1.0)")
top_p: float = Field(0.95, description="Top-p sampling parameter (0.0-1.0)")
auto_route: bool = Field(True, description="Auto-route based on query complexity")
force_patches: Optional[bool] = Field(None, description="Force usage of patches")
class GenerateResponse(BaseModel):
text: str = Field(..., description="Generated text")
elapsed_seconds: float = Field(..., description="Elapsed time in seconds")
used_patches: bool = Field(..., description="Whether patches were used")
adapter_paths: List[str] = Field(default_factory=list, description="Paths to used adapters")
total_tokens: int = Field(0, description="Total tokens used")
class ModelInfo(BaseModel):
name: str = Field(..., description="Model name")
quantization: str = Field(..., description="Quantization format")
patches: List[Dict[str, Any]] = Field(default_factory=list, description="Available patches")
using_gpu: bool = Field(False, description="Whether GPU is being used")
class StatusResponse(BaseModel):
status: str = Field(..., description="Service status")
model_info: Optional[ModelInfo] = Field(None, description="Model information")
uptime_seconds: float = Field(..., description="Service uptime in seconds")
processed_requests: int = Field(0, description="Number of processed requests")
is_model_loaded: bool = Field(False, description="Whether model is successfully loaded")
# Create FastAPI application
app = FastAPI(
title="ContinuumAgent API",
description="API for the ContinuumAgent knowledge patching system",
version="0.1.0",
)
# Add CORS middleware for Hugging Face Spaces
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables
start_time = time.time()
request_count = 0
continuum_router = None
model_load_error = None
@app.on_event("startup")
async def startup_event():
"""Initialize the router on startup"""
global continuum_router, model_load_error
# Find model path
model_dir = "models/slow"
os.makedirs(model_dir, exist_ok=True)
model_files = [f for f in os.listdir(model_dir) if f.endswith(".gguf")]
if not model_files:
model_load_error = "No GGUF models found. Please run download_model.py first."
print(f"Error: {model_load_error}")
return
model_path = os.path.join(model_dir, model_files[0])
print(f"Using model: {model_path}")
# Get GPU layers setting from environment
n_gpu_layers = int(os.environ.get("N_GPU_LAYERS", "0"))
# Initialize router
try:
continuum_router = ContinuumRouter(
model_path=model_path,
n_gpu_layers=n_gpu_layers
)
# Load patches (can be done in background)
continuum_router.load_latest_patches()
except Exception as e:
error_traceback = traceback.format_exc()
model_load_error = f"Error initializing router: {str(e)}\n{error_traceback}"
print(model_load_error)
def get_router():
"""Get router dependency"""
if continuum_router is None:
raise HTTPException(
status_code=503,
detail=f"Service not fully initialized: {model_load_error or 'Unknown error'}"
)
return continuum_router
@app.get("/", response_model=StatusResponse)
async def get_status():
"""Get service status"""
global start_time, request_count, continuum_router, model_load_error
# Create base response
status_response = StatusResponse(
status="initializing" if model_load_error else "running",
uptime_seconds=time.time() - start_time,
processed_requests=request_count,
is_model_loaded=continuum_router is not None
)
# Add model info if available
if continuum_router:
try:
model_info = continuum_router.get_model_info()
status_response.model_info = model_info
except Exception as e:
print(f"Error getting model info: {e}")
# Add error information if applicable
if model_load_error:
status_response.status = f"error: {model_load_error.split(chr(10))[0]}"
return status_response
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest, router: ContinuumRouter = Depends(get_router)):
"""Generate text from model"""
global request_count
request_count += 1
try:
# Generate text
result = router.generate(
prompt=request.prompt,
system_prompt=request.system_prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
auto_route=request.auto_route,
force_patches=request.force_patches
)
return result
except Exception as e:
error_traceback = traceback.format_exc()
raise HTTPException(
status_code=500,
detail=f"Error generating text: {str(e)}\n{error_traceback}"
)
@app.post("/patches/load")
async def load_patches(date_str: Optional[str] = None,
router: ContinuumRouter = Depends(get_router)):
"""Load patches for a specific date"""
try:
# Load patches
loaded = router.load_patches(date_str)
return {"status": "success", "loaded_patches": loaded}
except Exception as e:
error_traceback = traceback.format_exc()
raise HTTPException(
status_code=500,
detail=f"Error loading patches: {str(e)}\n{error_traceback}"
)
@app.get("/patches/list")
async def list_patches(router: ContinuumRouter = Depends(get_router)):
"""List available patches"""
try:
# Get patches
patches = router.list_patches()
return {"patches": patches}
except Exception as e:
error_traceback = traceback.format_exc()
raise HTTPException(
status_code=500,
detail=f"Error listing patches: {str(e)}\n{error_traceback}"
)
@app.get("/patches/active")
async def get_active_patches(router: ContinuumRouter = Depends(get_router)):
"""Get currently active patches"""
try:
# Get active patches
active = router.get_active_patches()
return {"active_patches": active}
except Exception as e:
error_traceback = traceback.format_exc()
raise HTTPException(
status_code=500,
detail=f"Error getting active patches: {str(e)}\n{error_traceback}"
)
# Health check endpoint for Hugging Face Spaces
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "ok", "model_loaded": continuum_router is not None}