deasdutta commited on
Commit
19f6105
·
verified ·
1 Parent(s): 61f7450

Upload app\main.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app//main.py +222 -0
app//main.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ FastAPI Application for ContinuumAgent Project
4
+ Serves the model with patched knowledge
5
+ Modified for better error handling and compatibility with Hugging Face Spaces
6
+ """
7
+
8
+ import os
9
+ import time
10
+ import traceback
11
+ from typing import Dict, List, Any, Optional
12
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, Query, Path, Depends
13
+ from fastapi.responses import JSONResponse
14
+ from fastapi.middleware.cors import CORSMiddleware
15
+ from pydantic import BaseModel, Field
16
+ from app.router import ContinuumRouter
17
+
18
+ # Define API models
19
+ class GenerateRequest(BaseModel):
20
+ prompt: str = Field(..., description="User input prompt")
21
+ system_prompt: Optional[str] = Field(None, description="Optional system prompt")
22
+ max_tokens: int = Field(256, description="Maximum number of tokens to generate")
23
+ temperature: float = Field(0.7, description="Sampling temperature (0.0-1.0)")
24
+ top_p: float = Field(0.95, description="Top-p sampling parameter (0.0-1.0)")
25
+ auto_route: bool = Field(True, description="Auto-route based on query complexity")
26
+ force_patches: Optional[bool] = Field(None, description="Force usage of patches")
27
+
28
+ class GenerateResponse(BaseModel):
29
+ text: str = Field(..., description="Generated text")
30
+ elapsed_seconds: float = Field(..., description="Elapsed time in seconds")
31
+ used_patches: bool = Field(..., description="Whether patches were used")
32
+ adapter_paths: List[str] = Field(default_factory=list, description="Paths to used adapters")
33
+ total_tokens: int = Field(0, description="Total tokens used")
34
+
35
+ class ModelInfo(BaseModel):
36
+ name: str = Field(..., description="Model name")
37
+ quantization: str = Field(..., description="Quantization format")
38
+ patches: List[Dict[str, Any]] = Field(default_factory=list, description="Available patches")
39
+ using_gpu: bool = Field(False, description="Whether GPU is being used")
40
+
41
+ class StatusResponse(BaseModel):
42
+ status: str = Field(..., description="Service status")
43
+ model_info: Optional[ModelInfo] = Field(None, description="Model information")
44
+ uptime_seconds: float = Field(..., description="Service uptime in seconds")
45
+ processed_requests: int = Field(0, description="Number of processed requests")
46
+ is_model_loaded: bool = Field(False, description="Whether model is successfully loaded")
47
+
48
+ # Create FastAPI application
49
+ app = FastAPI(
50
+ title="ContinuumAgent API",
51
+ description="API for the ContinuumAgent knowledge patching system",
52
+ version="0.1.0",
53
+ )
54
+
55
+ # Add CORS middleware for Hugging Face Spaces
56
+ app.add_middleware(
57
+ CORSMiddleware,
58
+ allow_origins=["*"],
59
+ allow_credentials=True,
60
+ allow_methods=["*"],
61
+ allow_headers=["*"],
62
+ )
63
+
64
+ # Global variables
65
+ start_time = time.time()
66
+ request_count = 0
67
+ continuum_router = None
68
+ model_load_error = None
69
+
70
+ @app.on_event("startup")
71
+ async def startup_event():
72
+ """Initialize the router on startup"""
73
+ global continuum_router, model_load_error
74
+
75
+ # Find model path
76
+ model_dir = "models/slow"
77
+ os.makedirs(model_dir, exist_ok=True)
78
+ model_files = [f for f in os.listdir(model_dir) if f.endswith(".gguf")]
79
+
80
+ if not model_files:
81
+ model_load_error = "No GGUF models found. Please run download_model.py first."
82
+ print(f"Error: {model_load_error}")
83
+ return
84
+
85
+ model_path = os.path.join(model_dir, model_files[0])
86
+ print(f"Using model: {model_path}")
87
+
88
+ # Get GPU layers setting from environment
89
+ n_gpu_layers = int(os.environ.get("N_GPU_LAYERS", "0"))
90
+
91
+ # Initialize router
92
+ try:
93
+ continuum_router = ContinuumRouter(
94
+ model_path=model_path,
95
+ n_gpu_layers=n_gpu_layers
96
+ )
97
+
98
+ # Load patches (can be done in background)
99
+ continuum_router.load_latest_patches()
100
+
101
+ except Exception as e:
102
+ error_traceback = traceback.format_exc()
103
+ model_load_error = f"Error initializing router: {str(e)}\n{error_traceback}"
104
+ print(model_load_error)
105
+
106
+ def get_router():
107
+ """Get router dependency"""
108
+ if continuum_router is None:
109
+ raise HTTPException(
110
+ status_code=503,
111
+ detail=f"Service not fully initialized: {model_load_error or 'Unknown error'}"
112
+ )
113
+ return continuum_router
114
+
115
+ @app.get("/", response_model=StatusResponse)
116
+ async def get_status():
117
+ """Get service status"""
118
+ global start_time, request_count, continuum_router, model_load_error
119
+
120
+ # Create base response
121
+ status_response = StatusResponse(
122
+ status="initializing" if model_load_error else "running",
123
+ uptime_seconds=time.time() - start_time,
124
+ processed_requests=request_count,
125
+ is_model_loaded=continuum_router is not None
126
+ )
127
+
128
+ # Add model info if available
129
+ if continuum_router:
130
+ try:
131
+ model_info = continuum_router.get_model_info()
132
+ status_response.model_info = model_info
133
+ except Exception as e:
134
+ print(f"Error getting model info: {e}")
135
+
136
+ # Add error information if applicable
137
+ if model_load_error:
138
+ status_response.status = f"error: {model_load_error.split(chr(10))[0]}"
139
+
140
+ return status_response
141
+
142
+ @app.post("/generate", response_model=GenerateResponse)
143
+ async def generate(request: GenerateRequest, router: ContinuumRouter = Depends(get_router)):
144
+ """Generate text from model"""
145
+ global request_count
146
+ request_count += 1
147
+
148
+ try:
149
+ # Generate text
150
+ result = router.generate(
151
+ prompt=request.prompt,
152
+ system_prompt=request.system_prompt,
153
+ max_tokens=request.max_tokens,
154
+ temperature=request.temperature,
155
+ top_p=request.top_p,
156
+ auto_route=request.auto_route,
157
+ force_patches=request.force_patches
158
+ )
159
+
160
+ return result
161
+
162
+ except Exception as e:
163
+ error_traceback = traceback.format_exc()
164
+ raise HTTPException(
165
+ status_code=500,
166
+ detail=f"Error generating text: {str(e)}\n{error_traceback}"
167
+ )
168
+
169
+ @app.post("/patches/load")
170
+ async def load_patches(date_str: Optional[str] = None,
171
+ router: ContinuumRouter = Depends(get_router)):
172
+ """Load patches for a specific date"""
173
+ try:
174
+ # Load patches
175
+ loaded = router.load_patches(date_str)
176
+
177
+ return {"status": "success", "loaded_patches": loaded}
178
+
179
+ except Exception as e:
180
+ error_traceback = traceback.format_exc()
181
+ raise HTTPException(
182
+ status_code=500,
183
+ detail=f"Error loading patches: {str(e)}\n{error_traceback}"
184
+ )
185
+
186
+ @app.get("/patches/list")
187
+ async def list_patches(router: ContinuumRouter = Depends(get_router)):
188
+ """List available patches"""
189
+ try:
190
+ # Get patches
191
+ patches = router.list_patches()
192
+
193
+ return {"patches": patches}
194
+
195
+ except Exception as e:
196
+ error_traceback = traceback.format_exc()
197
+ raise HTTPException(
198
+ status_code=500,
199
+ detail=f"Error listing patches: {str(e)}\n{error_traceback}"
200
+ )
201
+
202
+ @app.get("/patches/active")
203
+ async def get_active_patches(router: ContinuumRouter = Depends(get_router)):
204
+ """Get currently active patches"""
205
+ try:
206
+ # Get active patches
207
+ active = router.get_active_patches()
208
+
209
+ return {"active_patches": active}
210
+
211
+ except Exception as e:
212
+ error_traceback = traceback.format_exc()
213
+ raise HTTPException(
214
+ status_code=500,
215
+ detail=f"Error getting active patches: {str(e)}\n{error_traceback}"
216
+ )
217
+
218
+ # Health check endpoint for Hugging Face Spaces
219
+ @app.get("/health")
220
+ async def health_check():
221
+ """Health check endpoint"""
222
+ return {"status": "ok", "model_loaded": continuum_router is not None}