Spaces:
Paused
Paused
| import asyncio | |
| import json | |
| from datetime import datetime, timezone | |
| import os | |
| import base64 | |
| import tempfile | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| from typing import List, Optional, Dict, Any, Union | |
| import time | |
| import uuid | |
| import logging | |
| from gemini_webapi import GeminiClient, set_log_level | |
| from gemini_webapi.constants import Model | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| set_log_level("INFO") | |
| app = FastAPI(title="Gemini API FastAPI Server") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global client | |
| gemini_client = None | |
| # Authentication credentials | |
| SECURE_1PSID = os.environ.get("SECURE_1PSID", "") | |
| SECURE_1PSIDTS = os.environ.get("SECURE_1PSIDTS", "") | |
| # Print debug info at startup | |
| if not SECURE_1PSID or not SECURE_1PSIDTS: | |
| logger.warning("⚠️ Gemini API credentials are not set or empty! Please check your environment variables.") | |
| logger.warning("Make sure SECURE_1PSID and SECURE_1PSIDTS are correctly set in your .env file or environment.") | |
| logger.warning("If using Docker, ensure the .env file is correctly mounted and formatted.") | |
| logger.warning("Example format in .env file (no quotes):") | |
| logger.warning("SECURE_1PSID=your_secure_1psid_value_here") | |
| logger.warning("SECURE_1PSIDTS=your_secure_1psidts_value_here") | |
| else: | |
| # Only log the first few characters for security | |
| logger.info(f"Credentials found. SECURE_1PSID starts with: {SECURE_1PSID[:5]}...") | |
| logger.info(f"Credentials found. SECURE_1PSIDTS starts with: {SECURE_1PSIDTS[:5]}...") | |
| # Pydantic models for API requests and responses | |
| class ContentItem(BaseModel): | |
| type: str | |
| text: Optional[str] = None | |
| image_url: Optional[Dict[str, str]] = None | |
| class Message(BaseModel): | |
| role: str | |
| content: Union[str, List[ContentItem]] | |
| name: Optional[str] = None | |
| class ChatCompletionRequest(BaseModel): | |
| model: str | |
| messages: List[Message] | |
| temperature: Optional[float] = 0.7 | |
| top_p: Optional[float] = 1.0 | |
| n: Optional[int] = 1 | |
| stream: Optional[bool] = False | |
| max_tokens: Optional[int] = None | |
| presence_penalty: Optional[float] = 0 | |
| frequency_penalty: Optional[float] = 0 | |
| user: Optional[str] = None | |
| class Choice(BaseModel): | |
| index: int | |
| message: Message | |
| finish_reason: str | |
| class Usage(BaseModel): | |
| prompt_tokens: int | |
| completion_tokens: int | |
| total_tokens: int | |
| class ChatCompletionResponse(BaseModel): | |
| id: str | |
| object: str = "chat.completion" | |
| created: int | |
| model: str | |
| choices: List[Choice] | |
| usage: Usage | |
| class ModelData(BaseModel): | |
| id: str | |
| object: str = "model" | |
| created: int | |
| owned_by: str = "google" | |
| class ModelList(BaseModel): | |
| object: str = "list" | |
| data: List[ModelData] | |
| # Simple error handler middleware | |
| async def error_handling(request: Request, call_next): | |
| try: | |
| return await call_next(request) | |
| except Exception as e: | |
| logger.error(f"Request failed: {str(e)}") | |
| return JSONResponse(status_code=500, content={"error": {"message": str(e), "type": "internal_server_error"}}) | |
| # Get list of available models | |
| async def list_models(): | |
| """返回 gemini_webapi 中声明的模型列表""" | |
| now = int(datetime.now(tz=timezone.utc).timestamp()) | |
| data = [ | |
| { | |
| "id": m.model_name, # 如 "gemini-2.0-flash" | |
| "object": "model", | |
| "created": now, | |
| "owned_by": "google-gemini-web", | |
| } | |
| for m in Model | |
| ] | |
| print(data) | |
| return {"object": "list", "data": data} | |
| # Helper to convert between Gemini and OpenAI model names | |
| def map_model_name(openai_model_name: str) -> Model: | |
| """根据模型名称字符串查找匹配的 Model 枚举值""" | |
| # 打印所有可用模型以便调试 | |
| all_models = [m.model_name if hasattr(m, "model_name") else str(m) for m in Model] | |
| logger.info(f"Available models: {all_models}") | |
| # 首先尝试直接查找匹配的模型名称 | |
| for m in Model: | |
| model_name = m.model_name if hasattr(m, "model_name") else str(m) | |
| if openai_model_name.lower() in model_name.lower(): | |
| return m | |
| # 如果找不到匹配项,使用默认映射 | |
| model_keywords = { | |
| "gemini-pro": ["pro", "2.0"], | |
| "gemini-pro-vision": ["vision", "pro"], | |
| "gemini-flash": ["flash", "2.0"], | |
| "gemini-1.5-pro": ["1.5", "pro"], | |
| "gemini-1.5-flash": ["1.5", "flash"], | |
| } | |
| # 根据关键词匹配 | |
| keywords = model_keywords.get(openai_model_name, ["pro"]) # 默认使用pro模型 | |
| for m in Model: | |
| model_name = m.model_name if hasattr(m, "model_name") else str(m) | |
| if all(kw.lower() in model_name.lower() for kw in keywords): | |
| return m | |
| # 如果还是找不到,返回第一个模型 | |
| return next(iter(Model)) | |
| # Prepare conversation history from OpenAI messages format | |
| def prepare_conversation(messages: List[Message]) -> tuple: | |
| conversation = "" | |
| temp_files = [] | |
| for msg in messages: | |
| if isinstance(msg.content, str): | |
| # String content handling | |
| if msg.role == "system": | |
| conversation += f"System: {msg.content}\n\n" | |
| elif msg.role == "user": | |
| conversation += f"Human: {msg.content}\n\n" | |
| elif msg.role == "assistant": | |
| conversation += f"Assistant: {msg.content}\n\n" | |
| else: | |
| # Mixed content handling | |
| if msg.role == "user": | |
| conversation += "Human: " | |
| elif msg.role == "system": | |
| conversation += "System: " | |
| elif msg.role == "assistant": | |
| conversation += "Assistant: " | |
| for item in msg.content: | |
| if item.type == "text": | |
| conversation += item.text or "" | |
| elif item.type == "image_url" and item.image_url: | |
| # Handle image | |
| image_url = item.image_url.get("url", "") | |
| if image_url.startswith("data:image/"): | |
| # Process base64 encoded image | |
| try: | |
| # Extract the base64 part | |
| base64_data = image_url.split(",")[1] | |
| image_data = base64.b64decode(base64_data) | |
| # Create temporary file to hold the image | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp: | |
| tmp.write(image_data) | |
| temp_files.append(tmp.name) | |
| except Exception as e: | |
| logger.error(f"Error processing base64 image: {str(e)}") | |
| conversation += "\n\n" | |
| # Add a final prompt for the assistant to respond to | |
| conversation += "Assistant: " | |
| return conversation, temp_files | |
| # Dependency to get the initialized Gemini client | |
| async def get_gemini_client(): | |
| global gemini_client | |
| if gemini_client is None: | |
| try: | |
| gemini_client = GeminiClient(SECURE_1PSID, SECURE_1PSIDTS) | |
| await gemini_client.init(timeout=300) | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Gemini client: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Failed to initialize Gemini client: {str(e)}") | |
| return gemini_client | |
| async def create_chat_completion(request: ChatCompletionRequest): | |
| try: | |
| # 确保客户端已初始化 | |
| global gemini_client | |
| if gemini_client is None: | |
| gemini_client = GeminiClient(SECURE_1PSID, SECURE_1PSIDTS) | |
| await gemini_client.init(timeout=300) | |
| logger.info("Gemini client initialized successfully") | |
| # 转换消息为对话格式 | |
| conversation, temp_files = prepare_conversation(request.messages) | |
| logger.info(f"Prepared conversation: {conversation}") | |
| logger.info(f"Temp files: {temp_files}") | |
| # 获取适当的模型 | |
| model = map_model_name(request.model) | |
| logger.info(f"Using model: {model}") | |
| # 生成响应 | |
| logger.info("Sending request to Gemini...") | |
| if temp_files: | |
| # With files | |
| response = await gemini_client.generate_content(conversation, files=temp_files, model=model) | |
| else: | |
| # Text only | |
| response = await gemini_client.generate_content(conversation, model=model) | |
| # 清理临时文件 | |
| for temp_file in temp_files: | |
| try: | |
| os.unlink(temp_file) | |
| except Exception as e: | |
| logger.warning(f"Failed to delete temp file {temp_file}: {str(e)}") | |
| # 提取文本响应 | |
| reply_text = "" | |
| if hasattr(response, "text"): | |
| reply_text = response.text | |
| else: | |
| reply_text = str(response) | |
| logger.info(f"Response: {reply_text}") | |
| if not reply_text or reply_text.strip() == "": | |
| logger.warning("Empty response received from Gemini") | |
| reply_text = "服务器返回了空响应。请检查 Gemini API 凭据是否有效。" | |
| # 创建响应对象 | |
| completion_id = f"chatcmpl-{uuid.uuid4()}" | |
| created_time = int(time.time()) | |
| # 检查客户端是否请求流式响应 | |
| if request.stream: | |
| # 实现流式响应 | |
| async def generate_stream(): | |
| # 创建 SSE 格式的流式响应 | |
| # 先发送开始事件 | |
| data = { | |
| "id": completion_id, | |
| "object": "chat.completion.chunk", | |
| "created": created_time, | |
| "model": request.model, | |
| "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}], | |
| } | |
| yield f"data: {json.dumps(data)}\n\n" | |
| # 模拟流式输出 - 将文本按字符分割发送 | |
| for char in reply_text: | |
| data = { | |
| "id": completion_id, | |
| "object": "chat.completion.chunk", | |
| "created": created_time, | |
| "model": request.model, | |
| "choices": [{"index": 0, "delta": {"content": char}, "finish_reason": None}], | |
| } | |
| yield f"data: {json.dumps(data)}\n\n" | |
| # 可选:添加短暂延迟以模拟真实的流式输出 | |
| await asyncio.sleep(0.01) | |
| # 发送结束事件 | |
| data = { | |
| "id": completion_id, | |
| "object": "chat.completion.chunk", | |
| "created": created_time, | |
| "model": request.model, | |
| "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], | |
| } | |
| yield f"data: {json.dumps(data)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse(generate_stream(), media_type="text/event-stream") | |
| else: | |
| # 非流式响应(原来的逻辑) | |
| result = { | |
| "id": completion_id, | |
| "object": "chat.completion", | |
| "created": created_time, | |
| "model": request.model, | |
| "choices": [{"index": 0, "message": {"role": "assistant", "content": reply_text}, "finish_reason": "stop"}], | |
| "usage": { | |
| "prompt_tokens": len(conversation.split()), | |
| "completion_tokens": len(reply_text.split()), | |
| "total_tokens": len(conversation.split()) + len(reply_text.split()), | |
| }, | |
| } | |
| logger.info(f"Returning response: {result}") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error generating completion: {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Error generating completion: {str(e)}") | |
| async def root(): | |
| return {"status": "online", "message": "Gemini API FastAPI Server is running"} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("main:app", host="0.0.0.0", port=8000, log_level="info") | |