from fastapi import FastAPI, HTTPException, Depends, Request, Header from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import os import numpy as np from sentence_transformers import SentenceTransformer from typing import List, Optional, Union # 导入Union import logging # 配置日志 logging.basicConfig( level=logging.INFO, format="%(asctime)s-%(name)s-%(levelname)s-%(message)s", handlers=[ logging.FileHandler("embedding_service.log"), logging.StreamHandler() ] ) logger = logging.getLogger("embedding_service") app = FastAPI() # 允许跨域请求 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 模型映射:OpenAI模型名 → 开源模型名 MODEL_MAPPING = { "text-embedding-3-small": "BAAI/bge-small-en-v1.5", "text-embedding-3-large": "BAAI/bge-large-en-v1.5", "bge-small-en-v1.5": "BAAI/bge-small-en-v1.5", "bge-large-en-v1.5": "BAAI/bge-large-en-v1.5" } # 加载模型(懒加载) models = {} def get_model(model_name: str): logger.info(f"尝试获取模型: {model_name}") # 1. 定义所有支持的模型(映射名 + 直接支持的模型名) supported_models = set(MODEL_MAPPING.keys()) # 包含text-embedding-3-*和bge-* model_to_load = MODEL_MAPPING.get(model_name, model_name) # 2. 提前拦截无效模型:若不在支持列表且非已知机构前缀,直接返回400 known_prefixes = ("BAAI/", "sentence-transformers/") # 允许合法机构的模型 if (model_name not in supported_models) and (not model_to_load.startswith(known_prefixes)): error_msg = f"不支持的模型: {model_name}" logger.error(error_msg) raise HTTPException(status_code=400, detail=error_msg) # 3. 加载支持的模型(含合法机构前缀的模型) if model_name not in models: try: hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN") models[model_name] = SentenceTransformer( model_to_load, use_auth_token=hf_token ) logger.info(f"模型 {model_name} 加载成功") except Exception as e: # 若合法模型加载失败(如网络问题),返回500;无效模型已提前拦截 error_msg = f"加载模型 {model_name} 失败: {str(e)}" logger.error(error_msg) raise HTTPException(status_code=500, detail=error_msg) return models[model_name] # 验证API密钥 def verify_api_key(authorization: Optional[str] = Header(None)): logger.info(f"Authorization头部内容: {authorization}") if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="未提供有效的API密钥") api_key = authorization[len("Bearer "):] if api_key != os.getenv("API_KEY"): raise HTTPException(status_code=401, detail="无效的API密钥") logger.info("API密钥验证通过") return True # 请求体模型 class EmbeddingRequest(BaseModel): input: Union[str, List[str]] # 支持str或List[str] model: str encoding_format: Optional[str] = "float" # 响应体模型 class EmbeddingData(BaseModel): object: str = "embedding" embedding: List[float] index: int class EmbeddingResponse(BaseModel): object: str = "list" data: List[EmbeddingData] model: str usage: dict = {"prompt_tokens": 0, "total_tokens": 0} @app.post("/v1/embeddings", response_model=EmbeddingResponse) async def create_embedding( request: Request, req: EmbeddingRequest, _: bool = Depends(verify_api_key) ): # 打印请求信息 logger.info("\n===== 接收到的完整请求信息 =====") logger.info(f"请求方法: {request.method}") logger.info(f"请求URL: {request.url}") logger.info("请求头部:") for name, value in request.headers.items(): logger.info(f" {name}: {value}") logger.info(f"请求体: {await request.body()}") logger.info("===============================\n") # 嵌入生成逻辑 logger.info(f"收到嵌入请求,模型: {req.model}, 输入类型: {type(req.input)}") try: model = get_model(req.model) inputs = [req.input] if isinstance(req.input, str) else req.input logger.info(f"处理输入,文本数量: {len(inputs)}") logger.info("开始计算嵌入") embeddings = model.encode(inputs, normalize_embeddings=True) logger.info(f"嵌入计算完成,嵌入形状: {embeddings.shape}") data = [ EmbeddingData(embedding=embedding.tolist(), index=i) for i, embedding in enumerate(embeddings) ] prompt_tokens = sum(len(text.split()) for text in inputs) logger.info(f"估算token数: {prompt_tokens}") return EmbeddingResponse( data=data, model=req.model, usage={"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens} ) except Exception as e: raise HTTPException(status_code=500, detail=f"处理嵌入请求时发生错误: {str(e)}") @app.get("/health") async def health_check(request: Request): logger.info("\n===== 健康检查请求信息 =====") logger.info(f"请求方法: {request.method}") logger.info(f"请求URL: {request.url}") logger.info("请求头部:") for name, value in request.headers.items(): logger.info(f" {name}: {value}") logger.info("===============================\n") return {"status": "healthy", "models": list(MODEL_MAPPING.keys()) + list(models.keys())} if __name__ == "__main__": import uvicorn logger.info("启动服务") uvicorn.run(app, host="0.0.0.0", port=7860)