File size: 8,000 Bytes
19f6105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
#!/usr/bin/env python
"""

FastAPI Application for ContinuumAgent Project

Serves the model with patched knowledge

Modified for better error handling and compatibility with Hugging Face Spaces

"""

import os
import time
import traceback
from typing import Dict, List, Any, Optional
from fastapi import FastAPI, HTTPException, BackgroundTasks, Query, Path, Depends
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from app.router import ContinuumRouter

# Define API models
class GenerateRequest(BaseModel):
    prompt: str = Field(..., description="User input prompt")
    system_prompt: Optional[str] = Field(None, description="Optional system prompt")
    max_tokens: int = Field(256, description="Maximum number of tokens to generate")
    temperature: float = Field(0.7, description="Sampling temperature (0.0-1.0)")
    top_p: float = Field(0.95, description="Top-p sampling parameter (0.0-1.0)")
    auto_route: bool = Field(True, description="Auto-route based on query complexity")
    force_patches: Optional[bool] = Field(None, description="Force usage of patches")

class GenerateResponse(BaseModel):
    text: str = Field(..., description="Generated text")
    elapsed_seconds: float = Field(..., description="Elapsed time in seconds")
    used_patches: bool = Field(..., description="Whether patches were used")
    adapter_paths: List[str] = Field(default_factory=list, description="Paths to used adapters")
    total_tokens: int = Field(0, description="Total tokens used")

class ModelInfo(BaseModel):
    name: str = Field(..., description="Model name")
    quantization: str = Field(..., description="Quantization format")
    patches: List[Dict[str, Any]] = Field(default_factory=list, description="Available patches")
    using_gpu: bool = Field(False, description="Whether GPU is being used")

class StatusResponse(BaseModel):
    status: str = Field(..., description="Service status")
    model_info: Optional[ModelInfo] = Field(None, description="Model information")
    uptime_seconds: float = Field(..., description="Service uptime in seconds")
    processed_requests: int = Field(0, description="Number of processed requests")
    is_model_loaded: bool = Field(False, description="Whether model is successfully loaded")

# Create FastAPI application
app = FastAPI(
    title="ContinuumAgent API",
    description="API for the ContinuumAgent knowledge patching system",
    version="0.1.0",
)

# Add CORS middleware for Hugging Face Spaces
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global variables
start_time = time.time()
request_count = 0
continuum_router = None
model_load_error = None

@app.on_event("startup")
async def startup_event():
    """Initialize the router on startup"""
    global continuum_router, model_load_error
    
    # Find model path
    model_dir = "models/slow"
    os.makedirs(model_dir, exist_ok=True)
    model_files = [f for f in os.listdir(model_dir) if f.endswith(".gguf")]
    
    if not model_files:
        model_load_error = "No GGUF models found. Please run download_model.py first."
        print(f"Error: {model_load_error}")
        return
    
    model_path = os.path.join(model_dir, model_files[0])
    print(f"Using model: {model_path}")
    
    # Get GPU layers setting from environment
    n_gpu_layers = int(os.environ.get("N_GPU_LAYERS", "0"))
    
    # Initialize router
    try:
        continuum_router = ContinuumRouter(
            model_path=model_path,
            n_gpu_layers=n_gpu_layers
        )
        
        # Load patches (can be done in background)
        continuum_router.load_latest_patches()
        
    except Exception as e:
        error_traceback = traceback.format_exc()
        model_load_error = f"Error initializing router: {str(e)}\n{error_traceback}"
        print(model_load_error)

def get_router():
    """Get router dependency"""
    if continuum_router is None:
        raise HTTPException(
            status_code=503, 
            detail=f"Service not fully initialized: {model_load_error or 'Unknown error'}"
        )
    return continuum_router

@app.get("/", response_model=StatusResponse)
async def get_status():
    """Get service status"""
    global start_time, request_count, continuum_router, model_load_error
    
    # Create base response
    status_response = StatusResponse(
        status="initializing" if model_load_error else "running",
        uptime_seconds=time.time() - start_time,
        processed_requests=request_count,
        is_model_loaded=continuum_router is not None
    )
    
    # Add model info if available
    if continuum_router:
        try:
            model_info = continuum_router.get_model_info()
            status_response.model_info = model_info
        except Exception as e:
            print(f"Error getting model info: {e}")
    
    # Add error information if applicable
    if model_load_error:
        status_response.status = f"error: {model_load_error.split(chr(10))[0]}"
    
    return status_response

@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest, router: ContinuumRouter = Depends(get_router)):
    """Generate text from model"""
    global request_count
    request_count += 1
    
    try:
        # Generate text
        result = router.generate(
            prompt=request.prompt,
            system_prompt=request.system_prompt,
            max_tokens=request.max_tokens,
            temperature=request.temperature,
            top_p=request.top_p,
            auto_route=request.auto_route,
            force_patches=request.force_patches
        )
        
        return result
    
    except Exception as e:
        error_traceback = traceback.format_exc()
        raise HTTPException(
            status_code=500, 
            detail=f"Error generating text: {str(e)}\n{error_traceback}"
        )

@app.post("/patches/load")
async def load_patches(date_str: Optional[str] = None, 

                      router: ContinuumRouter = Depends(get_router)):
    """Load patches for a specific date"""
    try:
        # Load patches
        loaded = router.load_patches(date_str)
        
        return {"status": "success", "loaded_patches": loaded}
    
    except Exception as e:
        error_traceback = traceback.format_exc()
        raise HTTPException(
            status_code=500, 
            detail=f"Error loading patches: {str(e)}\n{error_traceback}"
        )

@app.get("/patches/list")
async def list_patches(router: ContinuumRouter = Depends(get_router)):
    """List available patches"""
    try:
        # Get patches
        patches = router.list_patches()
        
        return {"patches": patches}
    
    except Exception as e:
        error_traceback = traceback.format_exc()
        raise HTTPException(
            status_code=500, 
            detail=f"Error listing patches: {str(e)}\n{error_traceback}"
        )

@app.get("/patches/active")
async def get_active_patches(router: ContinuumRouter = Depends(get_router)):
    """Get currently active patches"""
    try:
        # Get active patches
        active = router.get_active_patches()
        
        return {"active_patches": active}
    
    except Exception as e:
        error_traceback = traceback.format_exc()
        raise HTTPException(
            status_code=500, 
            detail=f"Error getting active patches: {str(e)}\n{error_traceback}"
        )

# Health check endpoint for Hugging Face Spaces
@app.get("/health")
async def health_check():
    """Health check endpoint"""
    return {"status": "ok", "model_loaded": continuum_router is not None}