adowu commited on
Commit
13045e2
·
verified ·
1 Parent(s): accec81

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +50 -55
main.py CHANGED
@@ -30,9 +30,9 @@ DEFAULT_TEMP = float(os.getenv("DEFAULT_TEMPERATURE", "0.6"))
30
  DEFAULT_TOP_P = float(os.getenv("DEFAULT_TOP_P", "0.95"))
31
  DEFAULT_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "32000"))
32
 
33
- REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", "120"))
34
  MAX_RETRIES = int(os.getenv("MAX_RETRIES", "3"))
35
- RETRY_BASE_DELAY = float(os.getenv("RETRY_BASE_DELAY", "1.5"))
36
 
37
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
38
  log = logging.getLogger(__name__)
@@ -107,7 +107,7 @@ async def lifespan(app: FastAPI):
107
 
108
  app = FastAPI(
109
  title="FHR",
110
- version="4.0.0",
111
  lifespan=lifespan,
112
  )
113
 
@@ -123,7 +123,6 @@ app.add_middleware(
123
  # Utilities
124
  # ---------------------------------------------------------------------------
125
 
126
-
127
  def _content_str(m: Message) -> str:
128
  if isinstance(m.content, str):
129
  return m.content
@@ -160,24 +159,14 @@ def _build_prompt(messages: list[Message]) -> str:
160
  # ---------------------------------------------------------------------------
161
 
162
  def _extract_text(result: Any) -> str:
163
- """
164
- Robust extraction of assistant text from Gradio result.
165
- Works with:
166
- - tuple
167
- - result.data
168
- - dict["value"]
169
- - direct list
170
- """
171
-
172
  if hasattr(result, "data"):
173
  result = result.data
174
 
175
  if isinstance(result, tuple):
176
  result = list(result)
177
 
178
- if isinstance(result, dict):
179
- if "value" in result:
180
- result = result["value"]
181
 
182
  if isinstance(result, list) and result:
183
  last = result[-1]
@@ -197,7 +186,7 @@ def _extract_text(result: Any) -> str:
197
  if isinstance(result, str):
198
  return result.strip()
199
 
200
- raise ValueError(f"Cannot extract text from result: {type(result)}")
201
 
202
 
203
  # ---------------------------------------------------------------------------
@@ -254,47 +243,56 @@ async def _call_falcon(prompt: str, req: ChatCompletionRequest) -> str:
254
 
255
 
256
  # ---------------------------------------------------------------------------
257
- # Real Streaming (if Space supports /stream)
258
  # ---------------------------------------------------------------------------
259
 
260
- async def _stream_real(prompt: str, req: ChatCompletionRequest) -> AsyncGenerator[str, None]:
261
- client = await get_client()
262
-
263
- settings = {
264
- "model": req.model,
265
- "temperature": req.temperature,
266
- "max_new_tokens": req.max_tokens,
267
- "top_p": req.top_p,
268
- }
269
 
270
- await asyncio.to_thread(client.predict, api_name="/new_chat")
271
-
272
- stream = await asyncio.to_thread(
273
- client.submit,
274
- input_value=prompt,
275
- settings_form_value=settings,
276
- api_name="/add_message",
277
- )
278
 
279
  cid = f"chatcmpl-{uuid.uuid4().hex}"
280
  created = int(time.time())
281
 
282
- async for update in stream:
283
- text = _extract_text(update)
284
- chunk = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  "id": cid,
286
  "object": "chat.completion.chunk",
287
  "created": created,
288
  "model": req.model,
289
  "choices": [{
290
  "index": 0,
291
- "delta": {"content": text},
292
- "finish_reason": None,
293
  }],
294
  }
295
- yield f"data: {json.dumps(chunk)}\n\n"
296
 
297
- yield "data: [DONE]\n\n"
 
 
 
 
 
298
 
299
 
300
  # ---------------------------------------------------------------------------
@@ -333,23 +331,20 @@ async def chat_completions(req: ChatCompletionRequest, _: None = Depends(verify_
333
 
334
  try:
335
  if req.stream:
336
- try:
337
- return StreamingResponse(
338
- _stream_real(prompt, req),
339
- media_type="text/event-stream",
340
- )
341
- except Exception:
342
- log.warning("Real streaming failed, fallback to buffered.")
343
- text = await _call_falcon(prompt, req)
344
- return StreamingResponse(
345
- _fake_stream(text, req),
346
- media_type="text/event-stream",
347
- )
348
 
349
  text = await _call_falcon(prompt, req)
350
  return JSONResponse(content=_make_response(text, req))
351
 
352
- except Exception as e:
353
  log.exception("Final failure after retries.")
354
  raise HTTPException(
355
  status_code=502,
 
30
  DEFAULT_TOP_P = float(os.getenv("DEFAULT_TOP_P", "0.95"))
31
  DEFAULT_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "32000"))
32
 
33
+ REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", "180"))
34
  MAX_RETRIES = int(os.getenv("MAX_RETRIES", "3"))
35
+ RETRY_BASE_DELAY = float(os.getenv("RETRY_BASE_DELAY", "1.7"))
36
 
37
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
38
  log = logging.getLogger(__name__)
 
107
 
108
  app = FastAPI(
109
  title="FHR",
110
+ version="4.1.0",
111
  lifespan=lifespan,
112
  )
113
 
 
123
  # Utilities
124
  # ---------------------------------------------------------------------------
125
 
 
126
  def _content_str(m: Message) -> str:
127
  if isinstance(m.content, str):
128
  return m.content
 
159
  # ---------------------------------------------------------------------------
160
 
161
  def _extract_text(result: Any) -> str:
 
 
 
 
 
 
 
 
 
162
  if hasattr(result, "data"):
163
  result = result.data
164
 
165
  if isinstance(result, tuple):
166
  result = list(result)
167
 
168
+ if isinstance(result, dict) and "value" in result:
169
+ result = result["value"]
 
170
 
171
  if isinstance(result, list) and result:
172
  last = result[-1]
 
186
  if isinstance(result, str):
187
  return result.strip()
188
 
189
+ raise ValueError("Unable to extract model response.")
190
 
191
 
192
  # ---------------------------------------------------------------------------
 
243
 
244
 
245
  # ---------------------------------------------------------------------------
246
+ # SAFE STREAMING (HF Spaces stable)
247
  # ---------------------------------------------------------------------------
248
 
249
+ async def _safe_stream(prompt: str, req: ChatCompletionRequest) -> AsyncGenerator[str, None]:
250
+ """
251
+ Stable streaming for HF Spaces:
252
+ 1. Generate full response with retries
253
+ 2. Stream chunks safely
254
+ """
 
 
 
255
 
256
+ text = await _call_falcon(prompt, req)
 
 
 
 
 
 
 
257
 
258
  cid = f"chatcmpl-{uuid.uuid4().hex}"
259
  created = int(time.time())
260
 
261
+ try:
262
+ for i in range(0, len(text), 16):
263
+ chunk = {
264
+ "id": cid,
265
+ "object": "chat.completion.chunk",
266
+ "created": created,
267
+ "model": req.model,
268
+ "choices": [{
269
+ "index": 0,
270
+ "delta": {"content": text[i:i+16]},
271
+ "finish_reason": None,
272
+ }],
273
+ }
274
+
275
+ yield f"data: {json.dumps(chunk)}\n\n"
276
+ await asyncio.sleep(0.02)
277
+
278
+ final = {
279
  "id": cid,
280
  "object": "chat.completion.chunk",
281
  "created": created,
282
  "model": req.model,
283
  "choices": [{
284
  "index": 0,
285
+ "delta": {},
286
+ "finish_reason": "stop",
287
  }],
288
  }
 
289
 
290
+ yield f"data: {json.dumps(final)}\n\n"
291
+ yield "data: [DONE]\n\n"
292
+
293
+ except Exception:
294
+ log.exception("Streaming crashed unexpectedly.")
295
+ yield "data: [DONE]\n\n"
296
 
297
 
298
  # ---------------------------------------------------------------------------
 
331
 
332
  try:
333
  if req.stream:
334
+ return StreamingResponse(
335
+ _safe_stream(prompt, req),
336
+ media_type="text/event-stream",
337
+ headers={
338
+ "Cache-Control": "no-cache",
339
+ "Connection": "keep-alive",
340
+ "X-Accel-Buffering": "no",
341
+ },
342
+ )
 
 
 
343
 
344
  text = await _call_falcon(prompt, req)
345
  return JSONResponse(content=_make_response(text, req))
346
 
347
+ except Exception:
348
  log.exception("Final failure after retries.")
349
  raise HTTPException(
350
  status_code=502,