fiewolf1000 commited on
Commit
8bfbcde
·
verified ·
1 Parent(s): 74e864a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -182
app.py CHANGED
@@ -1,210 +1,166 @@
1
-
2
- from fastapi.security import APIKeyQuery, APIKeyHeader
3
- from typing import List, Optional
4
- from datetime import datetime # 需在文件开头导入
5
  import os
6
  import uuid
7
  from datetime import datetime
8
  from fastapi import FastAPI, HTTPException, Depends
9
- from fastapi.responses import MarkdownResponse # 确保导入这行
 
10
  from pydantic import BaseModel
11
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
  import torch
13
- from fastapi.responses import HTMLResponse # 替换为 HTMLResponse
14
 
15
- # 【双重保障】设置 Hugging Face 缓存目录(与 Dockerfile 一致)
 
16
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache"
17
  os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache"
18
 
19
- # 初始化 FastAPI 应用
20
- app = FastAPI(title="Cross-Encoder 重排序 API(兼容 GPT 格式)")
21
-
22
- # 后续代码不变(API 认证、模型加载、根路径首页、/v1/chat/completions 接口等)
23
- # ...
24
-
25
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache"
26
- class CrossEncoderWrapper:
27
- def __init__(self):
28
- self.model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
29
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
30
- self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
31
 
32
- # 1. 初始化FastAPI应用
33
  app = FastAPI(
34
- title="Cross-Encoder 重排序API",
35
- description="基于 cross-encoder/ms-marco-MiniLM-L-6-v2 的文本相关性排序接口",
36
  version="1.0.0"
37
  )
38
 
39
- # 2. API Key 认证配置(支持HeaderQuery参数传递)
40
- API_KEY = os.getenv("CROSS_ENCODER_API_KEY") # 生产环境从环境变量获取,避免硬编码
41
- if not API_KEY:
42
- raise ValueError("请先设置环境变量 CROSS_ENCODER_API_KEY")
43
-
44
- # 支持两种认证方式:Header(推荐,更安全)或 Query(备用)
45
- api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False, description="通过Header传递API Key")
46
- api_key_query = APIKeyQuery(name="api_key", auto_error=False, description="通过URL参数传递API Key,如 ?api_key=xxx")
 
 
 
 
47
 
48
- def get_api_key(
49
  header_key: Optional[str] = Depends(api_key_header),
50
  query_key: Optional[str] = Depends(api_key_query)
51
  ) -> str:
52
- """验证API Key,优先取Header中的值,其次取Query中的值"""
53
  if header_key == API_KEY:
54
  return header_key
55
  elif query_key == API_KEY:
56
  return query_key
57
  raise HTTPException(
58
  status_code=401,
59
- detail="无效或缺失API Key(支持Header: X-API-Key 或 Query: ?api_key=xxx)",
60
  headers={"WWW-Authenticate": "X-API-Key"}
61
  )
62
 
63
- # 3. 定义请求/响应数据模型(标准化格式)
64
  class RerankRequest(BaseModel):
65
- """重排序请求模型"""
66
  query: str # 用户查询(如“什么是机器学习?”)
67
- documents: List[str] # 候选文档列表(需排序的文本集合)
68
- top_k: Optional[int] = 5 # 需返回的Top N高相关文档,默认5
69
- truncation: Optional[bool] = True # 是否截断过长文本,默认True
70
 
71
  class DocumentScore(BaseModel):
72
- """单篇文档的排序结果(含分数)"""
73
  document: str # 文档内容
74
  score: float # 相关性分数(越高越相关)
75
- rank: int # 排序名次(1为最高)
76
 
77
  class RerankResponse(BaseModel):
78
- """重排序响应模型"""
79
-
80
-
81
-
82
- request_id: str # 请求唯一标识(便于排查问题)
83
- query: str # 回显请求的查询
84
- top_k: int # 回显请求的Top K
85
  results: List[DocumentScore] # 排序结果列表
86
- model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" # 使用的模型名称
87
- timestamp: str = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] # 用标准库生成时间戳
88
-
89
-
90
- # 4. 加载Cross-Encoder模型(全局初始化,避免重复加载)
91
- class CrossEncoderLoader:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
 
 
93
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
94
  self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
95
- # 自动使用GPU(若有),否则用CPU
96
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
97
  self.model.to(self.device)
98
- self.model.eval() # 推理模式,关闭Dropout
99
- print(f"模型加载完成,使用设备:{self.device}")
100
 
101
  def rerank(self, query: str, documents: List[str], top_k: int, truncation: bool) -> List[DocumentScore]:
102
- """
103
- 核心重排序逻辑
104
- :param query: 用户查询
105
- :param documents: 候选文档列表
106
- :param top_k: 返回Top N
107
- :param truncation: 是否截断文本
108
- :return: 排序后的DocumentScore列表
109
- """
110
  if not documents:
111
  raise ValueError("候选文档列表不能为空")
112
- if top_k <= 0:
113
- raise ValueError("top_k必须为正整数")
114
 
115
  # 计算每篇文档的相关性分数
116
  doc_scores = []
117
  for doc in documents:
118
- # 模型输入格式:query [SEP] document(SEP为模型默认分隔符)
 
119
  inputs = self.tokenizer(
120
- text=f"{query} {self.tokenizer.sep_token} {doc}",
121
  return_tensors="pt",
122
  padding="max_length",
123
  truncation=truncation,
124
- max_length=512 # 模型最大输入长度,MiniLM-L-6-v2支持512
125
  ).to(self.device)
126
 
127
  # 推理(关闭梯度计算,提升速度)
128
  with torch.no_grad():
129
  outputs = self.model(**inputs)
130
- # 模型输出的logits即为相关性分数(无需softmax,直接用原始值)
131
  score = outputs.logits.item()
132
  doc_scores.append((doc, score))
133
 
134
- # 按分数降序排序,取Top K,并添加名次
135
  sorted_docs = sorted(doc_scores, key=lambda x: x[1], reverse=True)[:top_k]
136
- results = [
137
  DocumentScore(
138
  document=doc,
139
- score=round(score, 4), # 分数保留4位小数,便于阅读
140
- rank=i+1 # 名次从1开始
141
  ) for i, (doc, score) in enumerate(sorted_docs)
142
  ]
143
- return results
144
-
145
- # 初始化模型(全局唯一实例)
146
- reranker = CrossEncoderLoader()
147
-
148
- # 5. 定义API端点(标准POST接口)
149
- @app.post(
150
- path="/api/v1/rerank",
151
- response_model=RerankResponse,
152
- description="文本相关性重排序接口:输入查询和候选文档,返回Top K高相关文档及分数"
153
- )
154
- async def rerank_endpoint(
155
- request: RerankRequest,
156
- api_key: str = Depends(get_api_key) # 强制API Key认证
157
- ) -> RerankResponse:
158
- try:
159
- # 生成请求唯一标识(用UUID,需安装:pip install python-uuid)
160
- import uuid
161
- request_id = str(uuid.uuid4())
162
-
163
- # 调用重排序逻辑
164
- results = reranker.rerank(
165
- query=request.query,
166
- documents=request.documents,
167
- top_k=request.top_k,
168
- truncation=request.truncation
169
- )
170
-
171
- # 构造响应
172
- return RerankResponse(
173
- request_id=request_id,
174
- query=request.query,
175
- top_k=request.top_k,
176
- results=results
177
- )
178
- except ValueError as e:
179
- # 业务逻辑错误(如参数无效)
180
- raise HTTPException(status_code=400, detail=str(e))
181
- except Exception as e:
182
- # 服务器内部错误(如模型加载失败)
183
- raise HTTPException(status_code=500, detail=f"服务器内部错误:{str(e)}")
184
-
185
- # 6. 健康检查接口(用于监控服务状态)
186
- @app.get("/api/v1/health", description="服务健康检查接口")
187
- async def health_check():
188
- return {
189
- "status": "healthy",
190
- "model": "cross-encoder/ms-marco-MiniLM-L-6-v2",
191
- "device": reranker.device,
192
- "timestamp": str(pd.Timestamp.now())
193
- }
194
 
 
 
195
 
196
-
197
- # ------------------- 新增:根路径(/)首页路由 -------------------
198
  @app.get("/", response_class=HTMLResponse, description="API 首页(含调用指南)")
199
  async def home_page():
200
- """根路径首页:用 HTML 渲染,避免 MarkdownResponse 依赖问题"""
201
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
202
  return f"""
203
  <!DOCTYPE html>
204
  <html lang="zh-CN">
205
  <head>
206
  <meta charset="UTF-8">
207
- <title>Cross-Encoder 重排序 API(兼容 GPT 格式)</title>
208
  <style>
209
  body {{ font-family: Arial, sans-serif; max-width: 1200px; margin: 0 auto; padding: 20px; }}
210
  h1 {{ color: #2c3e50; border-bottom: 2px solid #3498db; padding-bottom: 10px; }}
@@ -214,31 +170,52 @@ async def home_page():
214
  th, td {{ border: 1px solid #e9ecef; padding: 12px; text-align: left; }}
215
  th {{ background-color: #f1f5f9; }}
216
  .note {{ color: #6c757d; font-size: 0.9em; }}
 
217
  </style>
218
  </head>
219
  <body>
220
- <h1>Cross-Encoder 重排序 API(兼容 GPT 格式)</h1>
221
  <p>基于 <code>cross-encoder/ms-marco-MiniLM-L-6-v2</code> 模型,提供文本相关性排序服务,支持 GPT 标准 API 调用格式。</p>
222
 
223
  <h2>核心功能</h2>
224
  <ul>
225
  <li>输入「查询语句 + 候选文档列表」,返回按相关性降序排列的结果(含分数、排名)</li>
226
- <li>兼容 OpenAI 风格 API 格式,可直接用 OpenAI 库调用</li>
227
- <li>支持 API Key 认证,保障接口安全</li>
228
  </ul>
229
 
230
- <h2>接口地址</h2>
231
- <h3>重排序接口(兼容 GPT)</h3>
232
- <ul>
233
- <li><strong>URL</strong>: <code>{app.root_path}/v1/chat/completions</code></li>
234
- <li><strong>方法</strong>: <code>POST</code></li>
235
- <li><strong>认证</strong>: 需在 Header 中添加 <code>Authorization: Bearer &lt;你的 API Key&gt;</code>(API Key 在 Hugging Face Spaces 环境变量中配置)</li>
236
- </ul>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
- <h2>调用示例(Python)</h2>
239
  <pre><code>from openai import OpenAI
240
 
241
- # 配置客户端(指向当前 Space 地址)
242
  client = OpenAI(
243
  api_key="your-api-key-here", # 替换为你的 API Key
244
  base_url="https://&lt;your-username&gt;-&lt;your-space-name&gt;.hf.space/v1" # 替换为你的 Space URL
@@ -253,64 +230,120 @@ response = client.chat.completions.create(
253
  "content": "query: 什么是机器学习?; documents: 机器学习是AI的分支; Python是编程语言; 深度学习是机器学习的子集;"
254
  }}
255
  ],
256
- top_k=2 # 可选,返回 Top 2 高相关文档
257
  )
258
 
259
  # 打印结果
260
  print(response.choices[0].message.content)</code></pre>
261
 
262
- <h2>请求参数说明</h2>
263
- <table>
264
- <tr>
265
- <th>参数</th>
266
- <th>类型</th>
267
- <th>说明</th>
268
- </tr>
269
- <tr>
270
- <td><code>model</code></td>
271
- <td>string</td>
272
- <td>固定为 <code>cross-encoder/ms-marco-MiniLM-L-6-v2</code>,不可修改</td>
273
- </tr>
274
- <tr>
275
- <td><code>messages</code></td>
276
- <td>list</td>
277
- <td>消息列表,最后一条必须是 <code>role: user</code> 的消息</td>
278
- </tr>
279
- <tr>
280
- <td><code>messages[].content</code></td>
281
- <td>string</td>
282
- <td>格式:<code>query: [你的查询]; documents: [文档1]; [文档2]; ...</code>(文档用分号分隔)</td>
283
- </tr>
284
- <tr>
285
- <td><code>top_k</code></td>
286
- <td>int</td>
287
- <td>可选,默认返回 Top 3 文档,范围 1~20</td>
288
- </tr>
289
- </table>
290
-
291
- <h2>健康检查</h2>
292
  <ul>
293
- <li><strong>URL</strong>: <code>{app.root_path}/health</code></li>
294
- <li><strong>方法</strong>: <code>GET</code></li>
295
- <li><strong>说明</strong>: 无需认证,用于检查服务是否正常运行</li>
296
  </ul>
297
 
298
- <p class="note">页面生成时间: {current_time}</p>
299
  </body>
300
  </html>
301
  """
302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
 
 
 
 
 
 
 
 
 
 
305
 
306
-
307
- # 7. 本地运行入口(开发环境用)
308
  if __name__ == "__main__":
309
  import uvicorn
310
- # 安装uvicorn:pip install uvicorn
311
  uvicorn.run(
312
  app="app:app",
313
- host="0.0.0.0", # 允许外部访问
314
- port=7860, # 端口号
315
- reload=False # 生产环境关闭reload
316
  )
 
 
 
 
 
1
  import os
2
  import uuid
3
  from datetime import datetime
4
  from fastapi import FastAPI, HTTPException, Depends
5
+ from fastapi.security import APIKeyHeader, APIKeyQuery
6
+ from fastapi.responses import HTMLResponse # 仅保留 HTMLResponse,删除 MarkdownResponse
7
  from pydantic import BaseModel
8
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
  import torch
10
+ from typing import List, Optional
11
 
12
+ # ------------------- 1. 基础配置(缓存目录 + 环境变量) -------------------
13
+ # 设置 Hugging Face 缓存目录(可写目录,解决权限问题)
14
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache"
15
  os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache"
16
 
17
+ # 从环境变量获取 API Key(认证用,需在 Hugging Face Spaces 中配置)
18
+ API_KEY = os.getenv("CROSS_ENCODER_API_KEY")
19
+ if not API_KEY:
20
+ raise ValueError("请在 Hugging Face Spaces 中设置环境变量 CROSS_ENCODER_API_KEY")
 
 
 
 
 
 
 
 
21
 
22
+ # ------------------- 2. 初始化 FastAPI 应用(仅初始化一次) -------------------
23
  app = FastAPI(
24
+ title="Cross-Encoder 重排序 API",
25
+ description="基于 cross-encoder/ms-marco-MiniLM-L-6-v2 的文本相关性排序接口(兼容 GPT 格式)",
26
  version="1.0.0"
27
  )
28
 
29
+ # ------------------- 3. API Key 认证配置(支持 Header/Query 两种方式) -------------------
30
+ # 支持 Header(推荐)和 Query(备用)传递 API Key
31
+ api_key_header = APIKeyHeader(
32
+ name="X-API-Key",
33
+ auto_error=False,
34
+ description="通过 Header 传递 API Key(推荐)"
35
+ )
36
+ api_key_query = APIKeyQuery(
37
+ name="api_key",
38
+ auto_error=False,
39
+ description="通过 URL 参数传递 API Key(如 ?api_key=xxx)"
40
+ )
41
 
42
+ def verify_api_key(
43
  header_key: Optional[str] = Depends(api_key_header),
44
  query_key: Optional[str] = Depends(api_key_query)
45
  ) -> str:
46
+ """验证 API Key,优先使用 Header 中的值"""
47
  if header_key == API_KEY:
48
  return header_key
49
  elif query_key == API_KEY:
50
  return query_key
51
  raise HTTPException(
52
  status_code=401,
53
+ detail="无效或缺失 API Key(支持:Header: X-API-Key 或 Query: ?api_key=xxx)",
54
  headers={"WWW-Authenticate": "X-API-Key"}
55
  )
56
 
57
+ # ------------------- 4. 数据模型定义(请求/响应格式) -------------------
58
  class RerankRequest(BaseModel):
59
+ """重排序请求模型(支持基础重排序 + GPT 兼容格式)"""
60
  query: str # 用户查询(如“什么是机器学习?”)
61
+ documents: List[str] # 候选文档列表(需排序的文本)
62
+ top_k: Optional[int] = 3 # 返回 Top N 高相关文档,默认 3
63
+ truncation: Optional[bool] = True # 是否截断过长文本(模型最大输入 512 Token)
64
 
65
  class DocumentScore(BaseModel):
66
+ """单篇文档的排序结果(含分数��排名)"""
67
  document: str # 文档内容
68
  score: float # 相关性分数(越高越相关)
69
+ rank: int # 排序名次(1 为最高)
70
 
71
  class RerankResponse(BaseModel):
72
+ """重排序响应模型(标准化格式)"""
73
+ request_id: str # 请求唯一标识(用于排查问题)
74
+ query: str # 回显用户查询
75
+ top_k: int # 回显返回的 Top N 数量
 
 
 
76
  results: List[DocumentScore] # 排序结果列表
77
+ model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" # 使用的模型名
78
+ timestamp: str = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] # 时间戳
79
+
80
+ # GPT 兼容格式的请求模型(适配 /v1/chat/completions 接口)
81
+ class GPTMessage(BaseModel):
82
+ role: str # 仅支持 "user" 角色
83
+ content: str # 格式:"query: [查询]; documents: [文档1]; [文档2]; ..."
84
+
85
+ class GPTRequest(BaseModel):
86
+ model: str # 固定为模型名,用于兼容 GPT 调用格式
87
+ messages: List[GPTMessage] # GPT 风格的消息列表
88
+ top_k: Optional[int] = 3 # 同 RerankRequest 的 top_k
89
+
90
+ class GPTResponse(BaseModel):
91
+ """GPT 兼容的响应模型(模仿 OpenAI 格式)"""
92
+ id: str = f"rerank-{uuid.uuid4().hex[:10]}"
93
+ object: str = "chat.completion"
94
+ created: int = int(datetime.now().timestamp())
95
+ model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
96
+ choices: List[dict] = [] # 存储排序结果
97
+ usage: dict = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
98
+
99
+ # ------------------- 5. 加载 Cross-Encoder 模型(全局唯一实例) -------------------
100
+ class CrossEncoderModel:
101
  def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
102
+ self.model_name = model_name
103
+ # 加载分词器和模型(从缓存目录加载,避免权限问题)
104
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
105
  self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
106
+ # 自动选择设备(GPU 优先,无则用 CPU
107
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
108
  self.model.to(self.device)
109
+ self.model.eval() # 推理模式,关闭 Dropout
110
+ print(f"模型加载完成!使用设备:{self.device}")
111
 
112
  def rerank(self, query: str, documents: List[str], top_k: int, truncation: bool) -> List[DocumentScore]:
113
+ """核心重排序逻辑:计算查询与文档的相关性并排序"""
114
+ # 参数校验
 
 
 
 
 
 
115
  if not documents:
116
  raise ValueError("候选文档列表不能为空")
117
+ if top_k <= 0 or top_k > len(documents):
118
+ raise ValueError(f"top_k 需在 1~{len(documents)} 之间")
119
 
120
  # 计算每篇文档的相关性分数
121
  doc_scores = []
122
  for doc in documents:
123
+ # 模型输入格式:query [SEP] document(SEP 是模型默认分隔符)
124
+ input_text = f"{query} {self.tokenizer.sep_token} {doc}"
125
  inputs = self.tokenizer(
126
+ input_text,
127
  return_tensors="pt",
128
  padding="max_length",
129
  truncation=truncation,
130
+ max_length=512 # 模型最大输入长度(MiniLM-L-6-v2 支持 512 Token)
131
  ).to(self.device)
132
 
133
  # 推理(关闭梯度计算,提升速度)
134
  with torch.no_grad():
135
  outputs = self.model(**inputs)
136
+ # 模型输出的 logits 即为相关性分数(无需 softmax,直接使用原始值)
137
  score = outputs.logits.item()
138
  doc_scores.append((doc, score))
139
 
140
+ # 按分数降序排序,取 Top K 并生成结果
141
  sorted_docs = sorted(doc_scores, key=lambda x: x[1], reverse=True)[:top_k]
142
+ return [
143
  DocumentScore(
144
  document=doc,
145
+ score=round(score, 4), # 分数保留 4 位小数,便于阅读
146
+ rank=i+1 # 名次从 1 开始
147
  ) for i, (doc, score) in enumerate(sorted_docs)
148
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
+ # 初始化模型(全局唯一,避免重复加载)
151
+ reranker = CrossEncoderModel()
152
 
153
+ # ------------------- 6. API 端点定义 -------------------
154
+ # 6.1 根路径首页(HTML 格式,无 Markdown 依赖)
155
  @app.get("/", response_class=HTMLResponse, description="API 首页(含调用指南)")
156
  async def home_page():
 
157
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
158
  return f"""
159
  <!DOCTYPE html>
160
  <html lang="zh-CN">
161
  <head>
162
  <meta charset="UTF-8">
163
+ <title>Cross-Encoder 重排序 API</title>
164
  <style>
165
  body {{ font-family: Arial, sans-serif; max-width: 1200px; margin: 0 auto; padding: 20px; }}
166
  h1 {{ color: #2c3e50; border-bottom: 2px solid #3498db; padding-bottom: 10px; }}
 
170
  th, td {{ border: 1px solid #e9ecef; padding: 12px; text-align: left; }}
171
  th {{ background-color: #f1f5f9; }}
172
  .note {{ color: #6c757d; font-size: 0.9em; }}
173
+ .api-url {{ color: #3498db; font-weight: bold; }}
174
  </style>
175
  </head>
176
  <body>
177
+ <h1>Cross-Encoder 重排序 API</h1>
178
  <p>基于 <code>cross-encoder/ms-marco-MiniLM-L-6-v2</code> 模型,提供文本相关性排序服务,支持 GPT 标准 API 调用格式。</p>
179
 
180
  <h2>核心功能</h2>
181
  <ul>
182
  <li>输入「查询语句 + 候选文档列表」,返回按相关性降序排列的结果(含分数、排名)</li>
183
+ <li>支持两种 API 格式:基础重排序接口(/api/v1/rerank)和 GPT 兼容接口(/v1/chat/completions)</li>
184
+ <li>API Key 认证,保障接口安全</li>
185
  </ul>
186
 
187
+ <h2>接口列表</h2>
188
+ <table>
189
+ <tr>
190
+ <th>接口名称</th>
191
+ <th>URL</th>
192
+ <th>方法</th>
193
+ <th>说明</th>
194
+ </tr>
195
+ <tr>
196
+ <td>基础重排序接口</td>
197
+ <td class="api-url">{app.root_path}/api/v1/rerank</td>
198
+ <td>POST</td>
199
+ <td>标准化重排序接口,返回结构化结果</td>
200
+ </tr>
201
+ <tr>
202
+ <td>GPT 兼容接口</td>
203
+ <td class="api-url">{app.root_path}/v1/chat/completions</td>
204
+ <td>POST</td>
205
+ <td>模仿 OpenAI 格式,可直接用 OpenAI 库调用</td>
206
+ </tr>
207
+ <tr>
208
+ <td>健康检查</td>
209
+ <td class="api-url">{app.root_path}/api/v1/health</td>
210
+ <td>GET</td>
211
+ <td>无需认证,检查服务状态</td>
212
+ </tr>
213
+ </table>
214
 
215
+ <h2>调用示例(GPT 兼容接口)</h2>
216
  <pre><code>from openai import OpenAI
217
 
218
+ # 配置客户端(指向你的 Space 地址)
219
  client = OpenAI(
220
  api_key="your-api-key-here", # 替换为你的 API Key
221
  base_url="https://&lt;your-username&gt;-&lt;your-space-name&gt;.hf.space/v1" # 替换为你的 Space URL
 
230
  "content": "query: 什么是机器学习?; documents: 机器学习是AI的分支; Python是编程语言; 深度学习是机器学习的子集;"
231
  }}
232
  ],
233
+ top_k=2 # 返回 Top 2 高相关文档
234
  )
235
 
236
  # 打印结果
237
  print(response.choices[0].message.content)</code></pre>
238
 
239
+ <h2>API Key 认证方式</h2>
240
+ <p>所有 POST 接口需通过以下方式之一传递 API Key:</p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  <ul>
242
+ <li><strong>Header 方式(推荐)</strong>:在请求 Header 中添加 <code>X-API-Key: your-api-key</code></li>
243
+ <li><strong>Query 方式(备用)</strong>:在 URL 后添加 <code>?api_key=your-api-key</code></li>
 
244
  </ul>
245
 
246
+ <p class="note">页面生成时间: {current_time} | 模型运行设备: {reranker.device}</p>
247
  </body>
248
  </html>
249
  """
250
 
251
+ # 6.2 基础重排序接口(标准化格式)
252
+ @app.post(
253
+ "/api/v1/rerank",
254
+ response_model=RerankResponse,
255
+ description="基础重排序接口,返回结构化的排序结果"
256
+ )
257
+ async def base_rerank(
258
+ request: RerankRequest,
259
+ api_key: str = Depends(verify_api_key)
260
+ ):
261
+ try:
262
+ # 执行重排序
263
+ results = reranker.rerank(
264
+ query=request.query,
265
+ documents=request.documents,
266
+ top_k=request.top_k,
267
+ truncation=request.truncation
268
+ )
269
+ # 生成响应
270
+ return RerankResponse(
271
+ request_id=str(uuid.uuid4()),
272
+ query=request.query,
273
+ top_k=request.top_k,
274
+ results=results
275
+ )
276
+ except ValueError as e:
277
+ raise HTTPException(status_code=400, detail=str(e))
278
+ except Exception as e:
279
+ raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}")
280
 
281
+ # 6.3 GPT 兼容接口(模仿 OpenAI 格式)
282
+ @app.post(
283
+ "/v1/chat/completions",
284
+ response_model=GPTResponse,
285
+ description="GPT 兼容接口,支持用 OpenAI 库调用"
286
+ )
287
+ async def gpt_compatible_rerank(
288
+ request: GPTRequest,
289
+ api_key: str = Depends(verify_api_key)
290
+ ):
291
+ try:
292
+ # 校验模型名(确保兼容 GPT 调用格式)
293
+ if request.model != reranker.model_name:
294
+ raise ValueError(f"仅支持模型:{reranker.model_name}")
295
+ # 校验消息(仅支持最后一条为 user 角色)
296
+ if not request.messages or request.messages[-1].role != "user":
297
+ raise ValueError("最后一条消息必须是 'user' 角色")
298
+
299
+ # 解析用户输入(从 content 中提取 query 和 documents)
300
+ content = request.messages[-1].content
301
+ if "; documents: " not in content:
302
+ raise ValueError("输入格式错误,��为:'query: [查询]; documents: [文档1]; [文档2]; ...'")
303
+ query_part, docs_part = content.split("; documents: ")
304
+ query = query_part.replace("query: ", "").strip()
305
+ documents = [doc.strip() for doc in docs_part.split(";") if doc.strip()]
306
+
307
+ # 执行重排序
308
+ results = reranker.rerank(
309
+ query=query,
310
+ documents=documents,
311
+ top_k=request.top_k,
312
+ truncation=True
313
+ )
314
+ # 格式化 GPT 风格的响应
315
+ return GPTResponse(
316
+ choices=[{
317
+ "index": 0,
318
+ "message": {
319
+ "role": "assistant",
320
+ "content": f"重排序结果(按相关性降序):\n{[{'文档': r.document, '分数': r.score, '排名': r.rank} for r in results]}"
321
+ },
322
+ "finish_reason": "stop"
323
+ }]
324
+ )
325
+ except ValueError as e:
326
+ raise HTTPException(status_code=400, detail=str(e))
327
+ except Exception as e:
328
+ raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}")
329
 
330
+ # 6.4 健康检查接口
331
+ @app.get("/api/v1/health", description="服务健康检查接口(无需认证)")
332
+ async def health_check():
333
+ return {
334
+ "status": "healthy",
335
+ "model": reranker.model_name,
336
+ "device": reranker.device,
337
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
338
+ "message": "服务正常运行"
339
+ }
340
 
341
+ # ------------------- 7. 本地运行入口(开发环境用) -------------------
 
342
  if __name__ == "__main__":
343
  import uvicorn
 
344
  uvicorn.run(
345
  app="app:app",
346
+ host="0.0.0.0",
347
+ port=7860,
348
+ reload=False # 生产环境关闭 reload
349
  )