Spaces:
Sleeping
Sleeping
File size: 5,851 Bytes
2a769ef 9493f68 c6533bc 9493f68 4eb2455 1e5bdfa 4eb2455 1e5bdfa d7a364f 9493f68 c6533bc 9493f68 c6533bc 9493f68 66a70c6 c6533bc 4eb2455 9493f68 1e5bdfa 75911cc b532619 75911cc 9493f68 1e5bdfa b532619 75911cc 1e5bdfa 75911cc 9493f68 b532619 9493f68 2a769ef c8c2946 9493f68 c8c2946 9493f68 1e5bdfa 9493f68 2a769ef 9493f68 4eb2455 9493f68 2a769ef 9493f68 2a769ef 9493f68 c6533bc 9493f68 c6533bc 2a769ef 9493f68 2a769ef 4eb2455 9493f68 4eb2455 6a7f479 2a769ef 6a7f479 2a769ef 4eb2455 65318bb 9493f68 65318bb 1e5bdfa 9493f68 1e5bdfa 9493f68 1e5bdfa 9493f68 1e5bdfa 9493f68 65318bb 9493f68 4eb2455 c6533bc c8c2946 6a7f479 4eb2455 d7a364f 9493f68 1e5bdfa 4eb2455 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
from fastapi import FastAPI, HTTPException, Depends, Request, Header
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import os
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Optional, Union # 导入Union
import logging
# 配置日志
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s-%(name)s-%(levelname)s-%(message)s",
handlers=[
logging.FileHandler("embedding_service.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger("embedding_service")
app = FastAPI()
# 允许跨域请求
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 模型映射:OpenAI模型名 → 开源模型名
MODEL_MAPPING = {
"text-embedding-3-small": "BAAI/bge-small-en-v1.5",
"text-embedding-3-large": "BAAI/bge-large-en-v1.5",
"bge-small-en-v1.5": "BAAI/bge-small-en-v1.5",
"bge-large-en-v1.5": "BAAI/bge-large-en-v1.5"
}
# 加载模型(懒加载)
models = {}
def get_model(model_name: str):
logger.info(f"尝试获取模型: {model_name}")
# 1. 定义所有支持的模型(映射名 + 直接支持的模型名)
supported_models = set(MODEL_MAPPING.keys()) # 包含text-embedding-3-*和bge-*
model_to_load = MODEL_MAPPING.get(model_name, model_name)
# 2. 提前拦截无效模型:若不在支持列表且非已知机构前缀,直接返回400
known_prefixes = ("BAAI/", "sentence-transformers/") # 允许合法机构的模型
if (model_name not in supported_models) and (not model_to_load.startswith(known_prefixes)):
error_msg = f"不支持的模型: {model_name}"
logger.error(error_msg)
raise HTTPException(status_code=400, detail=error_msg)
# 3. 加载支持的模型(含合法机构前缀的模型)
if model_name not in models:
try:
hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN")
models[model_name] = SentenceTransformer(
model_to_load,
use_auth_token=hf_token
)
logger.info(f"模型 {model_name} 加载成功")
except Exception as e:
# 若合法模型加载失败(如网络问题),返回500;无效模型已提前拦截
error_msg = f"加载模型 {model_name} 失败: {str(e)}"
logger.error(error_msg)
raise HTTPException(status_code=500, detail=error_msg)
return models[model_name]
# 验证API密钥
def verify_api_key(authorization: Optional[str] = Header(None)):
logger.info(f"Authorization头部内容: {authorization}")
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的API密钥")
api_key = authorization[len("Bearer "):]
if api_key != os.getenv("API_KEY"):
raise HTTPException(status_code=401, detail="无效的API密钥")
logger.info("API密钥验证通过")
return True
# 请求体模型
class EmbeddingRequest(BaseModel):
input: Union[str, List[str]] # 支持str或List[str]
model: str
encoding_format: Optional[str] = "float"
# 响应体模型
class EmbeddingData(BaseModel):
object: str = "embedding"
embedding: List[float]
index: int
class EmbeddingResponse(BaseModel):
object: str = "list"
data: List[EmbeddingData]
model: str
usage: dict = {"prompt_tokens": 0, "total_tokens": 0}
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def create_embedding(
request: Request,
req: EmbeddingRequest,
_: bool = Depends(verify_api_key)
):
# 打印请求信息
logger.info("\n===== 接收到的完整请求信息 =====")
logger.info(f"请求方法: {request.method}")
logger.info(f"请求URL: {request.url}")
logger.info("请求头部:")
for name, value in request.headers.items():
logger.info(f" {name}: {value}")
logger.info(f"请求体: {await request.body()}")
logger.info("===============================\n")
# 嵌入生成逻辑
logger.info(f"收到嵌入请求,模型: {req.model}, 输入类型: {type(req.input)}")
try:
model = get_model(req.model)
inputs = [req.input] if isinstance(req.input, str) else req.input
logger.info(f"处理输入,文本数量: {len(inputs)}")
logger.info("开始计算嵌入")
embeddings = model.encode(inputs, normalize_embeddings=True)
logger.info(f"嵌入计算完成,嵌入形状: {embeddings.shape}")
data = [
EmbeddingData(embedding=embedding.tolist(), index=i)
for i, embedding in enumerate(embeddings)
]
prompt_tokens = sum(len(text.split()) for text in inputs)
logger.info(f"估算token数: {prompt_tokens}")
return EmbeddingResponse(
data=data,
model=req.model,
usage={"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"处理嵌入请求时发生错误: {str(e)}")
@app.get("/health")
async def health_check(request: Request):
logger.info("\n===== 健康检查请求信息 =====")
logger.info(f"请求方法: {request.method}")
logger.info(f"请求URL: {request.url}")
logger.info("请求头部:")
for name, value in request.headers.items():
logger.info(f" {name}: {value}")
logger.info("===============================\n")
return {"status": "healthy", "models": list(MODEL_MAPPING.keys()) + list(models.keys())}
if __name__ == "__main__":
import uvicorn
logger.info("启动服务")
uvicorn.run(app, host="0.0.0.0", port=7860) |