nomid2 commited on
Commit
93eb401
·
verified ·
1 Parent(s): 8d2c197

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -99
app.py CHANGED
@@ -150,83 +150,58 @@ async def create_replicate_prediction(session: aiohttp.ClientSession, model: str
150
  logger.error(f"Error creating prediction: {str(e)}")
151
  raise HTTPException(status_code=500, detail=f"Prediction creation error: {str(e)}")
152
 
153
- def transform_replicate_to_openai_stream(event_data: str, model: str, prediction_id: str) -> str:
154
- """将Replicate流式响应转换为OpenAI格式"""
 
 
 
 
155
 
156
- # 处理不同的事件格式
157
- if event_data.startswith("data: "):
158
- content = event_data[6:] # 移除 "data: " 前缀
159
-
160
- # 检查是否是结束标记
161
- if content.strip() in ["[DONE]", ""]:
162
- # 发送结束标记
163
- openai_response = {
164
- "id": f"chatcmpl-{prediction_id}",
165
- "object": "chat.completion.chunk",
166
- "created": int(asyncio.get_event_loop().time()),
167
- "model": model,
168
- "choices": [{
169
- "index": 0,
170
- "delta": {},
171
- "finish_reason": "stop"
172
- }]
173
- }
174
- return f"data: {json.dumps(openai_response)}\n\ndata: [DONE]\n\n"
175
-
176
- # 尝试解析为JSON(用于其他事件类型)
177
- try:
178
- data = json.loads(content)
179
- if data.get("event") == "output":
180
- openai_response = {
181
- "id": f"chatcmpl-{prediction_id}",
182
- "object": "chat.completion.chunk",
183
- "created": int(asyncio.get_event_loop().time()),
184
- "model": model,
185
- "choices": [{
186
- "index": 0,
187
- "delta": {
188
- "content": data.get("data", "")
189
- },
190
- "finish_reason": None
191
- }]
192
- }
193
- return f"data: {json.dumps(openai_response)}\n\n"
194
- elif data.get("event") == "done":
195
- openai_response = {
196
- "id": f"chatcmpl-{prediction_id}",
197
- "object": "chat.completion.chunk",
198
- "created": int(asyncio.get_event_loop().time()),
199
- "model": model,
200
- "choices": [{
201
- "index": 0,
202
- "delta": {},
203
- "finish_reason": "stop"
204
- }]
205
  }
206
- return f"data: {json.dumps(openai_response)}\n\ndata: [DONE]\n\n"
207
- except json.JSONDecodeError:
208
- # 不是JSON,作为直接文本内容处理
209
- if content.strip(): # 只有非空内容才发送
210
- openai_response = {
211
- "id": f"chatcmpl-{prediction_id}",
212
- "object": "chat.completion.chunk",
213
- "created": int(asyncio.get_event_loop().time()),
214
- "model": model,
215
- "choices": [{
216
- "index": 0,
217
- "delta": {
218
- "content": content
219
- },
220
- "finish_reason": None
221
- }]
222
- }
223
- return f"data: {json.dumps(openai_response)}\n\n"
 
 
224
 
225
- # 处理其他类型的事件(如event: output等)
226
- elif event_data.startswith("event: "):
227
- return "" # 忽略事件类型行
228
 
229
- return ""
230
 
231
  @app.get("/")
232
  async def root():
@@ -328,9 +303,9 @@ async def chat_completions(request: Request):
328
  "Cache-Control": "no-store"
329
  }
330
 
331
- stream_finished = False
332
 
333
- async with session.get(stream_url, headers=headers, timeout=300) as response:
334
  if response.status != 200:
335
  error_text = await response.text()
336
  logger.error(f"Stream error: {response.status} - {error_text}")
@@ -344,32 +319,40 @@ async def chat_completions(request: Request):
344
  return
345
 
346
  async for line in response.content:
347
- line = line.decode('utf-8').strip()
348
- if line:
349
- logger.info(f"Received stream line: {line}")
350
- openai_event = transform_replicate_to_openai_stream(line, model, prediction_id)
351
- if openai_event:
352
- yield openai_event
353
- # 检查是否是结束事件
354
- if "[DONE]" in openai_event:
355
- stream_finished = True
356
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
- # 如果流没有正常结束,发送结束标记
359
- if not stream_finished:
360
- final_response = {
361
- "id": f"chatcmpl-{prediction_id}",
362
- "object": "chat.completion.chunk",
363
- "created": int(asyncio.get_event_loop().time()),
364
- "model": model,
365
- "choices": [{
366
- "index": 0,
367
- "delta": {},
368
- "finish_reason": "stop"
369
- }]
370
- }
371
- yield f"data: {json.dumps(final_response)}\n\ndata: [DONE]\n\n"
372
 
 
 
 
 
373
  except Exception as e:
374
  logger.error(f"Stream generation error: {e}")
375
  error_response = {
@@ -387,6 +370,7 @@ async def chat_completions(request: Request):
387
  "Cache-Control": "no-cache",
388
  "Connection": "keep-alive",
389
  "Access-Control-Allow-Origin": "*",
 
390
  }
391
  )
392
 
 
150
  logger.error(f"Error creating prediction: {str(e)}")
151
  raise HTTPException(status_code=500, detail=f"Prediction creation error: {str(e)}")
152
 
153
+ class SSEParser:
154
+ """Server-Sent Events 解析器"""
155
+ def __init__(self):
156
+ self.event_type = None
157
+ self.event_id = None
158
+ self.data_buffer = []
159
 
160
+ def parse_line(self, line: str):
161
+ """解析 SSE 格式的一行"""
162
+ if line.startswith('event: '):
163
+ self.event_type = line[7:].strip()
164
+ elif line.startswith('id: '):
165
+ self.event_id = line[4:].strip()
166
+ elif line.startswith('data: '):
167
+ self.data_buffer.append(line[6:])
168
+ elif line.startswith(': '):
169
+ # 注释行,忽略
170
+ pass
171
+ elif line == '':
172
+ # 空行表示事件结束
173
+ if self.data_buffer or self.event_type:
174
+ data = '\n'.join(self.data_buffer)
175
+ event = {
176
+ 'event': self.event_type,
177
+ 'id': self.event_id,
178
+ 'data': data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  }
180
+ # 重置缓冲区
181
+ self.event_type = None
182
+ self.event_id = None
183
+ self.data_buffer = []
184
+ return event
185
+ return None
186
+
187
+ def create_openai_chunk(content: str, model: str, prediction_id: str, finish_reason=None):
188
+ """创建 OpenAI 格式的流式响应块"""
189
+ chunk = {
190
+ "id": f"chatcmpl-{prediction_id}",
191
+ "object": "chat.completion.chunk",
192
+ "created": int(asyncio.get_event_loop().time()),
193
+ "model": model,
194
+ "choices": [{
195
+ "index": 0,
196
+ "delta": {},
197
+ "finish_reason": finish_reason
198
+ }]
199
+ }
200
 
201
+ if content and not finish_reason:
202
+ chunk["choices"][0]["delta"]["content"] = content
 
203
 
204
+ return f"data: {json.dumps(chunk)}\n\n"
205
 
206
  @app.get("/")
207
  async def root():
 
303
  "Cache-Control": "no-store"
304
  }
305
 
306
+ sse_parser = SSEParser()
307
 
308
+ async with session.get(stream_url, headers=headers, timeout=120) as response:
309
  if response.status != 200:
310
  error_text = await response.text()
311
  logger.error(f"Stream error: {response.status} - {error_text}")
 
319
  return
320
 
321
  async for line in response.content:
322
+ line = line.decode('utf-8').rstrip('\r\n')
323
+
324
+ # 跳过超时或错误消息
325
+ if '408' in line or 'timeout' in line.lower():
326
+ logger.info(f"Ignoring timeout message: {line}")
327
+ continue
328
+
329
+ # 解析 SSE 事件
330
+ event = sse_parser.parse_line(line)
331
+ if event:
332
+ event_type = event.get('event')
333
+ data = event.get('data', '')
334
+
335
+ logger.info(f"Parsed SSE event: {event_type}, data: {data[:50]}...")
336
+
337
+ if event_type == 'output' and data.strip():
338
+ # 输出事件,包含实际内容
339
+ yield create_openai_chunk(data, model, prediction_id)
340
+ elif event_type == 'done':
341
+ # 完成事件
342
+ logger.info("Stream completed with done event")
343
+ yield create_openai_chunk("", model, prediction_id, "stop")
344
+ yield "data: [DONE]\n\n"
345
+ return
346
 
347
+ # 如果没有收到 done 事件,手动发送结束
348
+ logger.info("Stream ended without done event, sending manual completion")
349
+ yield create_openai_chunk("", model, prediction_id, "stop")
350
+ yield "data: [DONE]\n\n"
 
 
 
 
 
 
 
 
 
 
351
 
352
+ except asyncio.TimeoutError:
353
+ logger.error("Stream timeout")
354
+ yield create_openai_chunk("", model, prediction_id or "unknown", "stop")
355
+ yield "data: [DONE]\n\n"
356
  except Exception as e:
357
  logger.error(f"Stream generation error: {e}")
358
  error_response = {
 
370
  "Cache-Control": "no-cache",
371
  "Connection": "keep-alive",
372
  "Access-Control-Allow-Origin": "*",
373
+ "X-Accel-Buffering": "no", # 禁用 Nginx 缓冲
374
  }
375
  )
376