|
|
|
|
|
"""
|
|
|
FastAPI Application for ContinuumAgent Project
|
|
|
Serves the model with patched knowledge
|
|
|
Modified for better error handling and compatibility with Hugging Face Spaces
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import time
|
|
|
import traceback
|
|
|
from typing import Dict, List, Any, Optional
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks, Query, Path, Depends
|
|
|
from fastapi.responses import JSONResponse
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
from pydantic import BaseModel, Field
|
|
|
from app.router import ContinuumRouter
|
|
|
|
|
|
|
|
|
class GenerateRequest(BaseModel):
|
|
|
prompt: str = Field(..., description="User input prompt")
|
|
|
system_prompt: Optional[str] = Field(None, description="Optional system prompt")
|
|
|
max_tokens: int = Field(256, description="Maximum number of tokens to generate")
|
|
|
temperature: float = Field(0.7, description="Sampling temperature (0.0-1.0)")
|
|
|
top_p: float = Field(0.95, description="Top-p sampling parameter (0.0-1.0)")
|
|
|
auto_route: bool = Field(True, description="Auto-route based on query complexity")
|
|
|
force_patches: Optional[bool] = Field(None, description="Force usage of patches")
|
|
|
|
|
|
class GenerateResponse(BaseModel):
|
|
|
text: str = Field(..., description="Generated text")
|
|
|
elapsed_seconds: float = Field(..., description="Elapsed time in seconds")
|
|
|
used_patches: bool = Field(..., description="Whether patches were used")
|
|
|
adapter_paths: List[str] = Field(default_factory=list, description="Paths to used adapters")
|
|
|
total_tokens: int = Field(0, description="Total tokens used")
|
|
|
|
|
|
class ModelInfo(BaseModel):
|
|
|
name: str = Field(..., description="Model name")
|
|
|
quantization: str = Field(..., description="Quantization format")
|
|
|
patches: List[Dict[str, Any]] = Field(default_factory=list, description="Available patches")
|
|
|
using_gpu: bool = Field(False, description="Whether GPU is being used")
|
|
|
|
|
|
class StatusResponse(BaseModel):
|
|
|
status: str = Field(..., description="Service status")
|
|
|
model_info: Optional[ModelInfo] = Field(None, description="Model information")
|
|
|
uptime_seconds: float = Field(..., description="Service uptime in seconds")
|
|
|
processed_requests: int = Field(0, description="Number of processed requests")
|
|
|
is_model_loaded: bool = Field(False, description="Whether model is successfully loaded")
|
|
|
|
|
|
|
|
|
app = FastAPI(
|
|
|
title="ContinuumAgent API",
|
|
|
description="API for the ContinuumAgent knowledge patching system",
|
|
|
version="0.1.0",
|
|
|
)
|
|
|
|
|
|
|
|
|
app.add_middleware(
|
|
|
CORSMiddleware,
|
|
|
allow_origins=["*"],
|
|
|
allow_credentials=True,
|
|
|
allow_methods=["*"],
|
|
|
allow_headers=["*"],
|
|
|
)
|
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
request_count = 0
|
|
|
continuum_router = None
|
|
|
model_load_error = None
|
|
|
|
|
|
@app.on_event("startup")
|
|
|
async def startup_event():
|
|
|
"""Initialize the router on startup"""
|
|
|
global continuum_router, model_load_error
|
|
|
|
|
|
|
|
|
model_dir = "models/slow"
|
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
|
model_files = [f for f in os.listdir(model_dir) if f.endswith(".gguf")]
|
|
|
|
|
|
if not model_files:
|
|
|
model_load_error = "No GGUF models found. Please run download_model.py first."
|
|
|
print(f"Error: {model_load_error}")
|
|
|
return
|
|
|
|
|
|
model_path = os.path.join(model_dir, model_files[0])
|
|
|
print(f"Using model: {model_path}")
|
|
|
|
|
|
|
|
|
n_gpu_layers = int(os.environ.get("N_GPU_LAYERS", "0"))
|
|
|
|
|
|
|
|
|
try:
|
|
|
continuum_router = ContinuumRouter(
|
|
|
model_path=model_path,
|
|
|
n_gpu_layers=n_gpu_layers
|
|
|
)
|
|
|
|
|
|
|
|
|
continuum_router.load_latest_patches()
|
|
|
|
|
|
except Exception as e:
|
|
|
error_traceback = traceback.format_exc()
|
|
|
model_load_error = f"Error initializing router: {str(e)}\n{error_traceback}"
|
|
|
print(model_load_error)
|
|
|
|
|
|
def get_router():
|
|
|
"""Get router dependency"""
|
|
|
if continuum_router is None:
|
|
|
raise HTTPException(
|
|
|
status_code=503,
|
|
|
detail=f"Service not fully initialized: {model_load_error or 'Unknown error'}"
|
|
|
)
|
|
|
return continuum_router
|
|
|
|
|
|
@app.get("/", response_model=StatusResponse)
|
|
|
async def get_status():
|
|
|
"""Get service status"""
|
|
|
global start_time, request_count, continuum_router, model_load_error
|
|
|
|
|
|
|
|
|
status_response = StatusResponse(
|
|
|
status="initializing" if model_load_error else "running",
|
|
|
uptime_seconds=time.time() - start_time,
|
|
|
processed_requests=request_count,
|
|
|
is_model_loaded=continuum_router is not None
|
|
|
)
|
|
|
|
|
|
|
|
|
if continuum_router:
|
|
|
try:
|
|
|
model_info = continuum_router.get_model_info()
|
|
|
status_response.model_info = model_info
|
|
|
except Exception as e:
|
|
|
print(f"Error getting model info: {e}")
|
|
|
|
|
|
|
|
|
if model_load_error:
|
|
|
status_response.status = f"error: {model_load_error.split(chr(10))[0]}"
|
|
|
|
|
|
return status_response
|
|
|
|
|
|
@app.post("/generate", response_model=GenerateResponse)
|
|
|
async def generate(request: GenerateRequest, router: ContinuumRouter = Depends(get_router)):
|
|
|
"""Generate text from model"""
|
|
|
global request_count
|
|
|
request_count += 1
|
|
|
|
|
|
try:
|
|
|
|
|
|
result = router.generate(
|
|
|
prompt=request.prompt,
|
|
|
system_prompt=request.system_prompt,
|
|
|
max_tokens=request.max_tokens,
|
|
|
temperature=request.temperature,
|
|
|
top_p=request.top_p,
|
|
|
auto_route=request.auto_route,
|
|
|
force_patches=request.force_patches
|
|
|
)
|
|
|
|
|
|
return result
|
|
|
|
|
|
except Exception as e:
|
|
|
error_traceback = traceback.format_exc()
|
|
|
raise HTTPException(
|
|
|
status_code=500,
|
|
|
detail=f"Error generating text: {str(e)}\n{error_traceback}"
|
|
|
)
|
|
|
|
|
|
@app.post("/patches/load")
|
|
|
async def load_patches(date_str: Optional[str] = None,
|
|
|
router: ContinuumRouter = Depends(get_router)):
|
|
|
"""Load patches for a specific date"""
|
|
|
try:
|
|
|
|
|
|
loaded = router.load_patches(date_str)
|
|
|
|
|
|
return {"status": "success", "loaded_patches": loaded}
|
|
|
|
|
|
except Exception as e:
|
|
|
error_traceback = traceback.format_exc()
|
|
|
raise HTTPException(
|
|
|
status_code=500,
|
|
|
detail=f"Error loading patches: {str(e)}\n{error_traceback}"
|
|
|
)
|
|
|
|
|
|
@app.get("/patches/list")
|
|
|
async def list_patches(router: ContinuumRouter = Depends(get_router)):
|
|
|
"""List available patches"""
|
|
|
try:
|
|
|
|
|
|
patches = router.list_patches()
|
|
|
|
|
|
return {"patches": patches}
|
|
|
|
|
|
except Exception as e:
|
|
|
error_traceback = traceback.format_exc()
|
|
|
raise HTTPException(
|
|
|
status_code=500,
|
|
|
detail=f"Error listing patches: {str(e)}\n{error_traceback}"
|
|
|
)
|
|
|
|
|
|
@app.get("/patches/active")
|
|
|
async def get_active_patches(router: ContinuumRouter = Depends(get_router)):
|
|
|
"""Get currently active patches"""
|
|
|
try:
|
|
|
|
|
|
active = router.get_active_patches()
|
|
|
|
|
|
return {"active_patches": active}
|
|
|
|
|
|
except Exception as e:
|
|
|
error_traceback = traceback.format_exc()
|
|
|
raise HTTPException(
|
|
|
status_code=500,
|
|
|
detail=f"Error getting active patches: {str(e)}\n{error_traceback}"
|
|
|
)
|
|
|
|
|
|
|
|
|
@app.get("/health")
|
|
|
async def health_check():
|
|
|
"""Health check endpoint"""
|
|
|
return {"status": "ok", "model_loaded": continuum_router is not None} |