any-api / routers /v1.py
tanbushi's picture
111
8ec3cc4
import requests
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel
from typing import List, Optional
import random
import string,json
router = APIRouter()
def generate_chatcmpl_id():
characters = string.ascii_letters + string.digits
return "chatcmpl-" + ''.join(random.choice(characters) for i in range(29))
# Define data models for request and response
class ChatCompletionMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatCompletionMessage]
max_tokens: Optional[int] = None
temperature: Optional[float] = None
class ChatCompletionChoice(BaseModel):
index: int
message: ChatCompletionMessage
finish_reason: str
class ChatCompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
id: str
object: str
created: int
model: str
choices: List[ChatCompletionChoice]
usage: ChatCompletionUsage
@router.post("/chat/completions")
async def chat_completions(chat_request: ChatCompletionRequest, request: Request):
"""
聊天完成 endpoint,调用 Gemini API
"""
headers = dict(request.headers)
body = await request.body()
print('\\n\\n\\n接收到前端发过来的头信息')
print(headers)
print('\\n\\n\\n接收到前端发过来的body信息')
print(body)
gemini_api_url = "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions" # 需要验证 URL 的正确性
auth_header = request.headers.get("Authorization")
# 重写 设置请求头
headers = {"Content-Type": "application/json"}
if auth_header:
headers["Authorization"] = auth_header
headers["Access-Control-Allow-Origin"]= "*"
print('\n\n\n将要发生的头信息')
print(headers)
try:
print(f"chat_request.model_dump(): {chat_request.model_dump()}")
response = requests.post(gemini_api_url, headers=headers, json=chat_request.model_dump())
#import json
#response = requests.post(gemini_api_url, headers=headers, json=json.loads(body))
response.raise_for_status() # 检查请求是否成功
# 将 Gemini API 的响应转换为 ChatCompletionResponse
gemini_response = response.json()
formatted_response = json.dumps(gemini_response, indent=4, ensure_ascii=False)
print(f"\\n\\n\\nGemini API response: {formatted_response}") # 打印 Gemini API 的响应
chat_completion_response = ChatCompletionResponse(
id=generate_chatcmpl_id(),
object="chat.completion",
created=gemini_response.get("created", 1629750000),
model=chat_request.model,
choices=[
ChatCompletionChoice(
index=i,
message=ChatCompletionMessage(
role=choice.get("message", {}).get("role", "assistant"),
content=choice.get("message", {}).get("content", "This is a dummy response from Gemini."),
),
finish_reason=choice.get("finish_reason", "stop"),
)
for i, choice in enumerate(gemini_response.get("choices", []))
],
usage=ChatCompletionUsage(
prompt_tokens=gemini_response.get("usage", {}).get("prompt_tokens", 0),
completion_tokens=gemini_response.get("usage", {}).get("completion_tokens", 0),
total_tokens=gemini_response.get("usage", {}).get("total_tokens", 0),
),
)
return chat_completion_response
except requests.exceptions.RequestException as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/")
def greet_v1():
return {"version": "v1", "message": "Hello, World!"}