nomid2 commited on
Commit
62be311
·
verified ·
1 Parent(s): 93eb401

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -44
app.py CHANGED
@@ -41,6 +41,40 @@ if not REPLICATE_API_TOKEN:
41
  REPLICATE_BASE_URL = "https://api.replicate.com/v1"
42
  DEFAULT_MODEL = "anthropic/claude-3.5-sonnet"
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  # 全局异常处理器
45
  @app.exception_handler(Exception)
46
  async def global_exception_handler(request: Request, exc: Exception):
@@ -61,17 +95,17 @@ def transform_openai_to_replicate(openai_request: Dict[str, Any], model_override
61
  try:
62
  messages = openai_request.get("messages", [])
63
 
64
- # 提取system prompt
65
- system_prompt = "You are a helpful assistant"
66
  user_messages = []
67
 
68
  for message in messages:
69
  if message.get("role") == "system":
70
- system_prompt = message.get("content", "You are a helpful assistant")
71
  elif message.get("role") in ["user", "assistant"]:
72
  user_messages.append(message)
73
 
74
- # 构建prompt
75
  prompt_parts = []
76
  for msg in user_messages:
77
  role = msg.get("role", "")
@@ -95,24 +129,73 @@ def transform_openai_to_replicate(openai_request: Dict[str, Any], model_override
95
  "claude-3-sonnet": "anthropic/claude-3-sonnet",
96
  "claude-3.5-haiku": "anthropic/claude-3.5-haiku",
97
  "claude-3-haiku": "anthropic/claude-3-haiku",
 
98
  }
99
 
100
  if model in model_mapping:
101
  model = model_mapping[model]
102
- elif not model.startswith("anthropic/"):
103
  model = "anthropic/claude-3.5-sonnet"
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  replicate_request = {
106
  "stream": openai_request.get("stream", False),
107
- "input": {
108
- "prompt": prompt,
109
- "system_prompt": system_prompt,
110
- "max_tokens": openai_request.get("max_tokens", 4000),
111
- "temperature": openai_request.get("temperature", 0.7)
112
- }
113
  }
114
 
115
  logger.info(f"Transformed request for model: {model}")
 
 
 
116
  return replicate_request, model
117
 
118
  except Exception as e:
@@ -129,6 +212,7 @@ async def create_replicate_prediction(session: aiohttp.ClientSession, model: str
129
  }
130
 
131
  logger.info(f"Creating prediction for model: {model}")
 
132
 
133
  async with session.post(url, headers=headers, json=data, timeout=30) as response:
134
  response_text = await response.text()
@@ -210,7 +294,8 @@ async def root():
210
  "message": "Replicate API Proxy for LobeChat",
211
  "status": "running",
212
  "replicate_token_configured": bool(REPLICATE_API_TOKEN),
213
- "version": "1.0.0"
 
214
  }
215
 
216
  @app.get("/health")
@@ -219,44 +304,22 @@ async def health():
219
  return {
220
  "status": "healthy",
221
  "replicate_token": "configured" if REPLICATE_API_TOKEN else "missing",
222
- "timestamp": asyncio.get_event_loop().time()
 
223
  }
224
 
225
  @app.get("/v1/models")
226
  async def list_models():
227
  """列出可用模型(兼容OpenAI API)"""
228
- models = [
229
- {
230
- "id": "claude-4-sonnet",
 
231
  "object": "model",
232
  "created": 1677610602,
233
  "owned_by": "anthropic"
234
- },
235
- {
236
- "id": "claude-3.5-sonnet",
237
- "object": "model",
238
- "created": 1677610602,
239
- "owned_by": "anthropic"
240
- },
241
- {
242
- "id": "claude-3.5-haiku",
243
- "object": "model",
244
- "created": 1677610602,
245
- "owned_by": "anthropic"
246
- },
247
- {
248
- "id": "claude-3-sonnet",
249
- "object": "model",
250
- "created": 1677610602,
251
- "owned_by": "anthropic"
252
- },
253
- {
254
- "id": "claude-3-haiku",
255
- "object": "model",
256
- "created": 1677610602,
257
- "owned_by": "anthropic"
258
- }
259
- ]
260
  return {"object": "list", "data": models}
261
 
262
  @app.post("/v1/chat/completions")
@@ -269,6 +332,8 @@ async def chat_completions(request: Request):
269
  try:
270
  body = await request.json()
271
  logger.info(f"Received chat completion request")
 
 
272
 
273
  # 转换请求格式
274
  replicate_data, model = transform_openai_to_replicate(body)
@@ -332,8 +397,6 @@ async def chat_completions(request: Request):
332
  event_type = event.get('event')
333
  data = event.get('data', '')
334
 
335
- logger.info(f"Parsed SSE event: {event_type}, data: {data[:50]}...")
336
-
337
  if event_type == 'output' and data.strip():
338
  # 输出事件,包含实际内容
339
  yield create_openai_chunk(data, model, prediction_id)
 
41
  REPLICATE_BASE_URL = "https://api.replicate.com/v1"
42
  DEFAULT_MODEL = "anthropic/claude-3.5-sonnet"
43
 
44
+ # 模型配置信息
45
+ MODEL_CONFIGS = {
46
+ "anthropic/claude-4-sonnet": {
47
+ "min_max_tokens": 1024, # Replicate 要求的最小值
48
+ "default_max_tokens": 8192, # 如果客户端未指定时的默认值
49
+ "has_max_tokens_limit": True
50
+ },
51
+ "anthropic/claude-3.5-sonnet": {
52
+ "min_max_tokens": 1,
53
+ "default_max_tokens": 8192,
54
+ "has_max_tokens_limit": False
55
+ },
56
+ "anthropic/claude-3-sonnet": {
57
+ "min_max_tokens": 1,
58
+ "default_max_tokens": 4096,
59
+ "has_max_tokens_limit": False
60
+ },
61
+ "anthropic/claude-3.5-haiku": {
62
+ "min_max_tokens": 1,
63
+ "default_max_tokens": 4096,
64
+ "has_max_tokens_limit": False
65
+ },
66
+ "anthropic/claude-3-haiku": {
67
+ "min_max_tokens": 1,
68
+ "default_max_tokens": 4096,
69
+ "has_max_tokens_limit": False
70
+ },
71
+ "google/gemini-2.5-pro": { # 如果将来支持 Gemini
72
+ "min_max_tokens": 1,
73
+ "default_max_tokens": 8192,
74
+ "has_max_tokens_limit": False
75
+ }
76
+ }
77
+
78
  # 全局异常处理器
79
  @app.exception_handler(Exception)
80
  async def global_exception_handler(request: Request, exc: Exception):
 
95
  try:
96
  messages = openai_request.get("messages", [])
97
 
98
+ # 完全使用客户端提供的 system prompt,不设置默认值
99
+ system_prompt = None
100
  user_messages = []
101
 
102
  for message in messages:
103
  if message.get("role") == "system":
104
+ system_prompt = message.get("content", "")
105
  elif message.get("role") in ["user", "assistant"]:
106
  user_messages.append(message)
107
 
108
+ # 构建prompt - 包含完整的对话历史,不限制数量
109
  prompt_parts = []
110
  for msg in user_messages:
111
  role = msg.get("role", "")
 
129
  "claude-3-sonnet": "anthropic/claude-3-sonnet",
130
  "claude-3.5-haiku": "anthropic/claude-3.5-haiku",
131
  "claude-3-haiku": "anthropic/claude-3-haiku",
132
+ "gemini-2.5-pro": "google/gemini-2.5-pro", # 预留
133
  }
134
 
135
  if model in model_mapping:
136
  model = model_mapping[model]
137
+ elif not model.startswith(("anthropic/", "google/")):
138
  model = "anthropic/claude-3.5-sonnet"
139
 
140
+ # 获取模型配置
141
+ model_config = MODEL_CONFIGS.get(model, MODEL_CONFIGS["anthropic/claude-3.5-sonnet"])
142
+
143
+ # 处理 max_tokens - 完全根据客户端和模型配置
144
+ client_max_tokens = openai_request.get("max_tokens")
145
+
146
+ if client_max_tokens is not None:
147
+ # 客户端指定了 max_tokens,尊重客户端设置
148
+ max_tokens = client_max_tokens
149
+ # 只在低于模型最小要求时调整
150
+ if max_tokens < model_config["min_max_tokens"]:
151
+ logger.info(f"Adjusting max_tokens from {max_tokens} to {model_config['min_max_tokens']} (model minimum)")
152
+ max_tokens = model_config["min_max_tokens"]
153
+ else:
154
+ # 客户端未指定 max_tokens
155
+ if model_config["has_max_tokens_limit"]:
156
+ # 模型有强制要求,使用默认值
157
+ max_tokens = model_config["default_max_tokens"]
158
+ logger.info(f"Using default max_tokens {max_tokens} for model {model}")
159
+ else:
160
+ # 模型没有限制,不设置 max_tokens
161
+ max_tokens = None
162
+ logger.info(f"No max_tokens limit for model {model}, allowing unlimited")
163
+
164
+ # 构建 Replicate 请求的 input 参数
165
+ replicate_input = {
166
+ "prompt": prompt,
167
+ }
168
+
169
+ # 只在有 system_prompt 时才添加
170
+ if system_prompt:
171
+ replicate_input["system_prompt"] = system_prompt
172
+
173
+ # 只在有 max_tokens 时才添加
174
+ if max_tokens is not None:
175
+ replicate_input["max_tokens"] = max_tokens
176
+
177
+ # 处理其他参数 - 完全使用客户端设置
178
+ if "temperature" in openai_request:
179
+ replicate_input["temperature"] = openai_request["temperature"]
180
+
181
+ if "top_p" in openai_request:
182
+ replicate_input["top_p"] = openai_request["top_p"]
183
+
184
+ if "frequency_penalty" in openai_request:
185
+ replicate_input["frequency_penalty"] = openai_request["frequency_penalty"]
186
+
187
+ if "presence_penalty" in openai_request:
188
+ replicate_input["presence_penalty"] = openai_request["presence_penalty"]
189
+
190
  replicate_request = {
191
  "stream": openai_request.get("stream", False),
192
+ "input": replicate_input
 
 
 
 
 
193
  }
194
 
195
  logger.info(f"Transformed request for model: {model}")
196
+ logger.info(f"Message count: {len(messages)} (system: {1 if system_prompt else 0}, user/assistant: {len(user_messages)})")
197
+ logger.info(f"Parameters: max_tokens={max_tokens}, temperature={replicate_input.get('temperature', 'not set')}, top_p={replicate_input.get('top_p', 'not set')}")
198
+
199
  return replicate_request, model
200
 
201
  except Exception as e:
 
212
  }
213
 
214
  logger.info(f"Creating prediction for model: {model}")
215
+ logger.info(f"Request data: {json.dumps(data, indent=2)}")
216
 
217
  async with session.post(url, headers=headers, json=data, timeout=30) as response:
218
  response_text = await response.text()
 
294
  "message": "Replicate API Proxy for LobeChat",
295
  "status": "running",
296
  "replicate_token_configured": bool(REPLICATE_API_TOKEN),
297
+ "version": "1.0.0",
298
+ "supported_models": list(MODEL_CONFIGS.keys())
299
  }
300
 
301
  @app.get("/health")
 
304
  return {
305
  "status": "healthy",
306
  "replicate_token": "configured" if REPLICATE_API_TOKEN else "missing",
307
+ "timestamp": asyncio.get_event_loop().time(),
308
+ "model_configs": MODEL_CONFIGS
309
  }
310
 
311
  @app.get("/v1/models")
312
  async def list_models():
313
  """列出可用模型(兼容OpenAI API)"""
314
+ models = []
315
+ for model_id in ["claude-4-sonnet", "claude-3.5-sonnet", "claude-3.5-haiku", "claude-3-sonnet", "claude-3-haiku"]:
316
+ models.append({
317
+ "id": model_id,
318
  "object": "model",
319
  "created": 1677610602,
320
  "owned_by": "anthropic"
321
+ })
322
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  return {"object": "list", "data": models}
324
 
325
  @app.post("/v1/chat/completions")
 
332
  try:
333
  body = await request.json()
334
  logger.info(f"Received chat completion request")
335
+ logger.info(f"Client parameters: max_tokens={body.get('max_tokens', 'not set')}, temperature={body.get('temperature', 'not set')}, top_p={body.get('top_p', 'not set')}")
336
+ logger.info(f"Message count: {len(body.get('messages', []))}")
337
 
338
  # 转换请求格式
339
  replicate_data, model = transform_openai_to_replicate(body)
 
397
  event_type = event.get('event')
398
  data = event.get('data', '')
399
 
 
 
400
  if event_type == 'output' and data.strip():
401
  # 输出事件,包含实际内容
402
  yield create_openai_chunk(data, model, prediction_id)