|
|
import asyncio |
|
|
import json |
|
|
from datetime import datetime, timezone |
|
|
import os |
|
|
import base64 |
|
|
import tempfile |
|
|
|
|
|
from fastapi import FastAPI, HTTPException, Request, Depends, Header |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import JSONResponse |
|
|
from fastapi.responses import StreamingResponse |
|
|
from pydantic import BaseModel |
|
|
from typing import List, Optional, Dict, Any, Union |
|
|
import time |
|
|
import uuid |
|
|
import logging |
|
|
|
|
|
from gemini_webapi import GeminiClient, set_log_level |
|
|
from gemini_webapi.constants import Model |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
set_log_level("INFO") |
|
|
|
|
|
app = FastAPI(title="Gemini API FastAPI Server") |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
gemini_client = None |
|
|
|
|
|
|
|
|
SECURE_1PSID = os.environ.get("SECURE_1PSID", "") |
|
|
SECURE_1PSIDTS = os.environ.get("SECURE_1PSIDTS", "") |
|
|
API_KEY = os.environ.get("API_KEY", "") |
|
|
|
|
|
|
|
|
if not SECURE_1PSID or not SECURE_1PSIDTS: |
|
|
logger.warning("⚠️ Gemini API credentials are not set or empty! Please check your environment variables.") |
|
|
logger.warning("Make sure SECURE_1PSID and SECURE_1PSIDTS are correctly set in your .env file or environment.") |
|
|
logger.warning("If using Docker, ensure the .env file is correctly mounted and formatted.") |
|
|
logger.warning("Example format in .env file (no quotes):") |
|
|
logger.warning("SECURE_1PSID=your_secure_1psid_value_here") |
|
|
logger.warning("SECURE_1PSIDTS=your_secure_1psidts_value_here") |
|
|
else: |
|
|
|
|
|
logger.info(f"Credentials found. SECURE_1PSID starts with: {SECURE_1PSID[:5]}...") |
|
|
logger.info(f"Credentials found. SECURE_1PSIDTS starts with: {SECURE_1PSIDTS[:5]}...") |
|
|
|
|
|
if not API_KEY: |
|
|
logger.warning("⚠️ API_KEY is not set or empty! API authentication will not work.") |
|
|
logger.warning("Make sure API_KEY is correctly set in your .env file or environment.") |
|
|
else: |
|
|
logger.info(f"API_KEY found. API_KEY starts with: {API_KEY[:5]}...") |
|
|
|
|
|
|
|
|
|
|
|
class ContentItem(BaseModel): |
|
|
type: str |
|
|
text: Optional[str] = None |
|
|
image_url: Optional[Dict[str, str]] = None |
|
|
|
|
|
|
|
|
class Message(BaseModel): |
|
|
role: str |
|
|
content: Union[str, List[ContentItem]] |
|
|
name: Optional[str] = None |
|
|
|
|
|
|
|
|
class ChatCompletionRequest(BaseModel): |
|
|
model: str |
|
|
messages: List[Message] |
|
|
temperature: Optional[float] = 0.7 |
|
|
top_p: Optional[float] = 1.0 |
|
|
n: Optional[int] = 1 |
|
|
stream: Optional[bool] = False |
|
|
max_tokens: Optional[int] = None |
|
|
presence_penalty: Optional[float] = 0 |
|
|
frequency_penalty: Optional[float] = 0 |
|
|
user: Optional[str] = None |
|
|
|
|
|
|
|
|
class Choice(BaseModel): |
|
|
index: int |
|
|
message: Message |
|
|
finish_reason: str |
|
|
|
|
|
|
|
|
class Usage(BaseModel): |
|
|
prompt_tokens: int |
|
|
completion_tokens: int |
|
|
total_tokens: int |
|
|
|
|
|
|
|
|
class ChatCompletionResponse(BaseModel): |
|
|
id: str |
|
|
object: str = "chat.completion" |
|
|
created: int |
|
|
model: str |
|
|
choices: List[Choice] |
|
|
usage: Usage |
|
|
|
|
|
|
|
|
class ModelData(BaseModel): |
|
|
id: str |
|
|
object: str = "model" |
|
|
created: int |
|
|
owned_by: str = "google" |
|
|
|
|
|
|
|
|
class ModelList(BaseModel): |
|
|
object: str = "list" |
|
|
data: List[ModelData] |
|
|
|
|
|
|
|
|
|
|
|
async def verify_api_key(authorization: str = Header(None)): |
|
|
if not API_KEY: |
|
|
|
|
|
logger.warning("API key validation skipped - no API_KEY set in environment") |
|
|
return |
|
|
|
|
|
if not authorization: |
|
|
raise HTTPException(status_code=401, detail="Missing Authorization header") |
|
|
|
|
|
try: |
|
|
scheme, token = authorization.split() |
|
|
if scheme.lower() != "bearer": |
|
|
raise HTTPException(status_code=401, detail="Invalid authentication scheme. Use Bearer token") |
|
|
|
|
|
if token != API_KEY: |
|
|
raise HTTPException(status_code=401, detail="Invalid API key") |
|
|
except ValueError: |
|
|
raise HTTPException(status_code=401, detail="Invalid authorization format. Use 'Bearer YOUR_API_KEY'") |
|
|
|
|
|
return token |
|
|
|
|
|
|
|
|
|
|
|
@app.middleware("http") |
|
|
async def error_handling(request: Request, call_next): |
|
|
try: |
|
|
return await call_next(request) |
|
|
except Exception as e: |
|
|
logger.error(f"Request failed: {str(e)}") |
|
|
return JSONResponse(status_code=500, content={"error": {"message": str(e), "type": "internal_server_error"}}) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/v1/models") |
|
|
async def list_models(): |
|
|
"""返回 gemini_webapi 中声明的模型列表""" |
|
|
now = int(datetime.now(tz=timezone.utc).timestamp()) |
|
|
data = [ |
|
|
{ |
|
|
"id": m.model_name, |
|
|
"object": "model", |
|
|
"created": now, |
|
|
"owned_by": "google-gemini-web", |
|
|
} |
|
|
for m in Model |
|
|
] |
|
|
print(data) |
|
|
return {"object": "list", "data": data} |
|
|
|
|
|
|
|
|
|
|
|
def map_model_name(openai_model_name: str) -> Model: |
|
|
"""根据模型名称字符串查找匹配的 Model 枚举值""" |
|
|
|
|
|
all_models = [m.model_name if hasattr(m, "model_name") else str(m) for m in Model] |
|
|
logger.info(f"Available models: {all_models}") |
|
|
|
|
|
|
|
|
for m in Model: |
|
|
model_name = m.model_name if hasattr(m, "model_name") else str(m) |
|
|
if openai_model_name.lower() in model_name.lower(): |
|
|
return m |
|
|
|
|
|
|
|
|
model_keywords = { |
|
|
"gemini-pro": ["pro", "2.0"], |
|
|
"gemini-pro-vision": ["vision", "pro"], |
|
|
"gemini-flash": ["flash", "2.0"], |
|
|
"gemini-1.5-pro": ["1.5", "pro"], |
|
|
"gemini-1.5-flash": ["1.5", "flash"], |
|
|
} |
|
|
|
|
|
|
|
|
keywords = model_keywords.get(openai_model_name, ["pro"]) |
|
|
|
|
|
for m in Model: |
|
|
model_name = m.model_name if hasattr(m, "model_name") else str(m) |
|
|
if all(kw.lower() in model_name.lower() for kw in keywords): |
|
|
return m |
|
|
|
|
|
|
|
|
return next(iter(Model)) |
|
|
|
|
|
|
|
|
|
|
|
def prepare_conversation(messages: List[Message]) -> tuple: |
|
|
conversation = "" |
|
|
temp_files = [] |
|
|
|
|
|
for msg in messages: |
|
|
if isinstance(msg.content, str): |
|
|
|
|
|
if msg.role == "system": |
|
|
conversation += f"System: {msg.content}\n\n" |
|
|
elif msg.role == "user": |
|
|
conversation += f"Human: {msg.content}\n\n" |
|
|
elif msg.role == "assistant": |
|
|
conversation += f"Assistant: {msg.content}\n\n" |
|
|
else: |
|
|
|
|
|
if msg.role == "user": |
|
|
conversation += "Human: " |
|
|
elif msg.role == "system": |
|
|
conversation += "System: " |
|
|
elif msg.role == "assistant": |
|
|
conversation += "Assistant: " |
|
|
|
|
|
for item in msg.content: |
|
|
if item.type == "text": |
|
|
conversation += item.text or "" |
|
|
elif item.type == "image_url" and item.image_url: |
|
|
|
|
|
image_url = item.image_url.get("url", "") |
|
|
if image_url.startswith("data:image/"): |
|
|
|
|
|
try: |
|
|
|
|
|
base64_data = image_url.split(",")[1] |
|
|
image_data = base64.b64decode(base64_data) |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp: |
|
|
tmp.write(image_data) |
|
|
temp_files.append(tmp.name) |
|
|
except Exception as e: |
|
|
logger.error(f"Error processing base64 image: {str(e)}") |
|
|
|
|
|
conversation += "\n\n" |
|
|
|
|
|
|
|
|
conversation += "Assistant: " |
|
|
|
|
|
return conversation, temp_files |
|
|
|
|
|
|
|
|
|
|
|
async def get_gemini_client(): |
|
|
global gemini_client |
|
|
if gemini_client is None: |
|
|
try: |
|
|
gemini_client = GeminiClient(SECURE_1PSID, SECURE_1PSIDTS) |
|
|
await gemini_client.init(timeout=300) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize Gemini client: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"Failed to initialize Gemini client: {str(e)}") |
|
|
return gemini_client |
|
|
|
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
|
async def create_chat_completion(request: ChatCompletionRequest, api_key: str = Depends(verify_api_key)): |
|
|
try: |
|
|
|
|
|
global gemini_client |
|
|
if gemini_client is None: |
|
|
gemini_client = GeminiClient(SECURE_1PSID, SECURE_1PSIDTS) |
|
|
await gemini_client.init(timeout=300) |
|
|
logger.info("Gemini client initialized successfully") |
|
|
|
|
|
|
|
|
conversation, temp_files = prepare_conversation(request.messages) |
|
|
logger.info(f"Prepared conversation: {conversation}") |
|
|
logger.info(f"Temp files: {temp_files}") |
|
|
|
|
|
|
|
|
model = map_model_name(request.model) |
|
|
logger.info(f"Using model: {model}") |
|
|
|
|
|
|
|
|
logger.info("Sending request to Gemini...") |
|
|
if temp_files: |
|
|
|
|
|
response = await gemini_client.generate_content(conversation, files=temp_files, model=model) |
|
|
else: |
|
|
|
|
|
response = await gemini_client.generate_content(conversation, model=model) |
|
|
|
|
|
|
|
|
for temp_file in temp_files: |
|
|
try: |
|
|
os.unlink(temp_file) |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to delete temp file {temp_file}: {str(e)}") |
|
|
|
|
|
|
|
|
reply_text = "" |
|
|
if hasattr(response, "text"): |
|
|
reply_text = response.text |
|
|
else: |
|
|
reply_text = str(response) |
|
|
|
|
|
logger.info(f"Response: {reply_text}") |
|
|
|
|
|
if not reply_text or reply_text.strip() == "": |
|
|
logger.warning("Empty response received from Gemini") |
|
|
reply_text = "服务器返回了空响应。请检查 Gemini API 凭据是否有效。" |
|
|
|
|
|
|
|
|
completion_id = f"chatcmpl-{uuid.uuid4()}" |
|
|
created_time = int(time.time()) |
|
|
|
|
|
|
|
|
if request.stream: |
|
|
|
|
|
async def generate_stream(): |
|
|
|
|
|
|
|
|
data = { |
|
|
"id": completion_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": created_time, |
|
|
"model": request.model, |
|
|
"choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}], |
|
|
} |
|
|
yield f"data: {json.dumps(data)}\n\n" |
|
|
|
|
|
|
|
|
for char in reply_text: |
|
|
data = { |
|
|
"id": completion_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": created_time, |
|
|
"model": request.model, |
|
|
"choices": [{"index": 0, "delta": {"content": char}, "finish_reason": None}], |
|
|
} |
|
|
yield f"data: {json.dumps(data)}\n\n" |
|
|
|
|
|
await asyncio.sleep(0.01) |
|
|
|
|
|
|
|
|
data = { |
|
|
"id": completion_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": created_time, |
|
|
"model": request.model, |
|
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], |
|
|
} |
|
|
yield f"data: {json.dumps(data)}\n\n" |
|
|
yield "data: [DONE]\n\n" |
|
|
|
|
|
return StreamingResponse(generate_stream(), media_type="text/event-stream") |
|
|
else: |
|
|
|
|
|
result = { |
|
|
"id": completion_id, |
|
|
"object": "chat.completion", |
|
|
"created": created_time, |
|
|
"model": request.model, |
|
|
"choices": [{"index": 0, "message": {"role": "assistant", "content": reply_text}, "finish_reason": "stop"}], |
|
|
"usage": { |
|
|
"prompt_tokens": len(conversation.split()), |
|
|
"completion_tokens": len(reply_text.split()), |
|
|
"total_tokens": len(conversation.split()) + len(reply_text.split()), |
|
|
}, |
|
|
} |
|
|
|
|
|
logger.info(f"Returning response: {result}") |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating completion: {str(e)}", exc_info=True) |
|
|
raise HTTPException(status_code=500, detail=f"Error generating completion: {str(e)}") |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return {"status": "online", "message": "Gemini API FastAPI Server is running"} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
|
|
|
uvicorn.run("main:app", host="0.0.0.0", port=8000, log_level="info") |
|
|
|