gpt-text-api / app.py
fiewolf1000's picture
Update app.py
75911cc verified
from fastapi import FastAPI, HTTPException, Depends, Request, Header
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import os
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Optional, Union # 导入Union
import logging
# 配置日志
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s-%(name)s-%(levelname)s-%(message)s",
handlers=[
logging.FileHandler("embedding_service.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger("embedding_service")
app = FastAPI()
# 允许跨域请求
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 模型映射:OpenAI模型名 → 开源模型名
MODEL_MAPPING = {
"text-embedding-3-small": "BAAI/bge-small-en-v1.5",
"text-embedding-3-large": "BAAI/bge-large-en-v1.5",
"bge-small-en-v1.5": "BAAI/bge-small-en-v1.5",
"bge-large-en-v1.5": "BAAI/bge-large-en-v1.5"
}
# 加载模型(懒加载)
models = {}
def get_model(model_name: str):
logger.info(f"尝试获取模型: {model_name}")
# 1. 定义所有支持的模型(映射名 + 直接支持的模型名)
supported_models = set(MODEL_MAPPING.keys()) # 包含text-embedding-3-*和bge-*
model_to_load = MODEL_MAPPING.get(model_name, model_name)
# 2. 提前拦截无效模型:若不在支持列表且非已知机构前缀,直接返回400
known_prefixes = ("BAAI/", "sentence-transformers/") # 允许合法机构的模型
if (model_name not in supported_models) and (not model_to_load.startswith(known_prefixes)):
error_msg = f"不支持的模型: {model_name}"
logger.error(error_msg)
raise HTTPException(status_code=400, detail=error_msg)
# 3. 加载支持的模型(含合法机构前缀的模型)
if model_name not in models:
try:
hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN")
models[model_name] = SentenceTransformer(
model_to_load,
use_auth_token=hf_token
)
logger.info(f"模型 {model_name} 加载成功")
except Exception as e:
# 若合法模型加载失败(如网络问题),返回500;无效模型已提前拦截
error_msg = f"加载模型 {model_name} 失败: {str(e)}"
logger.error(error_msg)
raise HTTPException(status_code=500, detail=error_msg)
return models[model_name]
# 验证API密钥
def verify_api_key(authorization: Optional[str] = Header(None)):
logger.info(f"Authorization头部内容: {authorization}")
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的API密钥")
api_key = authorization[len("Bearer "):]
if api_key != os.getenv("API_KEY"):
raise HTTPException(status_code=401, detail="无效的API密钥")
logger.info("API密钥验证通过")
return True
# 请求体模型
class EmbeddingRequest(BaseModel):
input: Union[str, List[str]] # 支持str或List[str]
model: str
encoding_format: Optional[str] = "float"
# 响应体模型
class EmbeddingData(BaseModel):
object: str = "embedding"
embedding: List[float]
index: int
class EmbeddingResponse(BaseModel):
object: str = "list"
data: List[EmbeddingData]
model: str
usage: dict = {"prompt_tokens": 0, "total_tokens": 0}
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def create_embedding(
request: Request,
req: EmbeddingRequest,
_: bool = Depends(verify_api_key)
):
# 打印请求信息
logger.info("\n===== 接收到的完整请求信息 =====")
logger.info(f"请求方法: {request.method}")
logger.info(f"请求URL: {request.url}")
logger.info("请求头部:")
for name, value in request.headers.items():
logger.info(f" {name}: {value}")
logger.info(f"请求体: {await request.body()}")
logger.info("===============================\n")
# 嵌入生成逻辑
logger.info(f"收到嵌入请求,模型: {req.model}, 输入类型: {type(req.input)}")
try:
model = get_model(req.model)
inputs = [req.input] if isinstance(req.input, str) else req.input
logger.info(f"处理输入,文本数量: {len(inputs)}")
logger.info("开始计算嵌入")
embeddings = model.encode(inputs, normalize_embeddings=True)
logger.info(f"嵌入计算完成,嵌入形状: {embeddings.shape}")
data = [
EmbeddingData(embedding=embedding.tolist(), index=i)
for i, embedding in enumerate(embeddings)
]
prompt_tokens = sum(len(text.split()) for text in inputs)
logger.info(f"估算token数: {prompt_tokens}")
return EmbeddingResponse(
data=data,
model=req.model,
usage={"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"处理嵌入请求时发生错误: {str(e)}")
@app.get("/health")
async def health_check(request: Request):
logger.info("\n===== 健康检查请求信息 =====")
logger.info(f"请求方法: {request.method}")
logger.info(f"请求URL: {request.url}")
logger.info("请求头部:")
for name, value in request.headers.items():
logger.info(f" {name}: {value}")
logger.info("===============================\n")
return {"status": "healthy", "models": list(MODEL_MAPPING.keys()) + list(models.keys())}
if __name__ == "__main__":
import uvicorn
logger.info("启动服务")
uvicorn.run(app, host="0.0.0.0", port=7860)