Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Depends, Request, Header | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import os | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from typing import List, Optional, Union # 导入Union | |
| import logging | |
| # 配置日志 | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s-%(name)s-%(levelname)s-%(message)s", | |
| handlers=[ | |
| logging.FileHandler("embedding_service.log"), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| logger = logging.getLogger("embedding_service") | |
| app = FastAPI() | |
| # 允许跨域请求 | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # 模型映射:OpenAI模型名 → 开源模型名 | |
| MODEL_MAPPING = { | |
| "text-embedding-3-small": "BAAI/bge-small-en-v1.5", | |
| "text-embedding-3-large": "BAAI/bge-large-en-v1.5", | |
| "bge-small-en-v1.5": "BAAI/bge-small-en-v1.5", | |
| "bge-large-en-v1.5": "BAAI/bge-large-en-v1.5" | |
| } | |
| # 加载模型(懒加载) | |
| models = {} | |
| def get_model(model_name: str): | |
| logger.info(f"尝试获取模型: {model_name}") | |
| # 1. 定义所有支持的模型(映射名 + 直接支持的模型名) | |
| supported_models = set(MODEL_MAPPING.keys()) # 包含text-embedding-3-*和bge-* | |
| model_to_load = MODEL_MAPPING.get(model_name, model_name) | |
| # 2. 提前拦截无效模型:若不在支持列表且非已知机构前缀,直接返回400 | |
| known_prefixes = ("BAAI/", "sentence-transformers/") # 允许合法机构的模型 | |
| if (model_name not in supported_models) and (not model_to_load.startswith(known_prefixes)): | |
| error_msg = f"不支持的模型: {model_name}" | |
| logger.error(error_msg) | |
| raise HTTPException(status_code=400, detail=error_msg) | |
| # 3. 加载支持的模型(含合法机构前缀的模型) | |
| if model_name not in models: | |
| try: | |
| hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN") | |
| models[model_name] = SentenceTransformer( | |
| model_to_load, | |
| use_auth_token=hf_token | |
| ) | |
| logger.info(f"模型 {model_name} 加载成功") | |
| except Exception as e: | |
| # 若合法模型加载失败(如网络问题),返回500;无效模型已提前拦截 | |
| error_msg = f"加载模型 {model_name} 失败: {str(e)}" | |
| logger.error(error_msg) | |
| raise HTTPException(status_code=500, detail=error_msg) | |
| return models[model_name] | |
| # 验证API密钥 | |
| def verify_api_key(authorization: Optional[str] = Header(None)): | |
| logger.info(f"Authorization头部内容: {authorization}") | |
| if not authorization or not authorization.startswith("Bearer "): | |
| raise HTTPException(status_code=401, detail="未提供有效的API密钥") | |
| api_key = authorization[len("Bearer "):] | |
| if api_key != os.getenv("API_KEY"): | |
| raise HTTPException(status_code=401, detail="无效的API密钥") | |
| logger.info("API密钥验证通过") | |
| return True | |
| # 请求体模型 | |
| class EmbeddingRequest(BaseModel): | |
| input: Union[str, List[str]] # 支持str或List[str] | |
| model: str | |
| encoding_format: Optional[str] = "float" | |
| # 响应体模型 | |
| class EmbeddingData(BaseModel): | |
| object: str = "embedding" | |
| embedding: List[float] | |
| index: int | |
| class EmbeddingResponse(BaseModel): | |
| object: str = "list" | |
| data: List[EmbeddingData] | |
| model: str | |
| usage: dict = {"prompt_tokens": 0, "total_tokens": 0} | |
| async def create_embedding( | |
| request: Request, | |
| req: EmbeddingRequest, | |
| _: bool = Depends(verify_api_key) | |
| ): | |
| # 打印请求信息 | |
| logger.info("\n===== 接收到的完整请求信息 =====") | |
| logger.info(f"请求方法: {request.method}") | |
| logger.info(f"请求URL: {request.url}") | |
| logger.info("请求头部:") | |
| for name, value in request.headers.items(): | |
| logger.info(f" {name}: {value}") | |
| logger.info(f"请求体: {await request.body()}") | |
| logger.info("===============================\n") | |
| # 嵌入生成逻辑 | |
| logger.info(f"收到嵌入请求,模型: {req.model}, 输入类型: {type(req.input)}") | |
| try: | |
| model = get_model(req.model) | |
| inputs = [req.input] if isinstance(req.input, str) else req.input | |
| logger.info(f"处理输入,文本数量: {len(inputs)}") | |
| logger.info("开始计算嵌入") | |
| embeddings = model.encode(inputs, normalize_embeddings=True) | |
| logger.info(f"嵌入计算完成,嵌入形状: {embeddings.shape}") | |
| data = [ | |
| EmbeddingData(embedding=embedding.tolist(), index=i) | |
| for i, embedding in enumerate(embeddings) | |
| ] | |
| prompt_tokens = sum(len(text.split()) for text in inputs) | |
| logger.info(f"估算token数: {prompt_tokens}") | |
| return EmbeddingResponse( | |
| data=data, | |
| model=req.model, | |
| usage={"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens} | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"处理嵌入请求时发生错误: {str(e)}") | |
| async def health_check(request: Request): | |
| logger.info("\n===== 健康检查请求信息 =====") | |
| logger.info(f"请求方法: {request.method}") | |
| logger.info(f"请求URL: {request.url}") | |
| logger.info("请求头部:") | |
| for name, value in request.headers.items(): | |
| logger.info(f" {name}: {value}") | |
| logger.info("===============================\n") | |
| return {"status": "healthy", "models": list(MODEL_MAPPING.keys()) + list(models.keys())} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| logger.info("启动服务") | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |