|
|
import asyncio |
|
|
import json |
|
|
from datetime import datetime, timezone |
|
|
import os |
|
|
|
|
|
from fastapi import FastAPI, HTTPException, Request |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import JSONResponse |
|
|
from fastapi.responses import StreamingResponse |
|
|
from pydantic import BaseModel |
|
|
from typing import List, Optional |
|
|
import time |
|
|
import uuid |
|
|
import logging |
|
|
|
|
|
from gemini_webapi import GeminiClient, set_log_level |
|
|
from gemini_webapi.constants import Model |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
set_log_level("INFO") |
|
|
|
|
|
app = FastAPI(title="Gemini API FastAPI Server") |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
gemini_client = None |
|
|
|
|
|
|
|
|
SECURE_1PSID = os.environ.get("SECURE_1PSID", "") |
|
|
SECURE_1PSIDTS = os.environ.get("SECURE_1PSIDTS", "") |
|
|
|
|
|
|
|
|
if not SECURE_1PSID or not SECURE_1PSIDTS: |
|
|
logger.warning("⚠️ Gemini API credentials are not set or empty! Please check your environment variables.") |
|
|
else: |
|
|
|
|
|
logger.info(f"Credentials found. SECURE_1PSID starts with: {SECURE_1PSID[:5]}...") |
|
|
logger.info(f"Credentials found. SECURE_1PSIDTS starts with: {SECURE_1PSIDTS[:5]}...") |
|
|
|
|
|
|
|
|
class Message(BaseModel): |
|
|
role: str |
|
|
content: str |
|
|
name: Optional[str] = None |
|
|
|
|
|
|
|
|
class ChatCompletionRequest(BaseModel): |
|
|
model: str |
|
|
messages: List[Message] |
|
|
temperature: Optional[float] = 0.7 |
|
|
top_p: Optional[float] = 1.0 |
|
|
n: Optional[int] = 1 |
|
|
stream: Optional[bool] = False |
|
|
max_tokens: Optional[int] = None |
|
|
presence_penalty: Optional[float] = 0 |
|
|
frequency_penalty: Optional[float] = 0 |
|
|
user: Optional[str] = None |
|
|
|
|
|
|
|
|
class Choice(BaseModel): |
|
|
index: int |
|
|
message: Message |
|
|
finish_reason: str |
|
|
|
|
|
|
|
|
class Usage(BaseModel): |
|
|
prompt_tokens: int |
|
|
completion_tokens: int |
|
|
total_tokens: int |
|
|
|
|
|
|
|
|
class ChatCompletionResponse(BaseModel): |
|
|
id: str |
|
|
object: str = "chat.completion" |
|
|
created: int |
|
|
model: str |
|
|
choices: List[Choice] |
|
|
usage: Usage |
|
|
|
|
|
|
|
|
class ModelData(BaseModel): |
|
|
id: str |
|
|
object: str = "model" |
|
|
created: int |
|
|
owned_by: str = "google" |
|
|
|
|
|
|
|
|
class ModelList(BaseModel): |
|
|
object: str = "list" |
|
|
data: List[ModelData] |
|
|
|
|
|
|
|
|
|
|
|
@app.middleware("http") |
|
|
async def error_handling(request: Request, call_next): |
|
|
try: |
|
|
return await call_next(request) |
|
|
except Exception as e: |
|
|
logger.error(f"Request failed: {str(e)}") |
|
|
return JSONResponse( |
|
|
status_code=500, |
|
|
content={ "error": { "message": str(e), "type": "internal_server_error" } } |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/v1/models") |
|
|
async def list_models(): |
|
|
"""返回 gemini_webapi 中声明的模型列表""" |
|
|
now = int(datetime.now(tz=timezone.utc).timestamp()) |
|
|
data = [ |
|
|
{ |
|
|
"id": m.model_name, |
|
|
"object": "model", |
|
|
"created": now, |
|
|
"owned_by": "google-gemini-web" |
|
|
} |
|
|
for m in Model |
|
|
] |
|
|
print(data) |
|
|
return {"object": "list", "data": data} |
|
|
|
|
|
|
|
|
|
|
|
def map_model_name(openai_model_name: str) -> Model: |
|
|
"""根据模型名称字符串查找匹配的 Model 枚举值""" |
|
|
|
|
|
all_models = [m.model_name if hasattr(m, "model_name") else str(m) for m in Model] |
|
|
logger.info(f"Available models: {all_models}") |
|
|
|
|
|
|
|
|
for m in Model: |
|
|
model_name = m.model_name if hasattr(m, "model_name") else str(m) |
|
|
if openai_model_name.lower() in model_name.lower(): |
|
|
return m |
|
|
|
|
|
|
|
|
model_keywords = { |
|
|
"gemini-pro": ["pro", "2.0"], |
|
|
"gemini-pro-vision": ["vision", "pro"], |
|
|
"gemini-flash": ["flash", "2.0"], |
|
|
"gemini-1.5-pro": ["1.5", "pro"], |
|
|
"gemini-1.5-flash": ["1.5", "flash"], |
|
|
} |
|
|
|
|
|
|
|
|
keywords = model_keywords.get(openai_model_name, ["pro"]) |
|
|
|
|
|
for m in Model: |
|
|
model_name = m.model_name if hasattr(m, "model_name") else str(m) |
|
|
if all(kw.lower() in model_name.lower() for kw in keywords): |
|
|
return m |
|
|
|
|
|
|
|
|
return next(iter(Model)) |
|
|
|
|
|
|
|
|
|
|
|
def prepare_conversation(messages: List[Message]) -> str: |
|
|
conversation = "" |
|
|
|
|
|
for msg in messages: |
|
|
if msg.role == "system": |
|
|
conversation += f"System: {msg.content}\n\n" |
|
|
elif msg.role == "user": |
|
|
conversation += f"Human: {msg.content}\n\n" |
|
|
elif msg.role == "assistant": |
|
|
conversation += f"Assistant: {msg.content}\n\n" |
|
|
|
|
|
|
|
|
conversation += "Assistant: " |
|
|
|
|
|
return conversation |
|
|
|
|
|
|
|
|
|
|
|
async def get_gemini_client(): |
|
|
global gemini_client |
|
|
if gemini_client is None: |
|
|
try: |
|
|
gemini_client = GeminiClient(SECURE_1PSID, SECURE_1PSIDTS) |
|
|
await gemini_client.init(timeout=30) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize Gemini client: {str(e)}") |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=f"Failed to initialize Gemini client: {str(e)}" |
|
|
) |
|
|
return gemini_client |
|
|
|
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
|
async def create_chat_completion(request: ChatCompletionRequest): |
|
|
try: |
|
|
|
|
|
global gemini_client |
|
|
if gemini_client is None: |
|
|
gemini_client = GeminiClient(SECURE_1PSID, SECURE_1PSIDTS) |
|
|
await gemini_client.init(timeout=30) |
|
|
logger.info("Gemini client initialized successfully") |
|
|
|
|
|
|
|
|
conversation = prepare_conversation(request.messages) |
|
|
logger.info(f"Prepared conversation: {conversation}") |
|
|
|
|
|
|
|
|
model = map_model_name(request.model) |
|
|
logger.info(f"Using model: {model}") |
|
|
|
|
|
|
|
|
logger.info("Sending request to Gemini...") |
|
|
response = await gemini_client.generate_content(conversation, model=model) |
|
|
|
|
|
|
|
|
reply_text = "" |
|
|
if hasattr(response, "text"): |
|
|
reply_text = response.text |
|
|
else: |
|
|
reply_text = str(response) |
|
|
|
|
|
logger.info(f"Response: {reply_text}") |
|
|
|
|
|
if not reply_text or reply_text.strip() == "": |
|
|
logger.warning("Empty response received from Gemini") |
|
|
reply_text = "服务器返回了空响应。请检查 Gemini API 凭据是否有效。" |
|
|
|
|
|
|
|
|
completion_id = f"chatcmpl-{uuid.uuid4()}" |
|
|
created_time = int(time.time()) |
|
|
|
|
|
|
|
|
if request.stream: |
|
|
|
|
|
async def generate_stream(): |
|
|
|
|
|
|
|
|
data = { |
|
|
"id": completion_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": created_time, |
|
|
"model": request.model, |
|
|
"choices": [ |
|
|
{ |
|
|
"index": 0, |
|
|
"delta": { |
|
|
"role": "assistant" |
|
|
}, |
|
|
"finish_reason": None |
|
|
} |
|
|
] |
|
|
} |
|
|
yield f"data: {json.dumps(data)}\n\n" |
|
|
|
|
|
|
|
|
for char in reply_text: |
|
|
data = { |
|
|
"id": completion_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": created_time, |
|
|
"model": request.model, |
|
|
"choices": [ |
|
|
{ |
|
|
"index": 0, |
|
|
"delta": { |
|
|
"content": char |
|
|
}, |
|
|
"finish_reason": None |
|
|
} |
|
|
] |
|
|
} |
|
|
yield f"data: {json.dumps(data)}\n\n" |
|
|
|
|
|
await asyncio.sleep(0.01) |
|
|
|
|
|
|
|
|
data = { |
|
|
"id": completion_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": created_time, |
|
|
"model": request.model, |
|
|
"choices": [ |
|
|
{ |
|
|
"index": 0, |
|
|
"delta": { }, |
|
|
"finish_reason": "stop" |
|
|
} |
|
|
] |
|
|
} |
|
|
yield f"data: {json.dumps(data)}\n\n" |
|
|
yield "data: [DONE]\n\n" |
|
|
|
|
|
return StreamingResponse( |
|
|
generate_stream(), |
|
|
media_type="text/event-stream" |
|
|
) |
|
|
else: |
|
|
|
|
|
result = { |
|
|
"id": completion_id, |
|
|
"object": "chat.completion", |
|
|
"created": created_time, |
|
|
"model": request.model, |
|
|
"choices": [ |
|
|
{ |
|
|
"index": 0, |
|
|
"message": { |
|
|
"role": "assistant", |
|
|
"content": reply_text |
|
|
}, |
|
|
"finish_reason": "stop" |
|
|
} |
|
|
], |
|
|
"usage": { |
|
|
"prompt_tokens": len(conversation.split()), |
|
|
"completion_tokens": len(reply_text.split()), |
|
|
"total_tokens": len(conversation.split()) + len(reply_text.split()) |
|
|
} |
|
|
} |
|
|
|
|
|
logger.info(f"Returning response: {result}") |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating completion: {str(e)}", exc_info=True) |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail=f"Error generating completion: {str(e)}" |
|
|
) |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return { "status": "online", "message": "Gemini API FastAPI Server is running" } |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
|
|
|
uvicorn.run("main:app", host="0.0.0.0", port=7890, log_level="info") |