nomid2 commited on
Commit
4c6bca9
·
verified ·
1 Parent(s): 8355b22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -90
app.py CHANGED
@@ -2,15 +2,19 @@ import os
2
  import json
3
  import asyncio
4
  import aiohttp
 
5
  from fastapi import FastAPI, Request, HTTPException
6
- from fastapi.responses import StreamingResponse
7
  from fastapi.middleware.cors import CORSMiddleware
8
  import uvicorn
9
  from typing import Dict, Any, AsyncGenerator
10
  import logging
11
 
12
- # 配置日志
13
- logging.basicConfig(level=logging.INFO)
 
 
 
14
  logger = logging.getLogger(__name__)
15
 
16
  app = FastAPI(
@@ -31,90 +35,145 @@ app.add_middleware(
31
  # 从环境变量获取配置
32
  REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
33
  if not REPLICATE_API_TOKEN:
34
- logger.warning("REPLICATE_API_TOKEN not found in environment variables")
35
 
36
  # Replicate API配置
37
  REPLICATE_BASE_URL = "https://api.replicate.com/v1"
38
- DEFAULT_MODEL = "anthropic/claude-4-sonnet"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def transform_openai_to_replicate(openai_request: Dict[str, Any], model_override: str = None) -> Dict[str, Any]:
41
  """将OpenAI格式的请求转换为Replicate格式"""
42
- messages = openai_request.get("messages", [])
43
-
44
- # 提取system prompt
45
- system_prompt = ""
46
- user_messages = []
47
-
48
- for message in messages:
49
- if message.get("role") == "system":
50
- system_prompt = message.get("content", "")
51
- elif message.get("role") in ["user", "assistant"]:
52
- user_messages.append(message)
53
-
54
- # 构建prompt
55
- prompt_parts = []
56
- for msg in user_messages:
57
- role = msg.get("role", "")
58
- content = msg.get("content", "")
59
- if role == "user":
60
- prompt_parts.append(f"User: {content}")
61
- elif role == "assistant":
62
- prompt_parts.append(f"Assistant: {content}")
63
-
64
- prompt = "\n\n".join(prompt_parts)
65
- if prompt_parts and not prompt.endswith("\n\nAssistant:"):
66
- prompt += "\n\nAssistant:"
67
-
68
- # 确定使用的模型
69
- model = model_override or openai_request.get("model", DEFAULT_MODEL)
70
- if not model.startswith("anthropic/"):
71
- model = f"anthropic/{model}" if "/" not in model else model
72
-
73
- replicate_request = {
74
- "stream": openai_request.get("stream", False),
75
- "input": {
76
- "prompt": prompt,
77
- "system_prompt": system_prompt or "You are a helpful assistant",
78
- "max_tokens": openai_request.get("max_tokens", 1000),
79
- "temperature": openai_request.get("temperature", 0.7)
80
  }
81
- }
82
-
83
- return replicate_request, model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  async def create_replicate_prediction(session: aiohttp.ClientSession, model: str, data: Dict[str, Any]) -> Dict[str, Any]:
86
  """创建Replicate预测"""
87
- url = f"{REPLICATE_BASE_URL}/models/{model}/predictions"
88
- headers = {
89
- "Authorization": f"Bearer {REPLICATE_API_TOKEN}",
90
- "Content-Type": "application/json"
91
- }
92
-
93
- async with session.post(url, headers=headers, json=data) as response:
94
- if response.status != 201:
95
- error_text = await response.text()
96
- logger.error(f"Replicate API error: {response.status} - {error_text}")
97
- raise HTTPException(status_code=response.status, detail=f"Replicate API error: {error_text}")
98
 
99
- return await response.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  async def stream_replicate_response(session: aiohttp.ClientSession, stream_url: str) -> AsyncGenerator[str, None]:
102
  """流式读取Replicate响应"""
103
- headers = {
104
- "Accept": "text/event-stream",
105
- "Cache-Control": "no-store"
106
- }
107
-
108
- async with session.get(stream_url, headers=headers) as response:
109
- if response.status != 200:
110
- error_text = await response.text()
111
- logger.error(f"Stream error: {response.status} - {error_text}")
112
- raise HTTPException(status_code=response.status, detail=f"Stream error: {error_text}")
113
 
114
- async for line in response.content:
115
- line = line.decode('utf-8').strip()
116
- if line:
117
- yield line
 
 
 
 
 
 
 
 
 
 
118
 
119
  def transform_replicate_to_openai_stream(event_data: str, model: str) -> str:
120
  """将Replicate流式响应转换为OpenAI格式"""
@@ -158,8 +217,8 @@ def transform_replicate_to_openai_stream(event_data: str, model: str) -> str:
158
 
159
  return ""
160
 
161
- except json.JSONDecodeError:
162
- logger.warning(f"Failed to parse event data: {event_data}")
163
  return ""
164
 
165
  @app.get("/")
@@ -168,7 +227,17 @@ async def root():
168
  return {
169
  "message": "Replicate API Proxy for LobeChat",
170
  "status": "running",
171
- "replicate_token_configured": bool(REPLICATE_API_TOKEN)
 
 
 
 
 
 
 
 
 
 
172
  }
173
 
174
  @app.get("/v1/models")
@@ -200,20 +269,22 @@ async def list_models():
200
  async def chat_completions(request: Request):
201
  """处理聊天完成请求(兼容OpenAI API)"""
202
  if not REPLICATE_API_TOKEN:
 
203
  raise HTTPException(status_code=500, detail="REPLICATE_API_TOKEN not configured")
204
 
205
  try:
206
  body = await request.json()
207
- logger.info(f"Received request: {json.dumps(body, indent=2)}")
 
208
 
209
  # 转换请求格式
210
  replicate_data, model = transform_openai_to_replicate(body)
211
- logger.info(f"Transformed to Replicate format: {json.dumps(replicate_data, indent=2)}")
212
 
213
  async with aiohttp.ClientSession() as session:
214
  # 创建预测
215
  prediction = await create_replicate_prediction(session, model, replicate_data)
216
- logger.info(f"Created prediction: {prediction.get('id')}")
 
217
 
218
  if body.get("stream", False):
219
  # 流式响应
@@ -250,18 +321,26 @@ async def chat_completions(request: Request):
250
 
251
  else:
252
  # 非流式响应 - 等待预测完成
253
- prediction_url = f"{REPLICATE_BASE_URL}/predictions/{prediction['id']}"
254
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}"}
255
 
256
  # 轮询等待结果
257
- while True:
 
 
 
258
  async with session.get(prediction_url, headers=headers) as response:
259
  result = await response.json()
 
260
 
261
- if result.get("status") == "succeeded":
262
- content = "".join(result.get("output", []))
 
 
 
 
263
  openai_response = {
264
- "id": f"chatcmpl-{result['id']}",
265
  "object": "chat.completion",
266
  "created": int(asyncio.get_event_loop().time()),
267
  "model": model,
@@ -275,22 +354,34 @@ async def chat_completions(request: Request):
275
  }],
276
  "usage": {
277
  "prompt_tokens": 0,
278
- "completion_tokens": 0,
279
- "total_tokens": 0
280
  }
281
  }
282
  return openai_response
283
 
284
- elif result.get("status") == "failed":
285
- raise HTTPException(status_code=500, detail=f"Prediction failed: {result.get('error')}")
 
 
 
 
 
286
 
287
  # 等待一秒后重试
288
  await asyncio.sleep(1)
 
 
 
289
 
 
 
290
  except Exception as e:
291
- logger.error(f"Error processing request: {e}")
292
- raise HTTPException(status_code=500, detail=str(e))
 
293
 
294
  if __name__ == "__main__":
295
  port = int(os.getenv("PORT", 7860))
296
- uvicorn.run(app, host="0.0.0.0", port=port)
 
 
2
  import json
3
  import asyncio
4
  import aiohttp
5
+ import traceback
6
  from fastapi import FastAPI, Request, HTTPException
7
+ from fastapi.responses import StreamingResponse, JSONResponse
8
  from fastapi.middleware.cors import CORSMiddleware
9
  import uvicorn
10
  from typing import Dict, Any, AsyncGenerator
11
  import logging
12
 
13
+ # 配置更详细的日志
14
+ logging.basicConfig(
15
+ level=logging.INFO,
16
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
17
+ )
18
  logger = logging.getLogger(__name__)
19
 
20
  app = FastAPI(
 
35
  # 从环境变量获取配置
36
  REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
37
  if not REPLICATE_API_TOKEN:
38
+ logger.error("REPLICATE_API_TOKEN not found in environment variables")
39
 
40
  # Replicate API配置
41
  REPLICATE_BASE_URL = "https://api.replicate.com/v1"
42
+ DEFAULT_MODEL = "anthropic/claude-3-5-sonnet"
43
+
44
+ # 全局异常处理器
45
+ @app.exception_handler(Exception)
46
+ async def global_exception_handler(request: Request, exc: Exception):
47
+ logger.error(f"Global exception: {str(exc)}")
48
+ logger.error(f"Traceback: {traceback.format_exc()}")
49
+ return JSONResponse(
50
+ status_code=500,
51
+ content={
52
+ "error": {
53
+ "message": f"Internal server error: {str(exc)}",
54
+ "type": "internal_error"
55
+ }
56
+ }
57
+ )
58
 
59
  def transform_openai_to_replicate(openai_request: Dict[str, Any], model_override: str = None) -> Dict[str, Any]:
60
  """将OpenAI格式的请求转换为Replicate格式"""
61
+ try:
62
+ messages = openai_request.get("messages", [])
63
+
64
+ # 提取system prompt
65
+ system_prompt = "You are a helpful assistant"
66
+ user_messages = []
67
+
68
+ for message in messages:
69
+ if message.get("role") == "system":
70
+ system_prompt = message.get("content", "You are a helpful assistant")
71
+ elif message.get("role") in ["user", "assistant"]:
72
+ user_messages.append(message)
73
+
74
+ # 构建prompt
75
+ prompt_parts = []
76
+ for msg in user_messages:
77
+ role = msg.get("role", "")
78
+ content = msg.get("content", "")
79
+ if role == "user":
80
+ prompt_parts.append(f"Human: {content}")
81
+ elif role == "assistant":
82
+ prompt_parts.append(f"Assistant: {content}")
83
+
84
+ prompt = "\n\n".join(prompt_parts)
85
+ if prompt_parts and not prompt.endswith("\n\nAssistant:"):
86
+ prompt += "\n\nAssistant:"
87
+
88
+ # 确定使用的模型
89
+ model = model_override or openai_request.get("model", DEFAULT_MODEL)
90
+
91
+ # 模型名称映射
92
+ model_mapping = {
93
+ "claude-4-sonnet": "anthropic/claude-3-5-sonnet",
94
+ "claude-3-sonnet": "anthropic/claude-3-sonnet-20240229",
95
+ "claude-3-haiku": "anthropic/claude-3-haiku-20240307"
 
 
 
96
  }
97
+
98
+ if model in model_mapping:
99
+ model = model_mapping[model]
100
+ elif not model.startswith("anthropic/"):
101
+ model = f"anthropic/{model}"
102
+
103
+ replicate_request = {
104
+ "stream": openai_request.get("stream", False),
105
+ "input": {
106
+ "prompt": prompt,
107
+ "system_prompt": system_prompt,
108
+ "max_tokens": openai_request.get("max_tokens", 4000),
109
+ "temperature": openai_request.get("temperature", 0.7)
110
+ }
111
+ }
112
+
113
+ logger.info(f"Transformed request for model: {model}")
114
+ return replicate_request, model
115
+
116
+ except Exception as e:
117
+ logger.error(f"Error transforming request: {str(e)}")
118
+ raise HTTPException(status_code=400, detail=f"Request transformation error: {str(e)}")
119
 
120
  async def create_replicate_prediction(session: aiohttp.ClientSession, model: str, data: Dict[str, Any]) -> Dict[str, Any]:
121
  """创建Replicate预测"""
122
+ try:
123
+ url = f"{REPLICATE_BASE_URL}/models/{model}/predictions"
124
+ headers = {
125
+ "Authorization": f"Bearer {REPLICATE_API_TOKEN}",
126
+ "Content-Type": "application/json"
127
+ }
128
+
129
+ logger.info(f"Creating prediction for model: {model}")
130
+ logger.info(f"Request URL: {url}")
 
 
131
 
132
+ async with session.post(url, headers=headers, json=data, timeout=30) as response:
133
+ response_text = await response.text()
134
+ logger.info(f"Replicate response status: {response.status}")
135
+ logger.info(f"Replicate response: {response_text}")
136
+
137
+ if response.status != 201:
138
+ logger.error(f"Replicate API error: {response.status} - {response_text}")
139
+ raise HTTPException(
140
+ status_code=response.status,
141
+ detail=f"Replicate API error: {response_text}"
142
+ )
143
+
144
+ return json.loads(response_text)
145
+
146
+ except asyncio.TimeoutError:
147
+ logger.error("Timeout creating Replicate prediction")
148
+ raise HTTPException(status_code=504, detail="Timeout creating prediction")
149
+ except Exception as e:
150
+ logger.error(f"Error creating prediction: {str(e)}")
151
+ raise HTTPException(status_code=500, detail=f"Prediction creation error: {str(e)}")
152
 
153
  async def stream_replicate_response(session: aiohttp.ClientSession, stream_url: str) -> AsyncGenerator[str, None]:
154
  """流式读取Replicate响应"""
155
+ try:
156
+ headers = {
157
+ "Accept": "text/event-stream",
158
+ "Cache-Control": "no-store"
159
+ }
160
+
161
+ logger.info(f"Starting stream from: {stream_url}")
 
 
 
162
 
163
+ async with session.get(stream_url, headers=headers, timeout=300) as response:
164
+ if response.status != 200:
165
+ error_text = await response.text()
166
+ logger.error(f"Stream error: {response.status} - {error_text}")
167
+ raise HTTPException(status_code=response.status, detail=f"Stream error: {error_text}")
168
+
169
+ async for line in response.content:
170
+ line = line.decode('utf-8').strip()
171
+ if line:
172
+ yield line
173
+
174
+ except Exception as e:
175
+ logger.error(f"Stream error: {str(e)}")
176
+ raise
177
 
178
  def transform_replicate_to_openai_stream(event_data: str, model: str) -> str:
179
  """将Replicate流式响应转换为OpenAI格式"""
 
217
 
218
  return ""
219
 
220
+ except json.JSONDecodeError as e:
221
+ logger.warning(f"Failed to parse event data: {event_data}, error: {e}")
222
  return ""
223
 
224
  @app.get("/")
 
227
  return {
228
  "message": "Replicate API Proxy for LobeChat",
229
  "status": "running",
230
+ "replicate_token_configured": bool(REPLICATE_API_TOKEN),
231
+ "version": "1.0.0"
232
+ }
233
+
234
+ @app.get("/health")
235
+ async def health():
236
+ """详细健康检查"""
237
+ return {
238
+ "status": "healthy",
239
+ "replicate_token": "configured" if REPLICATE_API_TOKEN else "missing",
240
+ "timestamp": asyncio.get_event_loop().time()
241
  }
242
 
243
  @app.get("/v1/models")
 
269
  async def chat_completions(request: Request):
270
  """处理聊天完成请求(兼容OpenAI API)"""
271
  if not REPLICATE_API_TOKEN:
272
+ logger.error("REPLICATE_API_TOKEN not configured")
273
  raise HTTPException(status_code=500, detail="REPLICATE_API_TOKEN not configured")
274
 
275
  try:
276
  body = await request.json()
277
+ logger.info(f"Received chat completion request")
278
+ logger.info(f"Request body: {json.dumps(body, indent=2)}")
279
 
280
  # 转换请求格式
281
  replicate_data, model = transform_openai_to_replicate(body)
 
282
 
283
  async with aiohttp.ClientSession() as session:
284
  # 创建预测
285
  prediction = await create_replicate_prediction(session, model, replicate_data)
286
+ prediction_id = prediction.get('id')
287
+ logger.info(f"Created prediction: {prediction_id}")
288
 
289
  if body.get("stream", False):
290
  # 流式响应
 
321
 
322
  else:
323
  # 非流式响应 - 等待预测完成
324
+ prediction_url = f"{REPLICATE_BASE_URL}/predictions/{prediction_id}"
325
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}"}
326
 
327
  # 轮询等待结果
328
+ max_attempts = 60 # 最多等待60秒
329
+ attempt = 0
330
+
331
+ while attempt < max_attempts:
332
  async with session.get(prediction_url, headers=headers) as response:
333
  result = await response.json()
334
+ status = result.get("status")
335
 
336
+ logger.info(f"Prediction {prediction_id} status: {status}")
337
+
338
+ if status == "succeeded":
339
+ output = result.get("output", [])
340
+ content = "".join(output) if isinstance(output, list) else str(output)
341
+
342
  openai_response = {
343
+ "id": f"chatcmpl-{prediction_id}",
344
  "object": "chat.completion",
345
  "created": int(asyncio.get_event_loop().time()),
346
  "model": model,
 
354
  }],
355
  "usage": {
356
  "prompt_tokens": 0,
357
+ "completion_tokens": len(content.split()),
358
+ "total_tokens": len(content.split())
359
  }
360
  }
361
  return openai_response
362
 
363
+ elif status == "failed":
364
+ error_msg = result.get('error', 'Unknown error')
365
+ logger.error(f"Prediction failed: {error_msg}")
366
+ raise HTTPException(status_code=500, detail=f"Prediction failed: {error_msg}")
367
+
368
+ elif status in ["canceled", "cancelled"]:
369
+ raise HTTPException(status_code=500, detail="Prediction was canceled")
370
 
371
  # 等待一秒后重试
372
  await asyncio.sleep(1)
373
+ attempt += 1
374
+
375
+ raise HTTPException(status_code=504, detail="Prediction timeout")
376
 
377
+ except HTTPException:
378
+ raise
379
  except Exception as e:
380
+ logger.error(f"Unexpected error processing request: {str(e)}")
381
+ logger.error(f"Traceback: {traceback.format_exc()}")
382
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
383
 
384
  if __name__ == "__main__":
385
  port = int(os.getenv("PORT", 7860))
386
+ logger.info(f"Starting server on port {port}")
387
+ uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")