deeme commited on
Commit
c78beab
·
verified ·
1 Parent(s): 97f2698

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +274 -257
main.py CHANGED
@@ -2,13 +2,15 @@ import asyncio
2
  import json
3
  from datetime import datetime, timezone
4
  import os
 
 
5
 
6
  from fastapi import FastAPI, HTTPException, Request
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.responses import JSONResponse
9
  from fastapi.responses import StreamingResponse
10
  from pydantic import BaseModel
11
- from typing import List, Optional
12
  import time
13
  import uuid
14
  import logging
@@ -25,11 +27,11 @@ app = FastAPI(title="Gemini API FastAPI Server")
25
 
26
  # Add CORS middleware
27
  app.add_middleware(
28
- CORSMiddleware,
29
- allow_origins=["*"],
30
- allow_credentials=True,
31
- allow_methods=["*"],
32
- allow_headers=["*"],
33
  )
34
 
35
  # Global client
@@ -41,309 +43,324 @@ SECURE_1PSIDTS = os.environ.get("SECURE_1PSIDTS", "")
41
 
42
  # Print debug info at startup
43
  if not SECURE_1PSID or not SECURE_1PSIDTS:
44
- logger.warning("⚠️ Gemini API credentials are not set or empty! Please check your environment variables.")
 
 
 
 
 
45
  else:
46
- # Only log the first few characters for security
47
- logger.info(f"Credentials found. SECURE_1PSID starts with: {SECURE_1PSID[:5]}...")
48
- logger.info(f"Credentials found. SECURE_1PSIDTS starts with: {SECURE_1PSIDTS[:5]}...")
 
49
 
50
  # Pydantic models for API requests and responses
 
 
 
 
 
 
51
  class Message(BaseModel):
52
- role: str
53
- content: str
54
- name: Optional[str] = None
55
 
56
 
57
  class ChatCompletionRequest(BaseModel):
58
- model: str
59
- messages: List[Message]
60
- temperature: Optional[float] = 0.7
61
- top_p: Optional[float] = 1.0
62
- n: Optional[int] = 1
63
- stream: Optional[bool] = False
64
- max_tokens: Optional[int] = None
65
- presence_penalty: Optional[float] = 0
66
- frequency_penalty: Optional[float] = 0
67
- user: Optional[str] = None
68
 
69
 
70
  class Choice(BaseModel):
71
- index: int
72
- message: Message
73
- finish_reason: str
74
 
75
 
76
  class Usage(BaseModel):
77
- prompt_tokens: int
78
- completion_tokens: int
79
- total_tokens: int
80
 
81
 
82
  class ChatCompletionResponse(BaseModel):
83
- id: str
84
- object: str = "chat.completion"
85
- created: int
86
- model: str
87
- choices: List[Choice]
88
- usage: Usage
89
 
90
 
91
  class ModelData(BaseModel):
92
- id: str
93
- object: str = "model"
94
- created: int
95
- owned_by: str = "google"
96
 
97
 
98
  class ModelList(BaseModel):
99
- object: str = "list"
100
- data: List[ModelData]
101
 
102
 
103
  # Simple error handler middleware
104
  @app.middleware("http")
105
  async def error_handling(request: Request, call_next):
106
- try:
107
- return await call_next(request)
108
- except Exception as e:
109
- logger.error(f"Request failed: {str(e)}")
110
- return JSONResponse(
111
- status_code=500,
112
- content={ "error": { "message": str(e), "type": "internal_server_error" } }
113
- )
114
 
115
 
116
  # Get list of available models
117
  @app.get("/v1/models")
118
  async def list_models():
119
- """返回 gemini_webapi 中声明的模型列表"""
120
- now = int(datetime.now(tz=timezone.utc).timestamp())
121
- data = [
122
- {
123
- "id": m.model_name, # 如 "gemini-2.0-flash"
124
- "object": "model",
125
- "created": now,
126
- "owned_by": "google-gemini-web"
127
- }
128
- for m in Model
129
- ]
130
- print(data)
131
- return {"object": "list", "data": data}
132
 
133
 
134
  # Helper to convert between Gemini and OpenAI model names
135
  def map_model_name(openai_model_name: str) -> Model:
136
- """根据模型名称字符串查找匹配的 Model 枚举值"""
137
- # 打印所有可用模型以便调试
138
- all_models = [m.model_name if hasattr(m, "model_name") else str(m) for m in Model]
139
- logger.info(f"Available models: {all_models}")
140
-
141
- # 首先尝试直接查找匹配的模型名称
142
- for m in Model:
143
- model_name = m.model_name if hasattr(m, "model_name") else str(m)
144
- if openai_model_name.lower() in model_name.lower():
145
- return m
146
-
147
- # 如果找不到匹配项,使用默认映射
148
- model_keywords = {
149
- "gemini-pro": ["pro", "2.0"],
150
- "gemini-pro-vision": ["vision", "pro"],
151
- "gemini-flash": ["flash", "2.0"],
152
- "gemini-1.5-pro": ["1.5", "pro"],
153
- "gemini-1.5-flash": ["1.5", "flash"],
154
- }
155
-
156
- # 根据关键词匹配
157
- keywords = model_keywords.get(openai_model_name, ["pro"]) # 默认使用pro模型
158
-
159
- for m in Model:
160
- model_name = m.model_name if hasattr(m, "model_name") else str(m)
161
- if all(kw.lower() in model_name.lower() for kw in keywords):
162
- return m
163
-
164
- # 如果还是找不到,返回第一个模型
165
- return next(iter(Model))
166
 
167
 
168
  # Prepare conversation history from OpenAI messages format
169
- def prepare_conversation(messages: List[Message]) -> str:
170
- conversation = ""
171
-
172
- for msg in messages:
173
- if msg.role == "system":
174
- conversation += f"System: {msg.content}\n\n"
175
- elif msg.role == "user":
176
- conversation += f"Human: {msg.content}\n\n"
177
- elif msg.role == "assistant":
178
- conversation += f"Assistant: {msg.content}\n\n"
179
-
180
- # Add a final prompt for the assistant to respond to
181
- conversation += "Assistant: "
182
-
183
- return conversation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
 
186
  # Dependency to get the initialized Gemini client
187
  async def get_gemini_client():
188
- global gemini_client
189
- if gemini_client is None:
190
- try:
191
- gemini_client = GeminiClient(SECURE_1PSID, SECURE_1PSIDTS)
192
- await gemini_client.init(timeout=300)
193
- except Exception as e:
194
- logger.error(f"Failed to initialize Gemini client: {str(e)}")
195
- raise HTTPException(
196
- status_code=500,
197
- detail=f"Failed to initialize Gemini client: {str(e)}"
198
- )
199
- return gemini_client
200
 
201
 
202
  @app.post("/v1/chat/completions")
203
  async def create_chat_completion(request: ChatCompletionRequest):
204
- try:
205
- # 确保客户端已初始化
206
- global gemini_client
207
- if gemini_client is None:
208
- gemini_client = GeminiClient(SECURE_1PSID, SECURE_1PSIDTS)
209
- await gemini_client.init(timeout=300)
210
- logger.info("Gemini client initialized successfully")
211
-
212
- # 转换消息为对话格式
213
- conversation = prepare_conversation(request.messages)
214
- logger.info(f"Prepared conversation: {conversation}")
215
-
216
- # 获取适当的模型
217
- model = map_model_name(request.model)
218
- logger.info(f"Using model: {model}")
219
-
220
- # 生成响应
221
- logger.info("Sending request to Gemini...")
222
- response = await gemini_client.generate_content(conversation, model=model)
223
-
224
- # 提取文本响应
225
- reply_text = ""
226
- if hasattr(response, "text"):
227
- reply_text = response.text
228
- else:
229
- reply_text = str(response)
230
-
231
- logger.info(f"Response: {reply_text}")
232
-
233
- if not reply_text or reply_text.strip() == "":
234
- logger.warning("Empty response received from Gemini")
235
- reply_text = "服务器返回了空响应。请检查 Gemini API 凭据是否有效。"
236
-
237
- # 创建响应对象
238
- completion_id = f"chatcmpl-{uuid.uuid4()}"
239
- created_time = int(time.time())
240
-
241
- # 检查客户端是否请求流式响应
242
- if request.stream:
243
- # 实现流式响应
244
- async def generate_stream():
245
- # 创建 SSE 格式的流式响应
246
- # 先发送开始事件
247
- data = {
248
- "id": completion_id,
249
- "object": "chat.completion.chunk",
250
- "created": created_time,
251
- "model": request.model,
252
- "choices": [
253
- {
254
- "index": 0,
255
- "delta": {
256
- "role": "assistant"
257
- },
258
- "finish_reason": None
259
- }
260
- ]
261
- }
262
- yield f"data: {json.dumps(data)}\n\n"
263
-
264
- # 模拟流式输出 - 将文本按字符分割发送
265
- for char in reply_text:
266
- data = {
267
- "id": completion_id,
268
- "object": "chat.completion.chunk",
269
- "created": created_time,
270
- "model": request.model,
271
- "choices": [
272
- {
273
- "index": 0,
274
- "delta": {
275
- "content": char
276
- },
277
- "finish_reason": None
278
- }
279
- ]
280
- }
281
- yield f"data: {json.dumps(data)}\n\n"
282
- # 可选:添加短暂延迟以模拟真实的流式输出
283
- await asyncio.sleep(0.01)
284
-
285
- # 发送结束事件
286
- data = {
287
- "id": completion_id,
288
- "object": "chat.completion.chunk",
289
- "created": created_time,
290
- "model": request.model,
291
- "choices": [
292
- {
293
- "index": 0,
294
- "delta": { },
295
- "finish_reason": "stop"
296
- }
297
- ]
298
- }
299
- yield f"data: {json.dumps(data)}\n\n"
300
- yield "data: [DONE]\n\n"
301
-
302
- return StreamingResponse(
303
- generate_stream(),
304
- media_type="text/event-stream"
305
- )
306
- else:
307
- # 非流式响应(原来的逻辑)
308
- result = {
309
- "id": completion_id,
310
- "object": "chat.completion",
311
- "created": created_time,
312
- "model": request.model,
313
- "choices": [
314
- {
315
- "index": 0,
316
- "message": {
317
- "role": "assistant",
318
- "content": reply_text
319
- },
320
- "finish_reason": "stop"
321
- }
322
- ],
323
- "usage": {
324
- "prompt_tokens": len(conversation.split()),
325
- "completion_tokens": len(reply_text.split()),
326
- "total_tokens": len(conversation.split()) + len(reply_text.split())
327
- }
328
- }
329
-
330
- logger.info(f"Returning response: {result}")
331
- return result
332
-
333
- except Exception as e:
334
- logger.error(f"Error generating completion: {str(e)}", exc_info=True)
335
- raise HTTPException(
336
- status_code=500,
337
- detail=f"Error generating completion: {str(e)}"
338
- )
339
 
340
 
341
  @app.get("/")
342
  async def root():
343
- return { "status": "online", "message": "Gemini API FastAPI Server is running" }
344
 
345
 
346
  if __name__ == "__main__":
347
- import uvicorn
348
 
349
- uvicorn.run("main:app", host="0.0.0.0", port=7860, log_level="info")
 
2
  import json
3
  from datetime import datetime, timezone
4
  import os
5
+ import base64
6
+ import tempfile
7
 
8
  from fastapi import FastAPI, HTTPException, Request
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.responses import JSONResponse
11
  from fastapi.responses import StreamingResponse
12
  from pydantic import BaseModel
13
+ from typing import List, Optional, Dict, Any, Union
14
  import time
15
  import uuid
16
  import logging
 
27
 
28
  # Add CORS middleware
29
  app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=["*"],
32
+ allow_credentials=True,
33
+ allow_methods=["*"],
34
+ allow_headers=["*"],
35
  )
36
 
37
  # Global client
 
43
 
44
  # Print debug info at startup
45
  if not SECURE_1PSID or not SECURE_1PSIDTS:
46
+ logger.warning("⚠️ Gemini API credentials are not set or empty! Please check your environment variables.")
47
+ logger.warning("Make sure SECURE_1PSID and SECURE_1PSIDTS are correctly set in your .env file or environment.")
48
+ logger.warning("If using Docker, ensure the .env file is correctly mounted and formatted.")
49
+ logger.warning("Example format in .env file (no quotes):")
50
+ logger.warning("SECURE_1PSID=your_secure_1psid_value_here")
51
+ logger.warning("SECURE_1PSIDTS=your_secure_1psidts_value_here")
52
  else:
53
+ # Only log the first few characters for security
54
+ logger.info(f"Credentials found. SECURE_1PSID starts with: {SECURE_1PSID[:5]}...")
55
+ logger.info(f"Credentials found. SECURE_1PSIDTS starts with: {SECURE_1PSIDTS[:5]}...")
56
+
57
 
58
  # Pydantic models for API requests and responses
59
+ class ContentItem(BaseModel):
60
+ type: str
61
+ text: Optional[str] = None
62
+ image_url: Optional[Dict[str, str]] = None
63
+
64
+
65
  class Message(BaseModel):
66
+ role: str
67
+ content: Union[str, List[ContentItem]]
68
+ name: Optional[str] = None
69
 
70
 
71
  class ChatCompletionRequest(BaseModel):
72
+ model: str
73
+ messages: List[Message]
74
+ temperature: Optional[float] = 0.7
75
+ top_p: Optional[float] = 1.0
76
+ n: Optional[int] = 1
77
+ stream: Optional[bool] = False
78
+ max_tokens: Optional[int] = None
79
+ presence_penalty: Optional[float] = 0
80
+ frequency_penalty: Optional[float] = 0
81
+ user: Optional[str] = None
82
 
83
 
84
  class Choice(BaseModel):
85
+ index: int
86
+ message: Message
87
+ finish_reason: str
88
 
89
 
90
  class Usage(BaseModel):
91
+ prompt_tokens: int
92
+ completion_tokens: int
93
+ total_tokens: int
94
 
95
 
96
  class ChatCompletionResponse(BaseModel):
97
+ id: str
98
+ object: str = "chat.completion"
99
+ created: int
100
+ model: str
101
+ choices: List[Choice]
102
+ usage: Usage
103
 
104
 
105
  class ModelData(BaseModel):
106
+ id: str
107
+ object: str = "model"
108
+ created: int
109
+ owned_by: str = "google"
110
 
111
 
112
  class ModelList(BaseModel):
113
+ object: str = "list"
114
+ data: List[ModelData]
115
 
116
 
117
  # Simple error handler middleware
118
  @app.middleware("http")
119
  async def error_handling(request: Request, call_next):
120
+ try:
121
+ return await call_next(request)
122
+ except Exception as e:
123
+ logger.error(f"Request failed: {str(e)}")
124
+ return JSONResponse(status_code=500, content={"error": {"message": str(e), "type": "internal_server_error"}})
 
 
 
125
 
126
 
127
  # Get list of available models
128
  @app.get("/v1/models")
129
  async def list_models():
130
+ """返回 gemini_webapi 中声明的模型列表"""
131
+ now = int(datetime.now(tz=timezone.utc).timestamp())
132
+ data = [
133
+ {
134
+ "id": m.model_name, # 如 "gemini-2.0-flash"
135
+ "object": "model",
136
+ "created": now,
137
+ "owned_by": "google-gemini-web",
138
+ }
139
+ for m in Model
140
+ ]
141
+ print(data)
142
+ return {"object": "list", "data": data}
143
 
144
 
145
  # Helper to convert between Gemini and OpenAI model names
146
  def map_model_name(openai_model_name: str) -> Model:
147
+ """根据模型名称字符串查找匹配的 Model 枚举值"""
148
+ # 打印所有可用模型以便调试
149
+ all_models = [m.model_name if hasattr(m, "model_name") else str(m) for m in Model]
150
+ logger.info(f"Available models: {all_models}")
151
+
152
+ # 首先尝试直接查找匹配的模型名称
153
+ for m in Model:
154
+ model_name = m.model_name if hasattr(m, "model_name") else str(m)
155
+ if openai_model_name.lower() in model_name.lower():
156
+ return m
157
+
158
+ # 如果找不到匹配项,使用默认映射
159
+ model_keywords = {
160
+ "gemini-pro": ["pro", "2.0"],
161
+ "gemini-pro-vision": ["vision", "pro"],
162
+ "gemini-flash": ["flash", "2.0"],
163
+ "gemini-1.5-pro": ["1.5", "pro"],
164
+ "gemini-1.5-flash": ["1.5", "flash"],
165
+ }
166
+
167
+ # 根据关键词匹配
168
+ keywords = model_keywords.get(openai_model_name, ["pro"]) # 默认使用pro模型
169
+
170
+ for m in Model:
171
+ model_name = m.model_name if hasattr(m, "model_name") else str(m)
172
+ if all(kw.lower() in model_name.lower() for kw in keywords):
173
+ return m
174
+
175
+ # 如果还是找不到,返回第一个模型
176
+ return next(iter(Model))
177
 
178
 
179
  # Prepare conversation history from OpenAI messages format
180
+ def prepare_conversation(messages: List[Message]) -> tuple:
181
+ conversation = ""
182
+ temp_files = []
183
+
184
+ for msg in messages:
185
+ if isinstance(msg.content, str):
186
+ # String content handling
187
+ if msg.role == "system":
188
+ conversation += f"System: {msg.content}\n\n"
189
+ elif msg.role == "user":
190
+ conversation += f"Human: {msg.content}\n\n"
191
+ elif msg.role == "assistant":
192
+ conversation += f"Assistant: {msg.content}\n\n"
193
+ else:
194
+ # Mixed content handling
195
+ if msg.role == "user":
196
+ conversation += "Human: "
197
+ elif msg.role == "system":
198
+ conversation += "System: "
199
+ elif msg.role == "assistant":
200
+ conversation += "Assistant: "
201
+
202
+ for item in msg.content:
203
+ if item.type == "text":
204
+ conversation += item.text or ""
205
+ elif item.type == "image_url" and item.image_url:
206
+ # Handle image
207
+ image_url = item.image_url.get("url", "")
208
+ if image_url.startswith("data:image/"):
209
+ # Process base64 encoded image
210
+ try:
211
+ # Extract the base64 part
212
+ base64_data = image_url.split(",")[1]
213
+ image_data = base64.b64decode(base64_data)
214
+
215
+ # Create temporary file to hold the image
216
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
217
+ tmp.write(image_data)
218
+ temp_files.append(tmp.name)
219
+ except Exception as e:
220
+ logger.error(f"Error processing base64 image: {str(e)}")
221
+
222
+ conversation += "\n\n"
223
+
224
+ # Add a final prompt for the assistant to respond to
225
+ conversation += "Assistant: "
226
+
227
+ return conversation, temp_files
228
 
229
 
230
  # Dependency to get the initialized Gemini client
231
  async def get_gemini_client():
232
+ global gemini_client
233
+ if gemini_client is None:
234
+ try:
235
+ gemini_client = GeminiClient(SECURE_1PSID, SECURE_1PSIDTS)
236
+ await gemini_client.init(timeout=300)
237
+ except Exception as e:
238
+ logger.error(f"Failed to initialize Gemini client: {str(e)}")
239
+ raise HTTPException(status_code=500, detail=f"Failed to initialize Gemini client: {str(e)}")
240
+ return gemini_client
 
 
 
241
 
242
 
243
  @app.post("/v1/chat/completions")
244
  async def create_chat_completion(request: ChatCompletionRequest):
245
+ try:
246
+ # 确保客户端已初始化
247
+ global gemini_client
248
+ if gemini_client is None:
249
+ gemini_client = GeminiClient(SECURE_1PSID, SECURE_1PSIDTS)
250
+ await gemini_client.init(timeout=300)
251
+ logger.info("Gemini client initialized successfully")
252
+
253
+ # 转换消息为对话格式
254
+ conversation, temp_files = prepare_conversation(request.messages)
255
+ logger.info(f"Prepared conversation: {conversation}")
256
+ logger.info(f"Temp files: {temp_files}")
257
+
258
+ # 获取适当的模型
259
+ model = map_model_name(request.model)
260
+ logger.info(f"Using model: {model}")
261
+
262
+ # 生成响应
263
+ logger.info("Sending request to Gemini...")
264
+ if temp_files:
265
+ # With files
266
+ response = await gemini_client.generate_content(conversation, files=temp_files, model=model)
267
+ else:
268
+ # Text only
269
+ response = await gemini_client.generate_content(conversation, model=model)
270
+
271
+ # 清理临时文件
272
+ for temp_file in temp_files:
273
+ try:
274
+ os.unlink(temp_file)
275
+ except Exception as e:
276
+ logger.warning(f"Failed to delete temp file {temp_file}: {str(e)}")
277
+
278
+ # 提取文本响应
279
+ reply_text = ""
280
+ if hasattr(response, "text"):
281
+ reply_text = response.text
282
+ else:
283
+ reply_text = str(response)
284
+
285
+ logger.info(f"Response: {reply_text}")
286
+
287
+ if not reply_text or reply_text.strip() == "":
288
+ logger.warning("Empty response received from Gemini")
289
+ reply_text = "服务器返回了空响应。请检查 Gemini API 凭据是否有效。"
290
+
291
+ # 创建响应对象
292
+ completion_id = f"chatcmpl-{uuid.uuid4()}"
293
+ created_time = int(time.time())
294
+
295
+ # 检查客户端是否请求流式响应
296
+ if request.stream:
297
+ # 实现流式响应
298
+ async def generate_stream():
299
+ # 创建 SSE 格式的流式响应
300
+ # 先发送开始事件
301
+ data = {
302
+ "id": completion_id,
303
+ "object": "chat.completion.chunk",
304
+ "created": created_time,
305
+ "model": request.model,
306
+ "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
307
+ }
308
+ yield f"data: {json.dumps(data)}\n\n"
309
+
310
+ # 模拟流式输出 - 将文本按字符分割发送
311
+ for char in reply_text:
312
+ data = {
313
+ "id": completion_id,
314
+ "object": "chat.completion.chunk",
315
+ "created": created_time,
316
+ "model": request.model,
317
+ "choices": [{"index": 0, "delta": {"content": char}, "finish_reason": None}],
318
+ }
319
+ yield f"data: {json.dumps(data)}\n\n"
320
+ # 可选:添加短暂延迟以模拟真实的流式输出
321
+ await asyncio.sleep(0.01)
322
+
323
+ # 发送结束事件
324
+ data = {
325
+ "id": completion_id,
326
+ "object": "chat.completion.chunk",
327
+ "created": created_time,
328
+ "model": request.model,
329
+ "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
330
+ }
331
+ yield f"data: {json.dumps(data)}\n\n"
332
+ yield "data: [DONE]\n\n"
333
+
334
+ return StreamingResponse(generate_stream(), media_type="text/event-stream")
335
+ else:
336
+ # 非流式响应(原来的逻辑)
337
+ result = {
338
+ "id": completion_id,
339
+ "object": "chat.completion",
340
+ "created": created_time,
341
+ "model": request.model,
342
+ "choices": [{"index": 0, "message": {"role": "assistant", "content": reply_text}, "finish_reason": "stop"}],
343
+ "usage": {
344
+ "prompt_tokens": len(conversation.split()),
345
+ "completion_tokens": len(reply_text.split()),
346
+ "total_tokens": len(conversation.split()) + len(reply_text.split()),
347
+ },
348
+ }
349
+
350
+ logger.info(f"Returning response: {result}")
351
+ return result
352
+
353
+ except Exception as e:
354
+ logger.error(f"Error generating completion: {str(e)}", exc_info=True)
355
+ raise HTTPException(status_code=500, detail=f"Error generating completion: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
 
358
  @app.get("/")
359
  async def root():
360
+ return {"status": "online", "message": "Gemini API FastAPI Server is running"}
361
 
362
 
363
  if __name__ == "__main__":
364
+ import uvicorn
365
 
366
+ uvicorn.run("main:app", host="0.0.0.0", port=8000, log_level="info")