nomid2 commited on
Commit
a7ba54d
·
verified ·
1 Parent(s): 0eca9c1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +380 -0
app.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import uuid
5
+ import asyncio
6
+ import logging
7
+ from typing import Dict, Any, Optional, AsyncGenerator
8
+ from datetime import datetime
9
+
10
+ import httpx
11
+ from fastapi import FastAPI, HTTPException, Request
12
+ from fastapi.responses import StreamingResponse, JSONResponse
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from pydantic import BaseModel, Field
15
+ from typing import List, Union, Literal
16
+
17
+ # 配置日志
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ app = FastAPI(
22
+ title="Replicate API Proxy",
23
+ description="将 Replicate API 转换为 OpenAI 兼容格式的代理服务",
24
+ version="1.0.0"
25
+ )
26
+
27
+ # 添加 CORS 中间件
28
+ app.add_middleware(
29
+ CORSMiddleware,
30
+ allow_origins=["*"],
31
+ allow_credentials=True,
32
+ allow_methods=["*"],
33
+ allow_headers=["*"],
34
+ )
35
+
36
+ # 环境变量
37
+ REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
38
+ if not REPLICATE_API_TOKEN:
39
+ logger.warning("REPLICATE_API_TOKEN 未设置,某些功能可能无法正常工作")
40
+
41
+ # OpenAI 兼容的请求模型
42
+ class ChatMessage(BaseModel):
43
+ role: Literal["system", "user", "assistant"]
44
+ content: str
45
+
46
+ class ChatCompletionRequest(BaseModel):
47
+ model: str = "claude-3-5-sonnet"
48
+ messages: List[ChatMessage]
49
+ temperature: Optional[float] = Field(default=0.7, ge=0, le=2)
50
+ max_tokens: Optional[int] = Field(default=1000, ge=1)
51
+ stream: Optional[bool] = False
52
+ top_p: Optional[float] = Field(default=1, ge=0, le=1)
53
+
54
+ # OpenAI 兼容的响应模型
55
+ class ChatCompletionChoice(BaseModel):
56
+ index: int
57
+ message: ChatMessage
58
+ finish_reason: str
59
+
60
+ class ChatCompletionUsage(BaseModel):
61
+ prompt_tokens: int
62
+ completion_tokens: int
63
+ total_tokens: int
64
+
65
+ class ChatCompletionResponse(BaseModel):
66
+ id: str
67
+ object: str = "chat.completion"
68
+ created: int
69
+ model: str
70
+ choices: List[ChatCompletionChoice]
71
+ usage: ChatCompletionUsage
72
+
73
+ class ChatCompletionStreamChoice(BaseModel):
74
+ index: int
75
+ delta: Dict[str, Any]
76
+ finish_reason: Optional[str] = None
77
+
78
+ class ChatCompletionStreamResponse(BaseModel):
79
+ id: str
80
+ object: str = "chat.completion.chunk"
81
+ created: int
82
+ model: str
83
+ choices: List[ChatCompletionStreamChoice]
84
+
85
+ # Replicate API 客户端
86
+ class ReplicateClient:
87
+ def __init__(self, api_token: str):
88
+ self.api_token = api_token
89
+ self.base_url = "https://api.replicate.com/v1"
90
+ self.headers = {
91
+ "Authorization": f"Bearer {api_token}",
92
+ "Content-Type": "application/json"
93
+ }
94
+
95
+ def format_messages_for_replicate(self, messages: List[ChatMessage]) -> str:
96
+ """将 OpenAI 格式的消息转换为 Replicate 格式"""
97
+ formatted_messages = []
98
+
99
+ for message in messages:
100
+ if message.role == "system":
101
+ formatted_messages.append(f"System: {message.content}")
102
+ elif message.role == "user":
103
+ formatted_messages.append(f"Human: {message.content}")
104
+ elif message.role == "assistant":
105
+ formatted_messages.append(f"Assistant: {message.content}")
106
+
107
+ # 为 Claude 添加最后的 Assistant: 提示
108
+ if not any(msg.role == "assistant" for msg in messages[-1:]):
109
+ formatted_messages.append("Assistant:")
110
+
111
+ return "\n\n".join(formatted_messages)
112
+
113
+ async def create_prediction(
114
+ self,
115
+ messages: List[ChatMessage],
116
+ temperature: float = 0.7,
117
+ max_tokens: int = 1000,
118
+ top_p: float = 1.0
119
+ ) -> Dict[str, Any]:
120
+ """创建 Replicate 预测"""
121
+
122
+ # 构建输入
123
+ prompt = self.format_messages_for_replicate(messages)
124
+
125
+ payload = {
126
+ "version": "14e5e6719b5af8e6a0b4b1d73b48bb0f8e8b3a7a0b4b1d73b48bb0f8e8b3a7a0", # Claude 3.5 Sonnet 版本ID
127
+ "input": {
128
+ "prompt": prompt,
129
+ "max_tokens": max_tokens,
130
+ "temperature": temperature,
131
+ "top_p": top_p,
132
+ "system_prompt": "You are Claude, an AI assistant created by Anthropic."
133
+ }
134
+ }
135
+
136
+ async with httpx.AsyncClient(timeout=30.0) as client:
137
+ try:
138
+ response = await client.post(
139
+ f"{self.base_url}/predictions",
140
+ headers=self.headers,
141
+ json=payload
142
+ )
143
+ response.raise_for_status()
144
+ return response.json()
145
+ except httpx.RequestError as e:
146
+ logger.error(f"请求 Replicate API 失败: {e}")
147
+ raise HTTPException(status_code=502, detail="上游服务请求失败")
148
+ except httpx.HTTPStatusError as e:
149
+ logger.error(f"Replicate API 返回错误: {e.response.status_code} - {e.response.text}")
150
+ raise HTTPException(status_code=e.response.status_code, detail="上游服务错误")
151
+
152
+ async def get_prediction(self, prediction_id: str) -> Dict[str, Any]:
153
+ """获取预测结果"""
154
+ async with httpx.AsyncClient(timeout=30.0) as client:
155
+ try:
156
+ response = await client.get(
157
+ f"{self.base_url}/predictions/{prediction_id}",
158
+ headers=self.headers
159
+ )
160
+ response.raise_for_status()
161
+ return response.json()
162
+ except httpx.RequestError as e:
163
+ logger.error(f"获取预测结果失败: {e}")
164
+ raise HTTPException(status_code=502, detail="获取结果失败")
165
+
166
+ async def wait_for_prediction(self, prediction_id: str, max_wait: int = 300) -> Dict[str, Any]:
167
+ """等待预测完成"""
168
+ start_time = time.time()
169
+
170
+ while time.time() - start_time < max_wait:
171
+ prediction = await self.get_prediction(prediction_id)
172
+
173
+ if prediction["status"] == "succeeded":
174
+ return prediction
175
+ elif prediction["status"] == "failed":
176
+ error_msg = prediction.get("error", "预测失败")
177
+ logger.error(f"Replicate 预测失败: {error_msg}")
178
+ raise HTTPException(status_code=502, detail=f"预测失败: {error_msg}")
179
+ elif prediction["status"] in ["canceled"]:
180
+ raise HTTPException(status_code=502, detail="预测被取消")
181
+
182
+ # 等待一段时间后重试
183
+ await asyncio.sleep(2)
184
+
185
+ raise HTTPException(status_code=504, detail="预测超时")
186
+
187
+ # 初始化 Replicate 客户端
188
+ replicate_client = None
189
+ if REPLICATE_API_TOKEN:
190
+ replicate_client = ReplicateClient(REPLICATE_API_TOKEN)
191
+
192
+ def calculate_tokens(text: str) -> int:
193
+ """简单的 token 计算(实际应用中应使用更精确的方法)"""
194
+ return len(text.split()) + len(text) // 4
195
+
196
+ def create_openai_response(
197
+ content: str,
198
+ model: str,
199
+ request_id: str,
200
+ prompt_tokens: int,
201
+ completion_tokens: int
202
+ ) -> ChatCompletionResponse:
203
+ """创建 OpenAI 格式的响应"""
204
+ return ChatCompletionResponse(
205
+ id=request_id,
206
+ created=int(time.time()),
207
+ model=model,
208
+ choices=[
209
+ ChatCompletionChoice(
210
+ index=0,
211
+ message=ChatMessage(role="assistant", content=content),
212
+ finish_reason="stop"
213
+ )
214
+ ],
215
+ usage=ChatCompletionUsage(
216
+ prompt_tokens=prompt_tokens,
217
+ completion_tokens=completion_tokens,
218
+ total_tokens=prompt_tokens + completion_tokens
219
+ )
220
+ )
221
+
222
+ async def create_openai_stream(
223
+ content: str,
224
+ model: str,
225
+ request_id: str
226
+ ) -> AsyncGenerator[str, None]:
227
+ """创建 OpenAI 格式的流式响应"""
228
+
229
+ # 开始流式响应
230
+ start_chunk = ChatCompletionStreamResponse(
231
+ id=request_id,
232
+ created=int(time.time()),
233
+ model=model,
234
+ choices=[
235
+ ChatCompletionStreamChoice(
236
+ index=0,
237
+ delta={"role": "assistant", "content": ""}
238
+ )
239
+ ]
240
+ )
241
+ yield f"data: {start_chunk.model_dump_json()}\n\n"
242
+
243
+ # 分块发送内容
244
+ words = content.split()
245
+ for i, word in enumerate(words):
246
+ chunk_content = word + (" " if i < len(words) - 1 else "")
247
+
248
+ chunk = ChatCompletionStreamResponse(
249
+ id=request_id,
250
+ created=int(time.time()),
251
+ model=model,
252
+ choices=[
253
+ ChatCompletionStreamChoice(
254
+ index=0,
255
+ delta={"content": chunk_content}
256
+ )
257
+ ]
258
+ )
259
+ yield f"data: {chunk.model_dump_json()}\n\n"
260
+ await asyncio.sleep(0.05) # 模拟流式响应延迟
261
+
262
+ # 结束流式响应
263
+ end_chunk = ChatCompletionStreamResponse(
264
+ id=request_id,
265
+ created=int(time.time()),
266
+ model=model,
267
+ choices=[
268
+ ChatCompletionStreamChoice(
269
+ index=0,
270
+ delta={},
271
+ finish_reason="stop"
272
+ )
273
+ ]
274
+ )
275
+ yield f"data: {end_chunk.model_dump_json()}\n\n"
276
+ yield "data: [DONE]\n\n"
277
+
278
+ @app.get("/")
279
+ async def root():
280
+ """根路径"""
281
+ return {
282
+ "message": "Replicate API Proxy",
283
+ "version": "1.0.0",
284
+ "status": "running",
285
+ "replicate_configured": REPLICATE_API_TOKEN is not None
286
+ }
287
+
288
+ @app.get("/v1/models")
289
+ async def list_models():
290
+ """列出可用模型"""
291
+ return {
292
+ "object": "list",
293
+ "data": [
294
+ {
295
+ "id": "claude-3-5-sonnet",
296
+ "object": "model",
297
+ "created": int(time.time()),
298
+ "owned_by": "anthropic"
299
+ }
300
+ ]
301
+ }
302
+
303
+ @app.post("/v1/chat/completions")
304
+ async def create_chat_completion(request: ChatCompletionRequest):
305
+ """创建聊天完成"""
306
+
307
+ if not replicate_client:
308
+ raise HTTPException(
309
+ status_code=500,
310
+ detail="Replicate API Token 未配置,请设置 REPLICATE_API_TOKEN 环境变量"
311
+ )
312
+
313
+ request_id = f"chatcmpl-{uuid.uuid4().hex}"
314
+
315
+ try:
316
+ # 创建 Replicate 预测
317
+ prediction = await replicate_client.create_prediction(
318
+ messages=request.messages,
319
+ temperature=request.temperature,
320
+ max_tokens=request.max_tokens,
321
+ top_p=request.top_p
322
+ )
323
+
324
+ # 等待预测完成
325
+ completed_prediction = await replicate_client.wait_for_prediction(
326
+ prediction["id"]
327
+ )
328
+
329
+ # 提取生成的内容
330
+ output = completed_prediction.get("output", [])
331
+ if isinstance(output, list):
332
+ content = "".join(output)
333
+ else:
334
+ content = str(output)
335
+
336
+ # 计算 token 使用量
337
+ prompt_text = " ".join([msg.content for msg in request.messages])
338
+ prompt_tokens = calculate_tokens(prompt_text)
339
+ completion_tokens = calculate_tokens(content)
340
+
341
+ if request.stream:
342
+ # 返回流式响应
343
+ return StreamingResponse(
344
+ create_openai_stream(content, request.model, request_id),
345
+ media_type="text/event-stream",
346
+ headers={
347
+ "Cache-Control": "no-cache",
348
+ "Connection": "keep-alive",
349
+ "Access-Control-Allow-Origin": "*",
350
+ }
351
+ )
352
+ else:
353
+ # 返回标准响应
354
+ response = create_openai_response(
355
+ content=content,
356
+ model=request.model,
357
+ request_id=request_id,
358
+ prompt_tokens=prompt_tokens,
359
+ completion_tokens=completion_tokens
360
+ )
361
+ return response
362
+
363
+ except Exception as e:
364
+ logger.error(f"处理聊天完成请求时出错: {e}")
365
+ if isinstance(e, HTTPException):
366
+ raise e
367
+ raise HTTPException(status_code=500, detail=str(e))
368
+
369
+ @app.get("/health")
370
+ async def health_check():
371
+ """健康检查"""
372
+ return {
373
+ "status": "healthy",
374
+ "timestamp": datetime.utcnow().isoformat(),
375
+ "replicate_configured": REPLICATE_API_TOKEN is not None
376
+ }
377
+
378
+ if __name__ == "__main__":
379
+ import uvicorn
380
+ uvicorn.run(app, host="0.0.0.0", port=7860)