Bc-AI commited on
Commit
b50dda1
·
verified ·
1 Parent(s): 43c1e33

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +256 -0
app.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SAM-Z-1 Cluster Head Node
3
+ Receives requests and distributes to worker spaces
4
+ """
5
+
6
+ from fastapi import FastAPI, HTTPException, Request
7
+ from fastapi.responses import StreamingResponse
8
+ from pydantic import BaseModel
9
+ import httpx
10
+ import asyncio
11
+ import json
12
+ import time
13
+ from typing import List, Optional
14
+ import random
15
+
16
+ app = FastAPI(title="SAM-Z-1 Cluster API", version="1.0.0")
17
+
18
+ # ============================================================================
19
+ # Configuration
20
+ # ============================================================================
21
+
22
+ # Add your worker space URLs here
23
+ WORKER_URLS = [
24
+ "https://your-username-sam-z1-worker1.hf.space",
25
+ "https://your-username-sam-z1-worker2.hf.space",
26
+ # Add more workers as needed
27
+ ]
28
+
29
+ # Health check interval (seconds)
30
+ HEALTH_CHECK_INTERVAL = 30
31
+
32
+ # Worker health status
33
+ worker_health = {url: {"healthy": True, "last_check": 0} for url in WORKER_URLS}
34
+
35
+ # ============================================================================
36
+ # Request Models
37
+ # ============================================================================
38
+
39
+ class GenerateRequest(BaseModel):
40
+ prompt: str
41
+ max_tokens: int = 512
42
+ temperature: float = 0.8
43
+ top_k: int = 40
44
+ top_p: float = 0.9
45
+ repetition_penalty: float = 1.1
46
+ stream: bool = False
47
+
48
+ class ChatMessage(BaseModel):
49
+ role: str # "user" or "assistant"
50
+ content: str
51
+
52
+ class ChatRequest(BaseModel):
53
+ messages: List[ChatMessage]
54
+ max_tokens: int = 512
55
+ temperature: float = 0.8
56
+ top_k: int = 40
57
+ top_p: float = 0.9
58
+ repetition_penalty: float = 1.1
59
+ stream: bool = False
60
+
61
+ # ============================================================================
62
+ # Load Balancing & Health Checks
63
+ # ============================================================================
64
+
65
+ def get_healthy_workers() -> List[str]:
66
+ """Get list of healthy workers"""
67
+ return [url for url, status in worker_health.items() if status["healthy"]]
68
+
69
+ def select_worker() -> Optional[str]:
70
+ """Select a worker using round-robin on healthy workers"""
71
+ healthy = get_healthy_workers()
72
+ if not healthy:
73
+ return None
74
+ return random.choice(healthy) # You could also implement round-robin here
75
+
76
+ async def check_worker_health(worker_url: str) -> bool:
77
+ """Check if a worker is healthy"""
78
+ try:
79
+ async with httpx.AsyncClient(timeout=5.0) as client:
80
+ response = await client.get(f"{worker_url}/health")
81
+ return response.status_code == 200
82
+ except:
83
+ return False
84
+
85
+ async def health_check_loop():
86
+ """Background task to check worker health"""
87
+ while True:
88
+ for worker_url in WORKER_URLS:
89
+ healthy = await check_worker_health(worker_url)
90
+ worker_health[worker_url]["healthy"] = healthy
91
+ worker_health[worker_url]["last_check"] = time.time()
92
+
93
+ status = "✅" if healthy else "❌"
94
+ print(f"{status} Worker {worker_url}: {'healthy' if healthy else 'unhealthy'}")
95
+
96
+ await asyncio.sleep(HEALTH_CHECK_INTERVAL)
97
+
98
+ @app.on_event("startup")
99
+ async def startup_event():
100
+ """Start health check loop on startup"""
101
+ asyncio.create_task(health_check_loop())
102
+
103
+ # ============================================================================
104
+ # API Endpoints
105
+ # ============================================================================
106
+
107
+ @app.get("/")
108
+ async def root():
109
+ """API info"""
110
+ healthy_count = len(get_healthy_workers())
111
+ return {
112
+ "name": "SAM-Z-1 Cluster API",
113
+ "version": "1.0.0",
114
+ "workers": len(WORKER_URLS),
115
+ "healthy_workers": healthy_count,
116
+ "endpoints": {
117
+ "generate": "/v1/generate",
118
+ "chat": "/v1/chat",
119
+ "health": "/health",
120
+ "workers": "/workers"
121
+ }
122
+ }
123
+
124
+ @app.get("/health")
125
+ async def health():
126
+ """Health check endpoint"""
127
+ healthy_count = len(get_healthy_workers())
128
+ return {
129
+ "status": "healthy" if healthy_count > 0 else "unhealthy",
130
+ "workers_total": len(WORKER_URLS),
131
+ "workers_healthy": healthy_count
132
+ }
133
+
134
+ @app.get("/workers")
135
+ async def workers_status():
136
+ """Get status of all workers"""
137
+ return {
138
+ "workers": [
139
+ {
140
+ "url": url,
141
+ "healthy": status["healthy"],
142
+ "last_check": status["last_check"]
143
+ }
144
+ for url, status in worker_health.items()
145
+ ]
146
+ }
147
+
148
+ @app.post("/v1/generate")
149
+ async def generate(request: GenerateRequest):
150
+ """Generate text from prompt"""
151
+ worker_url = select_worker()
152
+
153
+ if not worker_url:
154
+ raise HTTPException(
155
+ status_code=503,
156
+ detail="No healthy workers available"
157
+ )
158
+
159
+ try:
160
+ async with httpx.AsyncClient(timeout=300.0) as client:
161
+ if request.stream:
162
+ # Streaming response
163
+ async def stream_from_worker():
164
+ async with client.stream(
165
+ "POST",
166
+ f"{worker_url}/generate",
167
+ json=request.dict()
168
+ ) as response:
169
+ async for chunk in response.aiter_text():
170
+ yield chunk
171
+
172
+ return StreamingResponse(
173
+ stream_from_worker(),
174
+ media_type="text/event-stream"
175
+ )
176
+ else:
177
+ # Non-streaming response
178
+ response = await client.post(
179
+ f"{worker_url}/generate",
180
+ json=request.dict()
181
+ )
182
+ return response.json()
183
+
184
+ except httpx.TimeoutException:
185
+ # Mark worker as unhealthy and retry with another
186
+ worker_health[worker_url]["healthy"] = False
187
+ raise HTTPException(
188
+ status_code=504,
189
+ detail="Worker timeout - request failed"
190
+ )
191
+ except Exception as e:
192
+ raise HTTPException(
193
+ status_code=500,
194
+ detail=f"Worker error: {str(e)}"
195
+ )
196
+
197
+ @app.post("/v1/chat")
198
+ async def chat(request: ChatRequest):
199
+ """Chat completion endpoint"""
200
+ worker_url = select_worker()
201
+
202
+ if not worker_url:
203
+ raise HTTPException(
204
+ status_code=503,
205
+ detail="No healthy workers available"
206
+ )
207
+
208
+ try:
209
+ async with httpx.AsyncClient(timeout=300.0) as client:
210
+ if request.stream:
211
+ # Streaming response
212
+ async def stream_from_worker():
213
+ async with client.stream(
214
+ "POST",
215
+ f"{worker_url}/chat",
216
+ json=request.dict()
217
+ ) as response:
218
+ async for chunk in response.aiter_text():
219
+ yield chunk
220
+
221
+ return StreamingResponse(
222
+ stream_from_worker(),
223
+ media_type="text/event-stream"
224
+ )
225
+ else:
226
+ # Non-streaming response
227
+ response = await client.post(
228
+ f"{worker_url}/chat",
229
+ json=request.dict()
230
+ )
231
+ return response.json()
232
+
233
+ except httpx.TimeoutException:
234
+ worker_health[worker_url]["healthy"] = False
235
+ raise HTTPException(
236
+ status_code=504,
237
+ detail="Worker timeout - request failed"
238
+ )
239
+ except Exception as e:
240
+ raise HTTPException(
241
+ status_code=500,
242
+ detail=f"Worker error: {str(e)}"
243
+ )
244
+
245
+ # ============================================================================
246
+ # Launch
247
+ # ============================================================================
248
+
249
+ if __name__ == "__main__":
250
+ import uvicorn
251
+ uvicorn.run(
252
+ app,
253
+ host="0.0.0.0",
254
+ port=7860,
255
+ log_level="info"
256
+ )