File size: 5,851 Bytes
2a769ef
9493f68
 
c6533bc
9493f68
 
4eb2455
1e5bdfa
 
 
 
 
 
 
4eb2455
 
1e5bdfa
 
 
d7a364f
9493f68
c6533bc
9493f68
 
 
 
 
 
 
 
c6533bc
9493f68
 
 
66a70c6
 
 
c6533bc
 
4eb2455
9493f68
 
 
1e5bdfa
75911cc
 
b532619
75911cc
 
 
 
 
 
 
 
 
9493f68
1e5bdfa
b532619
75911cc
 
 
 
1e5bdfa
 
75911cc
 
 
 
9493f68
 
b532619
 
9493f68
2a769ef
c8c2946
 
9493f68
c8c2946
 
9493f68
1e5bdfa
9493f68
 
2a769ef
9493f68
4eb2455
9493f68
2a769ef
9493f68
2a769ef
9493f68
 
 
 
c6533bc
9493f68
 
 
 
 
c6533bc
2a769ef
9493f68
2a769ef
4eb2455
 
9493f68
4eb2455
6a7f479
 
 
 
 
 
2a769ef
6a7f479
2a769ef
4eb2455
65318bb
9493f68
65318bb
 
1e5bdfa
9493f68
1e5bdfa
9493f68
1e5bdfa
9493f68
 
 
 
 
 
 
1e5bdfa
9493f68
 
 
65318bb
9493f68
 
 
4eb2455
c6533bc
c8c2946
6a7f479
 
 
 
 
 
 
 
4eb2455
d7a364f
9493f68
 
1e5bdfa
4eb2455
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from fastapi import FastAPI, HTTPException, Depends, Request, Header
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import os
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Optional, Union  # 导入Union
import logging

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s-%(name)s-%(levelname)s-%(message)s",
    handlers=[
        logging.FileHandler("embedding_service.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger("embedding_service")

app = FastAPI()

# 允许跨域请求
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 模型映射:OpenAI模型名 → 开源模型名
MODEL_MAPPING = {
    "text-embedding-3-small": "BAAI/bge-small-en-v1.5",
    "text-embedding-3-large": "BAAI/bge-large-en-v1.5",
    "bge-small-en-v1.5": "BAAI/bge-small-en-v1.5",
    "bge-large-en-v1.5": "BAAI/bge-large-en-v1.5"
}

# 加载模型(懒加载)
models = {}

def get_model(model_name: str):
    logger.info(f"尝试获取模型: {model_name}")
    # 1. 定义所有支持的模型(映射名 + 直接支持的模型名)
    supported_models = set(MODEL_MAPPING.keys())  # 包含text-embedding-3-*和bge-*
    model_to_load = MODEL_MAPPING.get(model_name, model_name)

    # 2. 提前拦截无效模型:若不在支持列表且非已知机构前缀,直接返回400
    known_prefixes = ("BAAI/", "sentence-transformers/")  # 允许合法机构的模型
    if (model_name not in supported_models) and (not model_to_load.startswith(known_prefixes)):
        error_msg = f"不支持的模型: {model_name}"
        logger.error(error_msg)
        raise HTTPException(status_code=400, detail=error_msg)

    # 3. 加载支持的模型(含合法机构前缀的模型)
    if model_name not in models:
        try:
            hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN")
            models[model_name] = SentenceTransformer(
                model_to_load,
                use_auth_token=hf_token
            )
            logger.info(f"模型 {model_name} 加载成功")
        except Exception as e:
            # 若合法模型加载失败(如网络问题),返回500;无效模型已提前拦截
            error_msg = f"加载模型 {model_name} 失败: {str(e)}"
            logger.error(error_msg)
            raise HTTPException(status_code=500, detail=error_msg)
    return models[model_name]



# 验证API密钥
def verify_api_key(authorization: Optional[str] = Header(None)):
    logger.info(f"Authorization头部内容: {authorization}")
    if not authorization or not authorization.startswith("Bearer "):
        raise HTTPException(status_code=401, detail="未提供有效的API密钥")
    api_key = authorization[len("Bearer "):]
    if api_key != os.getenv("API_KEY"):
        raise HTTPException(status_code=401, detail="无效的API密钥")
    logger.info("API密钥验证通过")
    return True

# 请求体模型
class EmbeddingRequest(BaseModel):
    input: Union[str, List[str]]  # 支持str或List[str]
    model: str
    encoding_format: Optional[str] = "float"

# 响应体模型
class EmbeddingData(BaseModel):
    object: str = "embedding"
    embedding: List[float]
    index: int

class EmbeddingResponse(BaseModel):
    object: str = "list"
    data: List[EmbeddingData]
    model: str
    usage: dict = {"prompt_tokens": 0, "total_tokens": 0}

@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def create_embedding(
    request: Request,
    req: EmbeddingRequest,
    _: bool = Depends(verify_api_key)
):
    # 打印请求信息
    logger.info("\n===== 接收到的完整请求信息 =====")
    logger.info(f"请求方法: {request.method}")
    logger.info(f"请求URL: {request.url}")
    logger.info("请求头部:")
    for name, value in request.headers.items():
        logger.info(f"  {name}: {value}")
    logger.info(f"请求体: {await request.body()}")
    logger.info("===============================\n")
    
    # 嵌入生成逻辑
    logger.info(f"收到嵌入请求,模型: {req.model}, 输入类型: {type(req.input)}")
    try:
        model = get_model(req.model)
        inputs = [req.input] if isinstance(req.input, str) else req.input
        logger.info(f"处理输入,文本数量: {len(inputs)}")
        
        logger.info("开始计算嵌入")
        embeddings = model.encode(inputs, normalize_embeddings=True)
        logger.info(f"嵌入计算完成,嵌入形状: {embeddings.shape}")
        
        data = [
            EmbeddingData(embedding=embedding.tolist(), index=i)
            for i, embedding in enumerate(embeddings)
        ]
        
        prompt_tokens = sum(len(text.split()) for text in inputs)
        logger.info(f"估算token数: {prompt_tokens}")
        
        return EmbeddingResponse(
            data=data,
            model=req.model,
            usage={"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens}
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"处理嵌入请求时发生错误: {str(e)}")

@app.get("/health")
async def health_check(request: Request):
    logger.info("\n===== 健康检查请求信息 =====")
    logger.info(f"请求方法: {request.method}")
    logger.info(f"请求URL: {request.url}")
    logger.info("请求头部:")
    for name, value in request.headers.items():
        logger.info(f"  {name}: {value}")
    logger.info("===============================\n")
    return {"status": "healthy", "models": list(MODEL_MAPPING.keys()) + list(models.keys())}

if __name__ == "__main__":
    import uvicorn
    logger.info("启动服务")
    uvicorn.run(app, host="0.0.0.0", port=7860)