fiewolf1000 commited on
Commit
37bee57
·
verified ·
1 Parent(s): 36ef6fd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -0
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Depends
2
+ from fastapi.security import APIKeyQuery, APIKeyHeader
3
+ from pydantic import BaseModel
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ import torch
6
+ import os
7
+ from typing import List, Optional
8
+
9
+ # 1. 初始化FastAPI应用
10
+ app = FastAPI(
11
+ title="Cross-Encoder 重排序API",
12
+ description="基于 cross-encoder/ms-marco-MiniLM-L-6-v2 的文本相关性排序接口",
13
+ version="1.0.0"
14
+ )
15
+
16
+ # 2. API Key 认证配置(支持Header或Query参数传递)
17
+ API_KEY = os.getenv("CROSS_ENCODER_API_KEY") # 生产环境从环境变量获取,避免硬编码
18
+ if not API_KEY:
19
+ raise ValueError("请先设置环境变量 CROSS_ENCODER_API_KEY")
20
+
21
+ # 支持两种认证方式:Header(推荐,更安全)或 Query(备用)
22
+ api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False, description="通过Header传递API Key")
23
+ api_key_query = APIKeyQuery(name="api_key", auto_error=False, description="通过URL参数传递API Key,如 ?api_key=xxx")
24
+
25
+ def get_api_key(
26
+ header_key: Optional[str] = Depends(api_key_header),
27
+ query_key: Optional[str] = Depends(api_key_query)
28
+ ) -> str:
29
+ """验证API Key,优先取Header中的值,其次取Query中的值"""
30
+ if header_key == API_KEY:
31
+ return header_key
32
+ elif query_key == API_KEY:
33
+ return query_key
34
+ raise HTTPException(
35
+ status_code=401,
36
+ detail="无效或缺失API Key(支持Header: X-API-Key 或 Query: ?api_key=xxx)",
37
+ headers={"WWW-Authenticate": "X-API-Key"}
38
+ )
39
+
40
+ # 3. 定义请求/响应数据模型(标准化格式)
41
+ class RerankRequest(BaseModel):
42
+ """重排序请求模型"""
43
+ query: str # 用户查询(如“什么是机器学习?”)
44
+ documents: List[str] # 候选文档列表(需排序的文本集合)
45
+ top_k: Optional[int] = 5 # 需返回的Top N高相关文档,默认5
46
+ truncation: Optional[bool] = True # 是否截断过长文本,默认True
47
+
48
+ class DocumentScore(BaseModel):
49
+ """单篇文档的排序结果(含分数)"""
50
+ document: str # 文档内容
51
+ score: float # 相关性分数(越高越相关)
52
+ rank: int # 排序名次(1为最高)
53
+
54
+ class RerankResponse(BaseModel):
55
+ """重排序响应模型"""
56
+ request_id: str # 请求唯一标识(便于排查问题)
57
+ query: str # 回显请求的查询
58
+ top_k: int # 回显请求的Top K
59
+ results: List[DocumentScore] # 排序结果列表
60
+ model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" # 使用的模型名称
61
+ timestamp: str = str(pd.Timestamp.now()) # 响应时间戳(需安装pandas:pip install pandas)
62
+
63
+ # 4. 加载Cross-Encoder模型(全局初始化,避免重复加载)
64
+ class CrossEncoderLoader:
65
+ def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
66
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
67
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
68
+ # 自动使用GPU(若有),否则用CPU
69
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
70
+ self.model.to(self.device)
71
+ self.model.eval() # 推理模式,关闭Dropout
72
+ print(f"模型加载完成,使用设备:{self.device}")
73
+
74
+ def rerank(self, query: str, documents: List[str], top_k: int, truncation: bool) -> List[DocumentScore]:
75
+ """
76
+ 核心重排序逻辑
77
+ :param query: 用户查询
78
+ :param documents: 候选文档列表
79
+ :param top_k: 返回Top N
80
+ :param truncation: 是否截断文本
81
+ :return: 排序后的DocumentScore列表
82
+ """
83
+ if not documents:
84
+ raise ValueError("候选文档列表不能为空")
85
+ if top_k <= 0:
86
+ raise ValueError("top_k必须为正整数")
87
+
88
+ # 计算每篇文档的相关性分数
89
+ doc_scores = []
90
+ for doc in documents:
91
+ # 模型输入格式:query [SEP] document(SEP为模型默认分隔符)
92
+ inputs = self.tokenizer(
93
+ text=f"{query} {self.tokenizer.sep_token} {doc}",
94
+ return_tensors="pt",
95
+ padding="max_length",
96
+ truncation=truncation,
97
+ max_length=512 # 模型最大输入长度,MiniLM-L-6-v2支持512
98
+ ).to(self.device)
99
+
100
+ # 推理(关闭梯度计算,提升速度)
101
+ with torch.no_grad():
102
+ outputs = self.model(**inputs)
103
+ # 模型输出的logits即为相关性分数(无需softmax,直接用原始值)
104
+ score = outputs.logits.item()
105
+ doc_scores.append((doc, score))
106
+
107
+ # 按分数降序排序,取Top K,并添加名次
108
+ sorted_docs = sorted(doc_scores, key=lambda x: x[1], reverse=True)[:top_k]
109
+ results = [
110
+ DocumentScore(
111
+ document=doc,
112
+ score=round(score, 4), # 分数保留4位小数,便于阅读
113
+ rank=i+1 # 名次从1开始
114
+ ) for i, (doc, score) in enumerate(sorted_docs)
115
+ ]
116
+ return results
117
+
118
+ # 初始化模型(全局唯一实例)
119
+ reranker = CrossEncoderLoader()
120
+
121
+ # 5. 定义API端点(标准POST接口)
122
+ @app.post(
123
+ path="/api/v1/rerank",
124
+ response_model=RerankResponse,
125
+ description="文本相关性重排序接口:输入查询和候选文档,返回Top K高相关文档及分数"
126
+ )
127
+ async def rerank_endpoint(
128
+ request: RerankRequest,
129
+ api_key: str = Depends(get_api_key) # 强制API Key认证
130
+ ) -> RerankResponse:
131
+ try:
132
+ # 生成请求唯一标识(用UUID,需安装:pip install python-uuid)
133
+ import uuid
134
+ request_id = str(uuid.uuid4())
135
+
136
+ # 调用重排序逻辑
137
+ results = reranker.rerank(
138
+ query=request.query,
139
+ documents=request.documents,
140
+ top_k=request.top_k,
141
+ truncation=request.truncation
142
+ )
143
+
144
+ # 构造响应
145
+ return RerankResponse(
146
+ request_id=request_id,
147
+ query=request.query,
148
+ top_k=request.top_k,
149
+ results=results
150
+ )
151
+ except ValueError as e:
152
+ # 业务逻辑错误(如参数无效)
153
+ raise HTTPException(status_code=400, detail=str(e))
154
+ except Exception as e:
155
+ # 服务器内部错误(如模型加载失败)
156
+ raise HTTPException(status_code=500, detail=f"服务器内部错误:{str(e)}")
157
+
158
+ # 6. 健康检查接口(用于监控服务状态)
159
+ @app.get("/api/v1/health", description="服务健康检查接口")
160
+ async def health_check():
161
+ return {
162
+ "status": "healthy",
163
+ "model": "cross-encoder/ms-marco-MiniLM-L-6-v2",
164
+ "device": reranker.device,
165
+ "timestamp": str(pd.Timestamp.now())
166
+ }
167
+
168
+ # 7. 本地运行入口(开发环境用)
169
+ if __name__ == "__main__":
170
+ import uvicorn
171
+ # 安装uvicorn:pip install uvicorn
172
+ uvicorn.run(
173
+ app="app:app",
174
+ host="0.0.0.0", # 允许外部访问
175
+ port=8000, # 端口号
176
+ reload=False # 生产环境关闭reload
177
+ )