adowu commited on
Commit
eda6854
·
verified ·
1 Parent(s): 13045e2

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +101 -108
main.py CHANGED
@@ -30,9 +30,9 @@ DEFAULT_TEMP = float(os.getenv("DEFAULT_TEMPERATURE", "0.6"))
30
  DEFAULT_TOP_P = float(os.getenv("DEFAULT_TOP_P", "0.95"))
31
  DEFAULT_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "32000"))
32
 
33
- REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", "180"))
34
  MAX_RETRIES = int(os.getenv("MAX_RETRIES", "3"))
35
- RETRY_BASE_DELAY = float(os.getenv("RETRY_BASE_DELAY", "1.7"))
36
 
37
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
38
  log = logging.getLogger(__name__)
@@ -49,7 +49,7 @@ async def get_client() -> Client:
49
  if _client is None:
50
  log.info("Connecting to %s", HF_SPACE_URL)
51
  _client = await asyncio.to_thread(Client, HF_SPACE_URL)
52
- log.info("Connected to Space.")
53
  return _client
54
 
55
 
@@ -106,8 +106,8 @@ async def lifespan(app: FastAPI):
106
  # ---------------------------------------------------------------------------
107
 
108
  app = FastAPI(
109
- title="FHR",
110
- version="4.1.0",
111
  lifespan=lifespan,
112
  )
113
 
@@ -124,17 +124,27 @@ app.add_middleware(
124
  # ---------------------------------------------------------------------------
125
 
126
  def _content_str(m: Message) -> str:
 
 
 
 
127
  if isinstance(m.content, str):
128
  return m.content
129
- return "".join(
130
- p.get("text", "") or p.get("content", "")
131
- for p in m.content
132
- if isinstance(p, dict)
133
- )
 
 
134
 
135
 
136
  def _build_prompt(messages: list[Message]) -> str:
 
 
 
137
  system, parts = [], []
 
138
  for m in messages:
139
  c = _content_str(m).strip()
140
  if not c:
@@ -144,77 +154,83 @@ def _build_prompt(messages: list[Message]) -> str:
144
  system.append(c)
145
  elif m.role == "assistant":
146
  parts.append(f"[ASSISTANT]\n{c}")
147
- else:
148
  parts.append(c)
149
 
150
- prefix = ""
151
- if system:
152
- prefix = "[SYSTEM]\n" + "\n".join(system) + "\n[/SYSTEM]\n\n"
153
-
154
- return prefix + "\n\n".join(parts)
155
 
156
 
157
  # ---------------------------------------------------------------------------
158
- # Robust Extraction
159
  # ---------------------------------------------------------------------------
160
 
161
  def _extract_text(result: Any) -> str:
162
- if hasattr(result, "data"):
163
- result = result.data
164
-
165
  if isinstance(result, tuple):
166
- result = list(result)
 
 
 
 
167
 
168
- if isinstance(result, dict) and "value" in result:
169
- result = result["value"]
170
 
171
- if isinstance(result, list) and result:
172
- last = result[-1]
 
 
 
 
 
 
173
 
174
- if isinstance(last, dict):
175
- if "content" in last:
176
- return str(last["content"]).strip()
177
- if "value" in last:
178
- return str(last["value"]).strip()
179
 
180
- if isinstance(last, (list, tuple)) and len(last) >= 2:
181
- return str(last[1]).strip()
182
 
183
- if isinstance(last, str):
184
- return last.strip()
 
 
 
 
185
 
186
- if isinstance(result, str):
187
- return result.strip()
 
 
 
 
188
 
189
- raise ValueError("Unable to extract model response.")
190
 
191
 
192
  # ---------------------------------------------------------------------------
193
- # Retry Wrapper
194
  # ---------------------------------------------------------------------------
195
 
196
- async def _call_with_retries(func, *args, **kwargs):
 
 
197
  for attempt in range(1, MAX_RETRIES + 1):
198
  try:
199
- return await asyncio.wait_for(func(*args, **kwargs), timeout=REQUEST_TIMEOUT)
 
 
 
200
  except Exception as e:
201
- if attempt >= MAX_RETRIES:
202
- log.error("All retries failed.")
203
- raise
204
 
205
  delay = RETRY_BASE_DELAY ** attempt
206
- log.warning(
207
- "Attempt %d failed: %s | retrying in %.2fs",
208
- attempt,
209
- str(e),
210
- delay,
211
- )
212
  await asyncio.sleep(delay)
213
 
 
214
 
215
- # ---------------------------------------------------------------------------
216
- # Falcon Call
217
- # ---------------------------------------------------------------------------
218
 
219
  async def _call_falcon_once(prompt: str, req: ChatCompletionRequest) -> str:
220
  client = await get_client()
@@ -238,65 +254,42 @@ async def _call_falcon_once(prompt: str, req: ChatCompletionRequest) -> str:
238
  return _extract_text(result)
239
 
240
 
241
- async def _call_falcon(prompt: str, req: ChatCompletionRequest) -> str:
242
- return await _call_with_retries(_call_falcon_once, prompt, req)
243
-
244
-
245
  # ---------------------------------------------------------------------------
246
- # SAFE STREAMING (HF Spaces stable)
247
  # ---------------------------------------------------------------------------
248
 
249
- async def _safe_stream(prompt: str, req: ChatCompletionRequest) -> AsyncGenerator[str, None]:
250
- """
251
- Stable streaming for HF Spaces:
252
- 1. Generate full response with retries
253
- 2. Stream chunks safely
254
- """
255
-
256
- text = await _call_falcon(prompt, req)
257
-
258
  cid = f"chatcmpl-{uuid.uuid4().hex}"
259
  created = int(time.time())
260
 
261
- try:
262
- for i in range(0, len(text), 16):
263
- chunk = {
264
- "id": cid,
265
- "object": "chat.completion.chunk",
266
- "created": created,
267
- "model": req.model,
268
- "choices": [{
269
- "index": 0,
270
- "delta": {"content": text[i:i+16]},
271
- "finish_reason": None,
272
- }],
273
- }
274
-
275
- yield f"data: {json.dumps(chunk)}\n\n"
276
- await asyncio.sleep(0.02)
277
-
278
- final = {
279
  "id": cid,
280
  "object": "chat.completion.chunk",
281
  "created": created,
282
  "model": req.model,
283
  "choices": [{
284
  "index": 0,
285
- "delta": {},
286
- "finish_reason": "stop",
287
  }],
288
  }
 
 
289
 
290
- yield f"data: {json.dumps(final)}\n\n"
291
- yield "data: [DONE]\n\n"
 
 
 
 
 
292
 
293
- except Exception:
294
- log.exception("Streaming crashed unexpectedly.")
295
- yield "data: [DONE]\n\n"
296
 
297
 
298
  # ---------------------------------------------------------------------------
299
- # OpenAI Response Builder
300
  # ---------------------------------------------------------------------------
301
 
302
  def _make_response(text: str, req: ChatCompletionRequest) -> dict:
@@ -330,23 +323,23 @@ async def chat_completions(req: ChatCompletionRequest, _: None = Depends(verify_
330
  prompt = _build_prompt(req.messages)
331
 
332
  try:
333
- if req.stream:
334
- return StreamingResponse(
335
- _safe_stream(prompt, req),
336
- media_type="text/event-stream",
337
- headers={
338
- "Cache-Control": "no-cache",
339
- "Connection": "keep-alive",
340
- "X-Accel-Buffering": "no",
341
- },
342
- )
343
-
344
- text = await _call_falcon(prompt, req)
345
- return JSONResponse(content=_make_response(text, req))
346
-
347
  except Exception:
348
- log.exception("Final failure after retries.")
349
  raise HTTPException(
350
  status_code=502,
351
  detail="Model temporarily unavailable. Please try again.",
352
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  DEFAULT_TOP_P = float(os.getenv("DEFAULT_TOP_P", "0.95"))
31
  DEFAULT_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "32000"))
32
 
33
+ REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", "120"))
34
  MAX_RETRIES = int(os.getenv("MAX_RETRIES", "3"))
35
+ RETRY_BASE_DELAY = float(os.getenv("RETRY_BASE_DELAY", "1.5"))
36
 
37
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
38
  log = logging.getLogger(__name__)
 
49
  if _client is None:
50
  log.info("Connecting to %s", HF_SPACE_URL)
51
  _client = await asyncio.to_thread(Client, HF_SPACE_URL)
52
+ log.info("Connected.")
53
  return _client
54
 
55
 
 
106
  # ---------------------------------------------------------------------------
107
 
108
  app = FastAPI(
109
+ title="Foc",
110
+ version="5.0.0",
111
  lifespan=lifespan,
112
  )
113
 
 
124
  # ---------------------------------------------------------------------------
125
 
126
  def _content_str(m: Message) -> str:
127
+ """
128
+ Extract ONLY text blocks.
129
+ This preserves Dyad compatibility and filters UI noise.
130
+ """
131
  if isinstance(m.content, str):
132
  return m.content
133
+
134
+ text_parts = []
135
+ for p in m.content:
136
+ if isinstance(p, dict) and p.get("type") == "text":
137
+ text_parts.append(p.get("text", "").strip())
138
+
139
+ return "".join(text_parts)
140
 
141
 
142
  def _build_prompt(messages: list[Message]) -> str:
143
+ """
144
+ Preserve original F alignment.
145
+ """
146
  system, parts = [], []
147
+
148
  for m in messages:
149
  c = _content_str(m).strip()
150
  if not c:
 
154
  system.append(c)
155
  elif m.role == "assistant":
156
  parts.append(f"[ASSISTANT]\n{c}")
157
+ elif m.role == "user":
158
  parts.append(c)
159
 
160
+ prefix = "[SYSTEM]\n" + "\n".join(system) + "\n[/SYSTEM]\n" if system else ""
161
+ return prefix + "\n".join(parts)
 
 
 
162
 
163
 
164
  # ---------------------------------------------------------------------------
165
+ # Robust extraction
166
  # ---------------------------------------------------------------------------
167
 
168
  def _extract_text(result: Any) -> str:
 
 
 
169
  if isinstance(result, tuple):
170
+ data = result
171
+ elif hasattr(result, "data"):
172
+ data = result.data
173
+ else:
174
+ data = [result]
175
 
176
+ conversation = None
 
177
 
178
+ for item in data:
179
+ if isinstance(item, dict) and "value" in item:
180
+ if isinstance(item["value"], list):
181
+ conversation = item["value"]
182
+ break
183
+ elif isinstance(item, list):
184
+ conversation = item
185
+ break
186
 
187
+ if not conversation:
188
+ raise ValueError("Cannot extract conversation from result")
 
 
 
189
 
190
+ last = conversation[-1]
 
191
 
192
+ if isinstance(last, dict):
193
+ content = last.get("content", "")
194
+ elif isinstance(last, (list, tuple)) and len(last) >= 2:
195
+ content = last[1] or ""
196
+ else:
197
+ content = str(last)
198
 
199
+ if isinstance(content, list):
200
+ parts = []
201
+ for block in content:
202
+ if isinstance(block, dict) and block.get("type") == "text":
203
+ parts.append(block.get("content", block.get("text", "")))
204
+ return "".join(parts).strip()
205
 
206
+ return str(content).strip()
207
 
208
 
209
  # ---------------------------------------------------------------------------
210
+ # Retry wrapper
211
  # ---------------------------------------------------------------------------
212
 
213
+ async def _call_with_retries(prompt: str, req: ChatCompletionRequest) -> str:
214
+ last_error = None
215
+
216
  for attempt in range(1, MAX_RETRIES + 1):
217
  try:
218
+ return await asyncio.wait_for(
219
+ _call_falcon_once(prompt, req),
220
+ timeout=REQUEST_TIMEOUT,
221
+ )
222
  except Exception as e:
223
+ last_error = e
224
+ if attempt == MAX_RETRIES:
225
+ break
226
 
227
  delay = RETRY_BASE_DELAY ** attempt
228
+ log.warning("Attempt %d failed: %s | retrying in %.2fs",
229
+ attempt, str(e), delay)
 
 
 
 
230
  await asyncio.sleep(delay)
231
 
232
+ raise last_error
233
 
 
 
 
234
 
235
  async def _call_falcon_once(prompt: str, req: ChatCompletionRequest) -> str:
236
  client = await get_client()
 
254
  return _extract_text(result)
255
 
256
 
 
 
 
 
257
  # ---------------------------------------------------------------------------
258
+ # Streaming (buffered safe streaming)
259
  # ---------------------------------------------------------------------------
260
 
261
+ async def _stream_sse(text: str, req: ChatCompletionRequest) -> AsyncGenerator[str, None]:
 
 
 
 
 
 
 
 
262
  cid = f"chatcmpl-{uuid.uuid4().hex}"
263
  created = int(time.time())
264
 
265
+ for i in range(0, len(text), 8):
266
+ chunk = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  "id": cid,
268
  "object": "chat.completion.chunk",
269
  "created": created,
270
  "model": req.model,
271
  "choices": [{
272
  "index": 0,
273
+ "delta": {"content": text[i:i+8]},
274
+ "finish_reason": None,
275
  }],
276
  }
277
+ yield f"data: {json.dumps(chunk)}\n\n"
278
+ await asyncio.sleep(0.01)
279
 
280
+ yield f"data: {json.dumps({
281
+ 'id': cid,
282
+ 'object': 'chat.completion.chunk',
283
+ 'created': created,
284
+ 'model': req.model,
285
+ 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}],
286
+ })}\n\n"
287
 
288
+ yield "data: [DONE]\n\n"
 
 
289
 
290
 
291
  # ---------------------------------------------------------------------------
292
+ # OpenAI response builder
293
  # ---------------------------------------------------------------------------
294
 
295
  def _make_response(text: str, req: ChatCompletionRequest) -> dict:
 
323
  prompt = _build_prompt(req.messages)
324
 
325
  try:
326
+ text = await _call_with_retries(prompt, req)
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  except Exception:
328
+ log.exception("Falcon failed after retries")
329
  raise HTTPException(
330
  status_code=502,
331
  detail="Model temporarily unavailable. Please try again.",
332
+ )
333
+
334
+ if req.stream:
335
+ return StreamingResponse(
336
+ _stream_sse(text, req),
337
+ media_type="text/event-stream",
338
+ headers={
339
+ "Cache-Control": "no-cache",
340
+ "X-Accel-Buffering": "no",
341
+ "Connection": "keep-alive",
342
+ },
343
+ )
344
+
345
+ return JSONResponse(content=_make_response(text, req))