nomid2 commited on
Commit
fcdaffb
·
verified ·
1 Parent(s): e4c9bed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -68
app.py CHANGED
@@ -39,7 +39,7 @@ if not REPLICATE_API_TOKEN:
39
 
40
  # Replicate API配置
41
  REPLICATE_BASE_URL = "https://api.replicate.com/v1"
42
- DEFAULT_MODEL = "anthropic/claude-3.5-sonnet" # 使用实际存在的模型
43
 
44
  # 全局异常处理器
45
  @app.exception_handler(Exception)
@@ -85,22 +85,21 @@ def transform_openai_to_replicate(openai_request: Dict[str, Any], model_override
85
  if prompt_parts and not prompt.endswith("\n\nAssistant:"):
86
  prompt += "\n\nAssistant:"
87
 
88
- # 确定使用的模型 - 使用正确的 Replicate 模型名称
89
  model = model_override or openai_request.get("model", DEFAULT_MODEL)
90
 
91
- # 正确的模型名称映射(基于搜索结果)
92
  model_mapping = {
93
- "claude-4-sonnet": "anthropic/claude-4-sonnet", # 最新的 Claude 4
94
- "claude-3.5-sonnet": "anthropic/claude-3.5-sonnet", # Claude 3.5 Sonnet
95
- "claude-3-sonnet": "anthropic/claude-3-sonnet", # Claude 3 Sonnet
96
- "claude-3.5-haiku": "anthropic/claude-3.5-haiku", # Claude 3.5 Haiku
97
- "claude-3-haiku": "anthropic/claude-3-haiku", # Claude 3 Haiku
98
  }
99
 
100
  if model in model_mapping:
101
  model = model_mapping[model]
102
  elif not model.startswith("anthropic/"):
103
- # 默认使用 claude-3.5-sonnet
104
  model = "anthropic/claude-3.5-sonnet"
105
 
106
  replicate_request = {
@@ -130,12 +129,10 @@ async def create_replicate_prediction(session: aiohttp.ClientSession, model: str
130
  }
131
 
132
  logger.info(f"Creating prediction for model: {model}")
133
- logger.info(f"Request URL: {url}")
134
 
135
  async with session.post(url, headers=headers, json=data, timeout=30) as response:
136
  response_text = await response.text()
137
  logger.info(f"Replicate response status: {response.status}")
138
- logger.info(f"Replicate response: {response_text}")
139
 
140
  if response.status != 201:
141
  logger.error(f"Replicate API error: {response.status} - {response_text}")
@@ -153,31 +150,6 @@ async def create_replicate_prediction(session: aiohttp.ClientSession, model: str
153
  logger.error(f"Error creating prediction: {str(e)}")
154
  raise HTTPException(status_code=500, detail=f"Prediction creation error: {str(e)}")
155
 
156
- async def stream_replicate_response(session: aiohttp.ClientSession, stream_url: str) -> AsyncGenerator[str, None]:
157
- """流式读取Replicate响应"""
158
- try:
159
- headers = {
160
- "Accept": "text/event-stream",
161
- "Cache-Control": "no-store"
162
- }
163
-
164
- logger.info(f"Starting stream from: {stream_url}")
165
-
166
- async with session.get(stream_url, headers=headers, timeout=300) as response:
167
- if response.status != 200:
168
- error_text = await response.text()
169
- logger.error(f"Stream error: {response.status} - {error_text}")
170
- raise HTTPException(status_code=response.status, detail=f"Stream error: {error_text}")
171
-
172
- async for line in response.content:
173
- line = line.decode('utf-8').strip()
174
- if line:
175
- yield line
176
-
177
- except Exception as e:
178
- logger.error(f"Stream error: {str(e)}")
179
- raise
180
-
181
  def transform_replicate_to_openai_stream(event_data: str, model: str) -> str:
182
  """将Replicate流式响应转换为OpenAI格式"""
183
  if not event_data.startswith("data: "):
@@ -290,32 +262,62 @@ async def chat_completions(request: Request):
290
  try:
291
  body = await request.json()
292
  logger.info(f"Received chat completion request")
293
- logger.info(f"Request body: {json.dumps(body, indent=2)}")
294
 
295
  # 转换请求格式
296
  replicate_data, model = transform_openai_to_replicate(body)
297
 
298
- async with aiohttp.ClientSession() as session:
299
- # 创建预测
300
- prediction = await create_replicate_prediction(session, model, replicate_data)
301
- prediction_id = prediction.get('id')
302
- logger.info(f"Created prediction: {prediction_id}")
303
-
304
- if body.get("stream", False):
305
- # 流式响应
306
- stream_url = prediction.get("urls", {}).get("stream")
307
- if not stream_url:
308
- raise HTTPException(status_code=500, detail="Stream URL not available")
309
-
310
- async def generate_stream():
311
  try:
312
- async for event in stream_replicate_response(session, stream_url):
313
- openai_event = transform_replicate_to_openai_stream(event, model)
314
- if openai_event:
315
- yield openai_event
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  except Exception as e:
317
  logger.error(f"Stream generation error: {e}")
318
- # 发送错误响应
319
  error_response = {
320
  "error": {
321
  "message": str(e),
@@ -323,23 +325,29 @@ async def chat_completions(request: Request):
323
  }
324
  }
325
  yield f"data: {json.dumps(error_response)}\n\n"
326
-
327
- return StreamingResponse(
328
- generate_stream(),
329
- media_type="text/event-stream",
330
- headers={
331
- "Cache-Control": "no-cache",
332
- "Connection": "keep-alive",
333
- "Access-Control-Allow-Origin": "*",
334
- }
335
- )
336
 
337
- else:
338
- # 非流式响应 - 等待预测完成
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  prediction_url = f"{REPLICATE_BASE_URL}/predictions/{prediction_id}"
340
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}"}
341
 
342
- # 轮询等待结果
343
  max_attempts = 60 # 最多等待60秒
344
  attempt = 0
345
 
 
39
 
40
  # Replicate API配置
41
  REPLICATE_BASE_URL = "https://api.replicate.com/v1"
42
+ DEFAULT_MODEL = "anthropic/claude-3.5-sonnet"
43
 
44
  # 全局异常处理器
45
  @app.exception_handler(Exception)
 
85
  if prompt_parts and not prompt.endswith("\n\nAssistant:"):
86
  prompt += "\n\nAssistant:"
87
 
88
+ # 确定使用的模型
89
  model = model_override or openai_request.get("model", DEFAULT_MODEL)
90
 
91
+ # 正确的模型名称映射
92
  model_mapping = {
93
+ "claude-4-sonnet": "anthropic/claude-4-sonnet",
94
+ "claude-3.5-sonnet": "anthropic/claude-3.5-sonnet",
95
+ "claude-3-sonnet": "anthropic/claude-3-sonnet",
96
+ "claude-3.5-haiku": "anthropic/claude-3.5-haiku",
97
+ "claude-3-haiku": "anthropic/claude-3-haiku",
98
  }
99
 
100
  if model in model_mapping:
101
  model = model_mapping[model]
102
  elif not model.startswith("anthropic/"):
 
103
  model = "anthropic/claude-3.5-sonnet"
104
 
105
  replicate_request = {
 
129
  }
130
 
131
  logger.info(f"Creating prediction for model: {model}")
 
132
 
133
  async with session.post(url, headers=headers, json=data, timeout=30) as response:
134
  response_text = await response.text()
135
  logger.info(f"Replicate response status: {response.status}")
 
136
 
137
  if response.status != 201:
138
  logger.error(f"Replicate API error: {response.status} - {response_text}")
 
150
  logger.error(f"Error creating prediction: {str(e)}")
151
  raise HTTPException(status_code=500, detail=f"Prediction creation error: {str(e)}")
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  def transform_replicate_to_openai_stream(event_data: str, model: str) -> str:
154
  """将Replicate流式响应转换为OpenAI格式"""
155
  if not event_data.startswith("data: "):
 
262
  try:
263
  body = await request.json()
264
  logger.info(f"Received chat completion request")
 
265
 
266
  # 转换请求格式
267
  replicate_data, model = transform_openai_to_replicate(body)
268
 
269
+ if body.get("stream", False):
270
+ # 流式响应 - 修复会话管理问题
271
+ async def generate_stream():
272
+ async with aiohttp.ClientSession() as session:
 
 
 
 
 
 
 
 
 
273
  try:
274
+ # 创建预测
275
+ prediction = await create_replicate_prediction(session, model, replicate_data)
276
+ prediction_id = prediction.get('id')
277
+ logger.info(f"Created prediction: {prediction_id}")
278
+
279
+ # 获取流式URL
280
+ stream_url = prediction.get("urls", {}).get("stream")
281
+ if not stream_url:
282
+ error_response = {
283
+ "error": {
284
+ "message": "Stream URL not available",
285
+ "type": "stream_error"
286
+ }
287
+ }
288
+ yield f"data: {json.dumps(error_response)}\n\n"
289
+ return
290
+
291
+ logger.info(f"Starting stream from: {stream_url}")
292
+
293
+ # 流式读取响应
294
+ headers = {
295
+ "Accept": "text/event-stream",
296
+ "Cache-Control": "no-store"
297
+ }
298
+
299
+ async with session.get(stream_url, headers=headers, timeout=300) as response:
300
+ if response.status != 200:
301
+ error_text = await response.text()
302
+ logger.error(f"Stream error: {response.status} - {error_text}")
303
+ error_response = {
304
+ "error": {
305
+ "message": f"Stream error: {error_text}",
306
+ "type": "stream_error"
307
+ }
308
+ }
309
+ yield f"data: {json.dumps(error_response)}\n\n"
310
+ return
311
+
312
+ async for line in response.content:
313
+ line = line.decode('utf-8').strip()
314
+ if line:
315
+ openai_event = transform_replicate_to_openai_stream(line, model)
316
+ if openai_event:
317
+ yield openai_event
318
+
319
  except Exception as e:
320
  logger.error(f"Stream generation error: {e}")
 
321
  error_response = {
322
  "error": {
323
  "message": str(e),
 
325
  }
326
  }
327
  yield f"data: {json.dumps(error_response)}\n\n"
 
 
 
 
 
 
 
 
 
 
328
 
329
+ return StreamingResponse(
330
+ generate_stream(),
331
+ media_type="text/event-stream",
332
+ headers={
333
+ "Cache-Control": "no-cache",
334
+ "Connection": "keep-alive",
335
+ "Access-Control-Allow-Origin": "*",
336
+ }
337
+ )
338
+
339
+ else:
340
+ # 非流式响应
341
+ async with aiohttp.ClientSession() as session:
342
+ # 创建预测
343
+ prediction = await create_replicate_prediction(session, model, replicate_data)
344
+ prediction_id = prediction.get('id')
345
+ logger.info(f"Created prediction: {prediction_id}")
346
+
347
+ # 轮询等待结果
348
  prediction_url = f"{REPLICATE_BASE_URL}/predictions/{prediction_id}"
349
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}"}
350
 
 
351
  max_attempts = 60 # 最多等待60秒
352
  attempt = 0
353