adowu commited on
Commit
accec81
·
verified ·
1 Parent(s): f7a796b

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +206 -204
main.py CHANGED
@@ -1,6 +1,11 @@
1
  from __future__ import annotations
2
 
3
- import os, json, time, uuid, asyncio, logging
 
 
 
 
 
4
  from typing import Any, AsyncGenerator
5
  from contextlib import asynccontextmanager
6
 
@@ -16,12 +21,18 @@ load_dotenv()
16
  # ---------------------------------------------------------------------------
17
  # Config
18
  # ---------------------------------------------------------------------------
19
- API_KEY = os.getenv("API_KEY", "")
20
- HF_SPACE_URL = os.getenv("HF_SPACE_URL", "")
21
- MODEL_ID = os.getenv("MODEL_ID", "")
22
- DEFAULT_TEMP = float(os.getenv("DEFAULT_TEMPERATURE", "0.6"))
23
- DEFAULT_TOP_P = float(os.getenv("DEFAULT_TOP_P", "0.95"))
24
- DEFAULT_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "1024"))
 
 
 
 
 
 
25
 
26
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
27
  log = logging.getLogger(__name__)
@@ -29,18 +40,21 @@ log = logging.getLogger(__name__)
29
  # ---------------------------------------------------------------------------
30
  # Gradio client (singleton)
31
  # ---------------------------------------------------------------------------
 
32
  _client: Client | None = None
33
 
 
34
  async def get_client() -> Client:
35
  global _client
36
  if _client is None:
37
  log.info("Connecting to %s", HF_SPACE_URL)
38
  _client = await asyncio.to_thread(Client, HF_SPACE_URL)
39
- log.info("Connected.")
40
  return _client
41
 
 
42
  # ---------------------------------------------------------------------------
43
- # Pydantic schemas
44
  # ---------------------------------------------------------------------------
45
 
46
  class Message(BaseModel):
@@ -48,6 +62,7 @@ class Message(BaseModel):
48
  content: str | list[dict] = ""
49
  name: str | None = None
50
 
 
51
  class ChatCompletionRequest(BaseModel):
52
  model: str = MODEL_ID
53
  messages: list[Message]
@@ -61,6 +76,7 @@ class ChatCompletionRequest(BaseModel):
61
  seed: int | None = None
62
  user: str | None = None
63
 
 
64
  # ---------------------------------------------------------------------------
65
  # Auth
66
  # ---------------------------------------------------------------------------
@@ -72,27 +88,26 @@ async def verify_key(request: Request) -> None:
72
  if not auth.startswith("Bearer ") or auth[7:] != API_KEY:
73
  raise HTTPException(status_code=401, detail="Invalid or missing API key")
74
 
 
75
  # ---------------------------------------------------------------------------
76
- # Lifespan context manager (modern FastAPI pattern)
77
  # ---------------------------------------------------------------------------
78
 
79
  @asynccontextmanager
80
  async def lifespan(app: FastAPI):
81
- # Startup
82
- log.info("Starting up - connecting to Gradio client...")
83
  await get_client()
84
- log.info("Startup complete.")
85
  yield
86
- # Shutdown (if needed)
87
- log.info("Shutting down.")
88
 
89
  # ---------------------------------------------------------------------------
90
  # App
91
  # ---------------------------------------------------------------------------
92
 
93
  app = FastAPI(
94
- title="Falcon H1R API",
95
- version="3.2.0",
96
  lifespan=lifespan,
97
  )
98
 
@@ -105,177 +120,167 @@ app.add_middleware(
105
  )
106
 
107
  # ---------------------------------------------------------------------------
108
- # Business logic
109
  # ---------------------------------------------------------------------------
110
 
 
111
  def _content_str(m: Message) -> str:
112
  if isinstance(m.content, str):
113
  return m.content
114
- return "".join(p.get("text", "") for p in m.content if p.get("type") == "text")
 
 
 
 
 
115
 
116
  def _build_prompt(messages: list[Message]) -> str:
117
- """Flatten messages into a single prompt string."""
118
  system, parts = [], []
119
  for m in messages:
120
- c = _content_str(m)
121
- if m.role == "system": system.append(c)
122
- elif m.role == "user": parts.append(c)
123
- elif m.role == "assistant": parts.append(f"[ASSISTANT]\n{c}")
124
- prefix = "[SYSTEM]\n" + "\n".join(system) + "\n[/SYSTEM]\n" if system else ""
125
- return prefix + "\n".join(parts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  def _extract_text(result: Any) -> str:
128
  """
129
- Extract assistant reply from Gradio client.predict() result.
130
-
131
- gradio_client returns either:
132
- - tuple: (output1, output2, ...) where one element is the chatbot data
133
- - object with .data attribute containing a list
134
-
135
- We need to find the conversation list and extract the last message.
136
  """
137
- try:
138
- # Handle both tuple and object with .data
139
- if isinstance(result, tuple):
140
- data = result
141
- elif hasattr(result, 'data'):
142
- data = result.data
143
- else:
144
- data = [result]
145
-
146
- log.info("Raw result type: %s, length: %s", type(data).__name__, len(data) if hasattr(data, '__len__') else 'N/A')
147
-
148
- # Search through all returned values for the conversation
149
- conversation = None
150
- for idx, item in enumerate(data):
151
- log.debug("Item %d type: %s", idx, type(item).__name__)
152
-
153
- # Check if this item is a dict with 'value' key (chatbot component)
154
- if isinstance(item, dict) and "value" in item:
155
- val = item["value"]
156
- if isinstance(val, list) and val:
157
- conversation = val
158
- log.info("Found conversation in dict at index %d, length: %d", idx, len(val))
159
- break
160
-
161
- # Check if item itself is a list (direct conversation)
162
- elif isinstance(item, list) and item:
163
- # Verify it looks like a conversation (list of message dicts/tuples)
164
- first = item[0]
165
- if isinstance(first, (dict, list, tuple)):
166
- conversation = item
167
- log.info("Found conversation as list at index %d, length: %d", idx, len(item))
168
- break
169
-
170
- if conversation is None:
171
- raise ValueError(f"Cannot find conversation in result. Data structure: {json.dumps(str(data)[:500])}")
172
-
173
- # Extract last message
174
- last = conversation[-1]
175
- log.info("Last message type: %s, value: %s", type(last).__name__, str(last)[:200])
176
-
177
- # Handle different message formats
178
- content = None
179
-
180
- # Format 1: dict with 'content' key (Gradio 4.x)
181
  if isinstance(last, dict):
182
- content = last.get("content", "")
183
-
184
- # Format 2: tuple/list [user_msg, assistant_msg] (Gradio 3.x)
185
- elif isinstance(last, (list, tuple)) and len(last) >= 2:
186
- content = last[1] or ""
187
-
188
- # Format 3: plain string
189
- elif isinstance(last, str):
190
- content = last
191
-
192
- if content is None:
193
- raise ValueError(f"Cannot extract content from last message: {last}")
194
-
195
- # If content is a list of content blocks, extract text
196
- if isinstance(content, list):
197
- text_parts = []
198
- for block in content:
199
- if isinstance(block, dict):
200
- if block.get("type") == "text":
201
- text_parts.append(block.get("content", block.get("text", "")).strip())
202
- elif isinstance(block, str):
203
- text_parts.append(block.strip())
204
- return "".join(text_parts)
205
-
206
- return str(content).strip()
207
-
208
- except Exception as e:
209
- log.error("_extract_text failed: %s", e, exc_info=True)
210
- log.error("Raw result dump: %s", str(result)[:1000])
211
- raise ValueError(f"Failed to extract text: {e}") from e
212
 
213
- async def _call_falcon(prompt: str, req: ChatCompletionRequest) -> str:
214
- """
215
- Call Falcon H1R via Gradio client.
216
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  client = await get_client()
218
-
219
  settings = {
220
  "model": req.model,
221
  "temperature": req.temperature,
222
  "max_new_tokens": req.max_tokens,
223
  "top_p": req.top_p,
224
  }
225
-
226
- # Step 1: Reset chat
227
- log.info("Resetting chat session...")
228
- await asyncio.to_thread(
229
- client.predict,
230
- api_name="/new_chat"
231
- )
232
-
233
- # Step 2: Send message
234
- log.info("Sending message to Falcon...")
235
  result = await asyncio.to_thread(
236
  client.predict,
237
  input_value=prompt,
238
  settings_form_value=settings,
239
- api_name="/add_message"
240
  )
241
-
242
- log.info("Received result, extracting text...")
243
  return _extract_text(result)
244
 
245
- def _make_response(text: str, req: ChatCompletionRequest) -> dict:
246
- pt = sum(len(_content_str(m)) for m in req.messages) // 4
247
- ct = len(text) // 4
248
- return {
249
- "id": f"chatcmpl-{uuid.uuid4().hex}",
250
- "object": "chat.completion",
251
- "created": int(time.time()),
 
 
 
 
 
 
252
  "model": req.model,
253
- "system_fingerprint": f"fp_{uuid.uuid4().hex[:8]}",
254
- "choices": [{
255
- "index": 0,
256
- "message": {
257
- "role": "assistant",
258
- "content": text,
259
- "tool_calls": None,
260
- "function_call": None,
261
- },
262
- "finish_reason": "stop",
263
- "logprobs": None,
264
- }],
265
- "usage": {
266
- "prompt_tokens": pt,
267
- "completion_tokens": ct,
268
- "total_tokens": pt + ct,
269
- },
270
  }
271
 
272
- async def _stream_sse(text: str, req: ChatCompletionRequest) -> AsyncGenerator[str, None]:
273
- """Simulate streaming by chunking the full response."""
 
 
 
 
 
 
 
274
  cid = f"chatcmpl-{uuid.uuid4().hex}"
275
  created = int(time.time())
276
-
277
- # Stream in small chunks
278
- for i in range(0, len(text), 6):
279
  chunk = {
280
  "id": cid,
281
  "object": "chat.completion.chunk",
@@ -283,73 +288,70 @@ async def _stream_sse(text: str, req: ChatCompletionRequest) -> AsyncGenerator[s
283
  "model": req.model,
284
  "choices": [{
285
  "index": 0,
286
- "delta": {"role": "assistant", "content": text[i:i+6]},
287
  "finish_reason": None,
288
  }],
289
  }
290
  yield f"data: {json.dumps(chunk)}\n\n"
291
- await asyncio.sleep(0.01)
292
-
293
- # Final chunk
294
- pt = sum(len(_content_str(m)) for m in req.messages) // 4
295
- ct = len(text) // 4
296
- final = {
297
- "id": cid,
298
- "object": "chat.completion.chunk",
299
- "created": created,
300
- "model": req.model,
301
- "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
302
- "usage": {"prompt_tokens": pt, "completion_tokens": ct, "total_tokens": pt + ct},
303
- }
304
- yield f"data: {json.dumps(final)}\n\n"
305
  yield "data: [DONE]\n\n"
306
 
 
307
  # ---------------------------------------------------------------------------
308
- # Routes
309
  # ---------------------------------------------------------------------------
310
 
311
- @app.get("/")
312
- async def root():
 
 
313
  return {
314
- "service": "FOC API",
315
- "version": "3.2.0",
316
- "endpoints": {
317
- "health": "/health",
318
- "models": "/v1/models",
319
- "chat": "/v1/chat/completions",
 
 
 
 
 
 
 
320
  },
321
  }
322
 
323
- @app.get("/health")
324
- async def health():
325
- return {"status": "ok", "model": MODEL_ID, "space": HF_SPACE_URL}
326
 
327
- @app.get("/v1/models")
328
- async def list_models(_: None = Depends(verify_key)):
329
- return {"object": "list", "data": [{
330
- "id": MODEL_ID,
331
- "object": "model",
332
- "created": 1710000000,
333
- "owned_by": "tiiuae",
334
- }]}
335
 
336
  @app.post("/v1/chat/completions")
337
  async def chat_completions(req: ChatCompletionRequest, _: None = Depends(verify_key)):
338
  prompt = _build_prompt(req.messages)
339
- log.info("Request | model=%s temp=%.2f tokens=%d stream=%s",
340
- req.model, req.temperature, req.max_tokens, req.stream)
341
-
342
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  text = await _call_falcon(prompt, req)
344
- except Exception as exc:
345
- log.exception("Falcon call failed")
346
- raise HTTPException(status_code=502, detail=f"Upstream error: {exc}") from exc
347
-
348
- if req.stream:
349
- return StreamingResponse(
350
- _stream_sse(text, req),
351
- media_type="text/event-stream",
352
- headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
353
- )
354
-
355
- return JSONResponse(content=_make_response(text, req))
 
1
  from __future__ import annotations
2
 
3
+ import os
4
+ import json
5
+ import time
6
+ import uuid
7
+ import asyncio
8
+ import logging
9
  from typing import Any, AsyncGenerator
10
  from contextlib import asynccontextmanager
11
 
 
21
  # ---------------------------------------------------------------------------
22
  # Config
23
  # ---------------------------------------------------------------------------
24
+
25
+ API_KEY = os.getenv("API_KEY", "")
26
+ HF_SPACE_URL = os.getenv("HF_SPACE_URL", "")
27
+ MODEL_ID = os.getenv("MODEL_ID", "")
28
+
29
+ DEFAULT_TEMP = float(os.getenv("DEFAULT_TEMPERATURE", "0.6"))
30
+ DEFAULT_TOP_P = float(os.getenv("DEFAULT_TOP_P", "0.95"))
31
+ DEFAULT_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "32000"))
32
+
33
+ REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", "120"))
34
+ MAX_RETRIES = int(os.getenv("MAX_RETRIES", "3"))
35
+ RETRY_BASE_DELAY = float(os.getenv("RETRY_BASE_DELAY", "1.5"))
36
 
37
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
38
  log = logging.getLogger(__name__)
 
40
  # ---------------------------------------------------------------------------
41
  # Gradio client (singleton)
42
  # ---------------------------------------------------------------------------
43
+
44
  _client: Client | None = None
45
 
46
+
47
  async def get_client() -> Client:
48
  global _client
49
  if _client is None:
50
  log.info("Connecting to %s", HF_SPACE_URL)
51
  _client = await asyncio.to_thread(Client, HF_SPACE_URL)
52
+ log.info("Connected to Space.")
53
  return _client
54
 
55
+
56
  # ---------------------------------------------------------------------------
57
+ # Schemas
58
  # ---------------------------------------------------------------------------
59
 
60
  class Message(BaseModel):
 
62
  content: str | list[dict] = ""
63
  name: str | None = None
64
 
65
+
66
  class ChatCompletionRequest(BaseModel):
67
  model: str = MODEL_ID
68
  messages: list[Message]
 
76
  seed: int | None = None
77
  user: str | None = None
78
 
79
+
80
  # ---------------------------------------------------------------------------
81
  # Auth
82
  # ---------------------------------------------------------------------------
 
88
  if not auth.startswith("Bearer ") or auth[7:] != API_KEY:
89
  raise HTTPException(status_code=401, detail="Invalid or missing API key")
90
 
91
+
92
  # ---------------------------------------------------------------------------
93
+ # Lifespan
94
  # ---------------------------------------------------------------------------
95
 
96
  @asynccontextmanager
97
  async def lifespan(app: FastAPI):
98
+ log.info("Startup: connecting to Gradio client...")
 
99
  await get_client()
 
100
  yield
101
+ log.info("Shutdown.")
102
+
103
 
104
  # ---------------------------------------------------------------------------
105
  # App
106
  # ---------------------------------------------------------------------------
107
 
108
  app = FastAPI(
109
+ title="FHR",
110
+ version="4.0.0",
111
  lifespan=lifespan,
112
  )
113
 
 
120
  )
121
 
122
  # ---------------------------------------------------------------------------
123
+ # Utilities
124
  # ---------------------------------------------------------------------------
125
 
126
+
127
  def _content_str(m: Message) -> str:
128
  if isinstance(m.content, str):
129
  return m.content
130
+ return "".join(
131
+ p.get("text", "") or p.get("content", "")
132
+ for p in m.content
133
+ if isinstance(p, dict)
134
+ )
135
+
136
 
137
  def _build_prompt(messages: list[Message]) -> str:
 
138
  system, parts = [], []
139
  for m in messages:
140
+ c = _content_str(m).strip()
141
+ if not c:
142
+ continue
143
+
144
+ if m.role == "system":
145
+ system.append(c)
146
+ elif m.role == "assistant":
147
+ parts.append(f"[ASSISTANT]\n{c}")
148
+ else:
149
+ parts.append(c)
150
+
151
+ prefix = ""
152
+ if system:
153
+ prefix = "[SYSTEM]\n" + "\n".join(system) + "\n[/SYSTEM]\n\n"
154
+
155
+ return prefix + "\n\n".join(parts)
156
+
157
+
158
+ # ---------------------------------------------------------------------------
159
+ # Robust Extraction
160
+ # ---------------------------------------------------------------------------
161
 
162
  def _extract_text(result: Any) -> str:
163
  """
164
+ Robust extraction of assistant text from Gradio result.
165
+ Works with:
166
+ - tuple
167
+ - result.data
168
+ - dict["value"]
169
+ - direct list
 
170
  """
171
+
172
+ if hasattr(result, "data"):
173
+ result = result.data
174
+
175
+ if isinstance(result, tuple):
176
+ result = list(result)
177
+
178
+ if isinstance(result, dict):
179
+ if "value" in result:
180
+ result = result["value"]
181
+
182
+ if isinstance(result, list) and result:
183
+ last = result[-1]
184
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  if isinstance(last, dict):
186
+ if "content" in last:
187
+ return str(last["content"]).strip()
188
+ if "value" in last:
189
+ return str(last["value"]).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
+ if isinstance(last, (list, tuple)) and len(last) >= 2:
192
+ return str(last[1]).strip()
193
+
194
+ if isinstance(last, str):
195
+ return last.strip()
196
+
197
+ if isinstance(result, str):
198
+ return result.strip()
199
+
200
+ raise ValueError(f"Cannot extract text from result: {type(result)}")
201
+
202
+
203
+ # ---------------------------------------------------------------------------
204
+ # Retry Wrapper
205
+ # ---------------------------------------------------------------------------
206
+
207
+ async def _call_with_retries(func, *args, **kwargs):
208
+ for attempt in range(1, MAX_RETRIES + 1):
209
+ try:
210
+ return await asyncio.wait_for(func(*args, **kwargs), timeout=REQUEST_TIMEOUT)
211
+ except Exception as e:
212
+ if attempt >= MAX_RETRIES:
213
+ log.error("All retries failed.")
214
+ raise
215
+
216
+ delay = RETRY_BASE_DELAY ** attempt
217
+ log.warning(
218
+ "Attempt %d failed: %s | retrying in %.2fs",
219
+ attempt,
220
+ str(e),
221
+ delay,
222
+ )
223
+ await asyncio.sleep(delay)
224
+
225
+
226
+ # ---------------------------------------------------------------------------
227
+ # Falcon Call
228
+ # ---------------------------------------------------------------------------
229
+
230
+ async def _call_falcon_once(prompt: str, req: ChatCompletionRequest) -> str:
231
  client = await get_client()
232
+
233
  settings = {
234
  "model": req.model,
235
  "temperature": req.temperature,
236
  "max_new_tokens": req.max_tokens,
237
  "top_p": req.top_p,
238
  }
239
+
240
+ await asyncio.to_thread(client.predict, api_name="/new_chat")
241
+
 
 
 
 
 
 
 
242
  result = await asyncio.to_thread(
243
  client.predict,
244
  input_value=prompt,
245
  settings_form_value=settings,
246
+ api_name="/add_message",
247
  )
248
+
 
249
  return _extract_text(result)
250
 
251
+
252
+ async def _call_falcon(prompt: str, req: ChatCompletionRequest) -> str:
253
+ return await _call_with_retries(_call_falcon_once, prompt, req)
254
+
255
+
256
+ # ---------------------------------------------------------------------------
257
+ # Real Streaming (if Space supports /stream)
258
+ # ---------------------------------------------------------------------------
259
+
260
+ async def _stream_real(prompt: str, req: ChatCompletionRequest) -> AsyncGenerator[str, None]:
261
+ client = await get_client()
262
+
263
+ settings = {
264
  "model": req.model,
265
+ "temperature": req.temperature,
266
+ "max_new_tokens": req.max_tokens,
267
+ "top_p": req.top_p,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  }
269
 
270
+ await asyncio.to_thread(client.predict, api_name="/new_chat")
271
+
272
+ stream = await asyncio.to_thread(
273
+ client.submit,
274
+ input_value=prompt,
275
+ settings_form_value=settings,
276
+ api_name="/add_message",
277
+ )
278
+
279
  cid = f"chatcmpl-{uuid.uuid4().hex}"
280
  created = int(time.time())
281
+
282
+ async for update in stream:
283
+ text = _extract_text(update)
284
  chunk = {
285
  "id": cid,
286
  "object": "chat.completion.chunk",
 
288
  "model": req.model,
289
  "choices": [{
290
  "index": 0,
291
+ "delta": {"content": text},
292
  "finish_reason": None,
293
  }],
294
  }
295
  yield f"data: {json.dumps(chunk)}\n\n"
296
+
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  yield "data: [DONE]\n\n"
298
 
299
+
300
  # ---------------------------------------------------------------------------
301
+ # OpenAI Response Builder
302
  # ---------------------------------------------------------------------------
303
 
304
+ def _make_response(text: str, req: ChatCompletionRequest) -> dict:
305
+ pt = sum(len(_content_str(m)) for m in req.messages) // 4
306
+ ct = len(text) // 4
307
+
308
  return {
309
+ "id": f"chatcmpl-{uuid.uuid4().hex}",
310
+ "object": "chat.completion",
311
+ "created": int(time.time()),
312
+ "model": req.model,
313
+ "choices": [{
314
+ "index": 0,
315
+ "message": {"role": "assistant", "content": text},
316
+ "finish_reason": "stop",
317
+ }],
318
+ "usage": {
319
+ "prompt_tokens": pt,
320
+ "completion_tokens": ct,
321
+ "total_tokens": pt + ct,
322
  },
323
  }
324
 
 
 
 
325
 
326
+ # ---------------------------------------------------------------------------
327
+ # Routes
328
+ # ---------------------------------------------------------------------------
 
 
 
 
 
329
 
330
  @app.post("/v1/chat/completions")
331
  async def chat_completions(req: ChatCompletionRequest, _: None = Depends(verify_key)):
332
  prompt = _build_prompt(req.messages)
333
+
 
 
334
  try:
335
+ if req.stream:
336
+ try:
337
+ return StreamingResponse(
338
+ _stream_real(prompt, req),
339
+ media_type="text/event-stream",
340
+ )
341
+ except Exception:
342
+ log.warning("Real streaming failed, fallback to buffered.")
343
+ text = await _call_falcon(prompt, req)
344
+ return StreamingResponse(
345
+ _fake_stream(text, req),
346
+ media_type="text/event-stream",
347
+ )
348
+
349
  text = await _call_falcon(prompt, req)
350
+ return JSONResponse(content=_make_response(text, req))
351
+
352
+ except Exception as e:
353
+ log.exception("Final failure after retries.")
354
+ raise HTTPException(
355
+ status_code=502,
356
+ detail="Model temporarily unavailable. Please try again.",
357
+ )