File size: 8,000 Bytes
19f6105 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
#!/usr/bin/env python
"""
FastAPI Application for ContinuumAgent Project
Serves the model with patched knowledge
Modified for better error handling and compatibility with Hugging Face Spaces
"""
import os
import time
import traceback
from typing import Dict, List, Any, Optional
from fastapi import FastAPI, HTTPException, BackgroundTasks, Query, Path, Depends
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from app.router import ContinuumRouter
# Define API models
class GenerateRequest(BaseModel):
prompt: str = Field(..., description="User input prompt")
system_prompt: Optional[str] = Field(None, description="Optional system prompt")
max_tokens: int = Field(256, description="Maximum number of tokens to generate")
temperature: float = Field(0.7, description="Sampling temperature (0.0-1.0)")
top_p: float = Field(0.95, description="Top-p sampling parameter (0.0-1.0)")
auto_route: bool = Field(True, description="Auto-route based on query complexity")
force_patches: Optional[bool] = Field(None, description="Force usage of patches")
class GenerateResponse(BaseModel):
text: str = Field(..., description="Generated text")
elapsed_seconds: float = Field(..., description="Elapsed time in seconds")
used_patches: bool = Field(..., description="Whether patches were used")
adapter_paths: List[str] = Field(default_factory=list, description="Paths to used adapters")
total_tokens: int = Field(0, description="Total tokens used")
class ModelInfo(BaseModel):
name: str = Field(..., description="Model name")
quantization: str = Field(..., description="Quantization format")
patches: List[Dict[str, Any]] = Field(default_factory=list, description="Available patches")
using_gpu: bool = Field(False, description="Whether GPU is being used")
class StatusResponse(BaseModel):
status: str = Field(..., description="Service status")
model_info: Optional[ModelInfo] = Field(None, description="Model information")
uptime_seconds: float = Field(..., description="Service uptime in seconds")
processed_requests: int = Field(0, description="Number of processed requests")
is_model_loaded: bool = Field(False, description="Whether model is successfully loaded")
# Create FastAPI application
app = FastAPI(
title="ContinuumAgent API",
description="API for the ContinuumAgent knowledge patching system",
version="0.1.0",
)
# Add CORS middleware for Hugging Face Spaces
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables
start_time = time.time()
request_count = 0
continuum_router = None
model_load_error = None
@app.on_event("startup")
async def startup_event():
"""Initialize the router on startup"""
global continuum_router, model_load_error
# Find model path
model_dir = "models/slow"
os.makedirs(model_dir, exist_ok=True)
model_files = [f for f in os.listdir(model_dir) if f.endswith(".gguf")]
if not model_files:
model_load_error = "No GGUF models found. Please run download_model.py first."
print(f"Error: {model_load_error}")
return
model_path = os.path.join(model_dir, model_files[0])
print(f"Using model: {model_path}")
# Get GPU layers setting from environment
n_gpu_layers = int(os.environ.get("N_GPU_LAYERS", "0"))
# Initialize router
try:
continuum_router = ContinuumRouter(
model_path=model_path,
n_gpu_layers=n_gpu_layers
)
# Load patches (can be done in background)
continuum_router.load_latest_patches()
except Exception as e:
error_traceback = traceback.format_exc()
model_load_error = f"Error initializing router: {str(e)}\n{error_traceback}"
print(model_load_error)
def get_router():
"""Get router dependency"""
if continuum_router is None:
raise HTTPException(
status_code=503,
detail=f"Service not fully initialized: {model_load_error or 'Unknown error'}"
)
return continuum_router
@app.get("/", response_model=StatusResponse)
async def get_status():
"""Get service status"""
global start_time, request_count, continuum_router, model_load_error
# Create base response
status_response = StatusResponse(
status="initializing" if model_load_error else "running",
uptime_seconds=time.time() - start_time,
processed_requests=request_count,
is_model_loaded=continuum_router is not None
)
# Add model info if available
if continuum_router:
try:
model_info = continuum_router.get_model_info()
status_response.model_info = model_info
except Exception as e:
print(f"Error getting model info: {e}")
# Add error information if applicable
if model_load_error:
status_response.status = f"error: {model_load_error.split(chr(10))[0]}"
return status_response
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest, router: ContinuumRouter = Depends(get_router)):
"""Generate text from model"""
global request_count
request_count += 1
try:
# Generate text
result = router.generate(
prompt=request.prompt,
system_prompt=request.system_prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
auto_route=request.auto_route,
force_patches=request.force_patches
)
return result
except Exception as e:
error_traceback = traceback.format_exc()
raise HTTPException(
status_code=500,
detail=f"Error generating text: {str(e)}\n{error_traceback}"
)
@app.post("/patches/load")
async def load_patches(date_str: Optional[str] = None,
router: ContinuumRouter = Depends(get_router)):
"""Load patches for a specific date"""
try:
# Load patches
loaded = router.load_patches(date_str)
return {"status": "success", "loaded_patches": loaded}
except Exception as e:
error_traceback = traceback.format_exc()
raise HTTPException(
status_code=500,
detail=f"Error loading patches: {str(e)}\n{error_traceback}"
)
@app.get("/patches/list")
async def list_patches(router: ContinuumRouter = Depends(get_router)):
"""List available patches"""
try:
# Get patches
patches = router.list_patches()
return {"patches": patches}
except Exception as e:
error_traceback = traceback.format_exc()
raise HTTPException(
status_code=500,
detail=f"Error listing patches: {str(e)}\n{error_traceback}"
)
@app.get("/patches/active")
async def get_active_patches(router: ContinuumRouter = Depends(get_router)):
"""Get currently active patches"""
try:
# Get active patches
active = router.get_active_patches()
return {"active_patches": active}
except Exception as e:
error_traceback = traceback.format_exc()
raise HTTPException(
status_code=500,
detail=f"Error getting active patches: {str(e)}\n{error_traceback}"
)
# Health check endpoint for Hugging Face Spaces
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "ok", "model_loaded": continuum_router is not None} |