fiewolf1000 commited on
Commit
6935049
·
verified ·
1 Parent(s): c025244

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -192
app.py CHANGED
@@ -1,179 +1,137 @@
1
  import os
2
  import uuid
3
  from datetime import datetime
4
- from fastapi import FastAPI, HTTPException, Depends
5
- from fastapi.security import APIKeyHeader, APIKeyQuery
6
- from fastapi.responses import HTMLResponse # 仅保留 HTMLResponse,删除 MarkdownResponse
7
  from pydantic import BaseModel
8
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
  import torch
10
  from typing import List, Optional
11
 
12
- # ------------------- 1. 基础配置(缓存目录 + 环境变量) -------------------
13
- # 设置 Hugging Face 缓存目录(可写目录,解决权限问题)
14
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache"
15
  os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache"
16
 
17
- # 从环境变量获取 API Key(认证用,需在 Hugging Face Spaces 中配置)
18
- API_KEY = os.getenv("CROSS_ENCODER_API_KEY")
19
  if not API_KEY:
20
- raise ValueError("请在 Hugging Face Spaces 中设置环境变量 CROSS_ENCODER_API_KEY")
21
 
22
- # ------------------- 2. 初始化 FastAPI 应用(仅初始化一次) -------------------
23
  app = FastAPI(
24
- title="Cross-Encoder 重排序 API",
25
- description="基于 cross-encoder/ms-marco-MiniLM-L-6-v2 的文本相关性排序接口(兼容 GPT 格式)",
26
  version="1.0.0"
27
  )
28
 
29
- # ------------------- 3. API Key 认证配置(支持 Header/Query 两种方式) -------------------
30
- # 支持 Header(推荐)和 Query(备用)传递 API Key
31
- api_key_header = APIKeyHeader(
32
- name="X-API-Key",
33
- auto_error=False,
34
- description="通过 Header 传递 API Key(推荐)"
35
- )
36
- api_key_query = APIKeyQuery(
37
- name="api_key",
38
- auto_error=False,
39
- description="通过 URL 参数传递 API Key(如 ?api_key=xxx)"
40
- )
41
 
42
- def verify_api_key(
43
- header_key: Optional[str] = Depends(api_key_header),
44
- query_key: Optional[str] = Depends(api_key_query)
45
- ) -> str:
46
- """验证 API Key,优先使用 Header 中的值"""
47
- if header_key == API_KEY:
48
- return header_key
49
- elif query_key == API_KEY:
50
- return query_key
51
- raise HTTPException(
52
- status_code=401,
53
- detail="无效或缺失 API Key(支持:Header: X-API-Key 或 Query: ?api_key=xxx)",
54
- headers={"WWW-Authenticate": "X-API-Key"}
55
- )
56
 
57
- # ------------------- 4. 数据模型定义(请求/响应格式) -------------------
58
  class RerankRequest(BaseModel):
59
- """重排序请求模型(支持基础重排序 + GPT 兼容格式)"""
60
- query: str # 用户查询(如“什么是机器学习?”)
61
- documents: List[str] # 候选文档列表(需排序的文本)
62
- top_k: Optional[int] = 3 # 返回 Top N 高相关文档,默认 3
63
- truncation: Optional[bool] = True # 是否截断过长文本(模型最大输入 512 Token)
64
 
65
  class DocumentScore(BaseModel):
66
- """单篇文档的排序结果(含分数和排名)"""
67
- document: str # 文档内容
68
- score: float # 相关性分数(越高越相关)
69
- rank: int # 排序名次(1 为最高)
70
 
71
  class RerankResponse(BaseModel):
72
- """重排序响应模型(标准化格式)"""
73
- request_id: str # 请求唯一标识(用于排查问题)
74
- query: str # 回显用户查询
75
- top_k: int # 回显返回的 Top N 数量
76
- results: List[DocumentScore] # 排序结果列表
77
- model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" # 使用的模型名
78
- timestamp: str = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] # 时间戳
79
 
80
- # GPT 兼容格式的请求模型(适配 /v1/chat/completions 接口)
81
  class GPTMessage(BaseModel):
82
- role: str # 仅支持 "user" 角色
83
- content: str # 格式:"query: [查询]; documents: [文档1]; [文档2]; ..."
84
 
85
  class GPTRequest(BaseModel):
86
- model: str # 固定为模型名,用于兼容 GPT 调用格式
87
- messages: List[GPTMessage] # GPT 风格的消息列表
88
- top_k: Optional[int] = 3 # 同 RerankRequest 的 top_k
 
 
 
 
 
89
 
90
  class GPTResponse(BaseModel):
91
- """GPT 兼容的响应模型(模仿 OpenAI 格式)"""
92
- id: str = f"rerank-{uuid.uuid4().hex[:10]}"
93
  object: str = "chat.completion"
94
  created: int = int(datetime.now().timestamp())
95
- model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
96
- choices: List[dict] = [] # 存储排序结果
97
  usage: dict = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
98
 
99
- # ------------------- 5. 加载 Cross-Encoder 模型(全局唯一实例) -------------------
100
- # 在 CrossEncoderModel 类的 __init__ 方法前添加缓存目录验证
101
  class CrossEncoderModel:
102
  def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
103
  self.model_name = model_name
104
-
105
- # 【新增】验证缓存目录是否可写
106
  cache_dir = os.environ.get("TRANSFORMERS_CACHE", "/tmp/huggingface_cache")
107
  try:
108
- # 尝试在缓存目录创建测试文件,验证权限
109
- test_file = os.path.join(cache_dir, "test_write_permission.txt")
110
  with open(test_file, "w") as f:
111
  f.write("test")
112
- os.remove(test_file) # 验证后删除测试文件
113
- print(f"缓存目录权限验证通过:{cache_dir}")
114
  except Exception as e:
115
- raise RuntimeError(f"缓存目录不可写,请检查权限:{cache_dir},错误:{str(e)}")
116
-
117
- # 加载模型(确保使用指定的缓存目录)
118
- self.tokenizer = AutoTokenizer.from_pretrained(
119
- model_name,
120
- cache_dir=cache_dir # 显式指定缓存目录
121
- )
122
- self.model = AutoModelForSequenceClassification.from_pretrained(
123
- model_name,
124
- cache_dir=cache_dir # 显式指定缓存目录
125
- )
126
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
127
  self.model.to(self.device)
128
  self.model.eval()
129
- print(f"模型加载完成!使用设备:{self.device}")
130
-
131
-
132
 
133
  def rerank(self, query: str, documents: List[str], top_k: int, truncation: bool) -> List[DocumentScore]:
134
- """核心重排序逻辑:计算查询与文档的相关性并排序"""
135
- # 参数校验
136
  if not documents:
137
- raise ValueError("候选文档列表不能为空")
138
  if top_k <= 0 or top_k > len(documents):
139
  raise ValueError(f"top_k 需在 1~{len(documents)} 之间")
140
-
141
- # 计算每篇文档的相关性分数
142
  doc_scores = []
143
  for doc in documents:
144
- # 模型输入格式:query [SEP] document(SEP 是模型默认分隔符)
145
- input_text = f"{query} {self.tokenizer.sep_token} {doc}"
146
  inputs = self.tokenizer(
147
- input_text,
148
  return_tensors="pt",
149
  padding="max_length",
150
  truncation=truncation,
151
- max_length=512 # 模型最大输入长度(MiniLM-L-6-v2 支持 512 Token)
152
  ).to(self.device)
153
-
154
- # 推理(关闭梯度计算,提升速度)
155
  with torch.no_grad():
156
  outputs = self.model(**inputs)
157
- # 模型输出的 logits 即为相关性分数(无需 softmax,直接使用原始值)
158
  score = outputs.logits.item()
159
  doc_scores.append((doc, score))
160
-
161
- # 按分数降序排序,取 Top K 并生成结果
162
  sorted_docs = sorted(doc_scores, key=lambda x: x[1], reverse=True)[:top_k]
163
  return [
164
- DocumentScore(
165
- document=doc,
166
- score=round(score, 4), # 分数保留 4 位小数,便于阅读
167
- rank=i+1 # 名次从 1 开始
168
- ) for i, (doc, score) in enumerate(sorted_docs)
169
  ]
170
 
171
- # 初始化模型(全局唯一,避免重复加载)
172
  reranker = CrossEncoderModel()
173
 
174
- # ------------------- 6. API 端点定义 -------------------
175
- # 6.1 根路径首页(HTML 格式,无 Markdown 依赖)
176
- @app.get("/", response_class=HTMLResponse, description="API 首页(含调用指南)")
177
  async def home_page():
178
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
179
  return f"""
@@ -181,7 +139,7 @@ async def home_page():
181
  <html lang="zh-CN">
182
  <head>
183
  <meta charset="UTF-8">
184
- <title>Cross-Encoder 重排序 API</title>
185
  <style>
186
  body {{ font-family: Arial, sans-serif; max-width: 1200px; margin: 0 auto; padding: 20px; }}
187
  h1 {{ color: #2c3e50; border-bottom: 2px solid #3498db; padding-bottom: 10px; }}
@@ -190,104 +148,77 @@ async def home_page():
190
  table {{ border-collapse: collapse; width: 100%; margin: 20px 0; }}
191
  th, td {{ border: 1px solid #e9ecef; padding: 12px; text-align: left; }}
192
  th {{ background-color: #f1f5f9; }}
193
- .note {{ color: #6c757d; font-size: 0.9em; }}
194
- .api-url {{ color: #3498db; font-weight: bold; }}
195
  </style>
196
  </head>
197
  <body>
198
- <h1>Cross-Encoder 重排序 API</h1>
199
- <p>基于 <code>cross-encoder/ms-marco-MiniLM-L-6-v2</code> 模型,提供文本相关性排序服务,支持 GPT 标准 API 调用格式。</p>
200
-
201
- <h2>核心功能</h2>
202
- <ul>
203
- <li>输入「查询语句 + 候选文档列表」,返回按相关性降序排列的结果(含分数、排名)</li>
204
- <li>支持两种 API 格式:基础重排序接口(/api/v1/rerank)和 GPT 兼容接口(/v1/chat/completions)</li>
205
- <li>API Key 认证,保障接口安全</li>
206
- </ul>
207
 
208
  <h2>接口列表</h2>
209
  <table>
210
  <tr>
211
- <th>接口名称</th>
212
  <th>URL</th>
213
  <th>方法</th>
214
- <th>说明</th>
215
  </tr>
216
  <tr>
217
- <td>基础重排序接口</td>
218
- <td class="api-url">{app.root_path}/api/v1/rerank</td>
219
  <td>POST</td>
220
- <td>标准化重排序接口,返回结构化结果</td>
221
  </tr>
222
  <tr>
223
- <td>GPT 兼容接口</td>
224
- <td class="api-url">{app.root_path}/v1/chat/completions</td>
225
  <td>POST</td>
226
- <td>模仿 OpenAI 格式,可直接用 OpenAI 库调用</td>
227
  </tr>
228
  <tr>
229
  <td>健康检查</td>
230
- <td class="api-url">{app.root_path}/api/v1/health</td>
231
  <td>GET</td>
232
- <td>无需认证,检查服务状态</td>
233
  </tr>
234
  </table>
235
 
236
- <h2>调用示例(GPT 兼容接口)</h2>
237
- <pre><code>from openai import OpenAI
238
 
239
- # 配置客户端(指向你的 Space 地址)
240
- client = OpenAI(
241
- api_key="your-api-key-here", # 替换为你的 API Key
242
- base_url="https://&lt;your-username&gt;-&lt;your-space-name&gt;.hf.space/v1" # 替换为你的 Space URL
243
  )
244
 
245
- # 发送重排序请求
246
  response = client.chat.completions.create(
247
- model="cross-encoder/ms-marco-MiniLM-L-6-v2", # 固定模型名
248
  messages=[
249
  {{
250
  "role": "user",
251
- "content": "query: 什么是机器学习?; documents: 机器学习是AI的分支; Python是编程语言; 深度学习是机器学习的子集;"
252
  }}
253
  ],
254
- top_k=2 # 返回 Top 2 高相关文档
255
  )
256
 
257
- # 打印结果
258
  print(response.choices[0].message.content)</code></pre>
259
-
260
- <h2>API Key 认证方式</h2>
261
- <p>所有 POST 接口需通过以下方式之一传递 API Key:</p>
262
- <ul>
263
- <li><strong>Header 方式(推荐)</strong>:在请求 Header 中添加 <code>X-API-Key: your-api-key</code></li>
264
- <li><strong>Query 方式(备用)</strong>:在 URL 后添加 <code>?api_key=your-api-key</code></li>
265
- </ul>
266
-
267
- <p class="note">页面生成时间: {current_time} | 模型运行设备: {reranker.device}</p>
268
  </body>
269
  </html>
270
  """
271
 
272
- # 6.2 基础重排序接口(标准化格式)
273
- @app.post(
274
- "/api/v1/rerank",
275
- response_model=RerankResponse,
276
- description="基础重排序接口,返回结构化的排序结果"
277
- )
278
  async def base_rerank(
279
  request: RerankRequest,
280
  api_key: str = Depends(verify_api_key)
281
  ):
282
  try:
283
- # 执行重排序
284
  results = reranker.rerank(
285
  query=request.query,
286
  documents=request.documents,
287
  top_k=request.top_k,
288
  truncation=request.truncation
289
  )
290
- # 生成响应
291
  return RerankResponse(
292
  request_id=str(uuid.uuid4()),
293
  query=request.query,
@@ -299,72 +230,57 @@ async def base_rerank(
299
  except Exception as e:
300
  raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}")
301
 
302
- # 6.3 GPT 兼容接口(模仿 OpenAI 格式)
303
- @app.post(
304
- "/v1/chat/completions",
305
- response_model=GPTResponse,
306
- description="GPT 兼容接口,支持用 OpenAI 库调用"
307
- )
308
  async def gpt_compatible_rerank(
309
  request: GPTRequest,
310
  api_key: str = Depends(verify_api_key)
311
  ):
312
  try:
313
- # 校验模型名(确保兼容 GPT 调用格式)
314
  if request.model != reranker.model_name:
315
  raise ValueError(f"仅支持模型:{reranker.model_name}")
316
- # 校验消息(仅支持最后一条为 user 角色)
317
  if not request.messages or request.messages[-1].role != "user":
318
  raise ValueError("最后一条消息必须是 'user' 角色")
319
-
320
- # 解析用户输入(从 content 中提取 query 和 documents)
321
  content = request.messages[-1].content
322
  if "; documents: " not in content:
323
- raise ValueError("输入格式错误,需为:'query: [查询]; documents: [文档1]; [文档2]; ...'")
324
  query_part, docs_part = content.split("; documents: ")
325
  query = query_part.replace("query: ", "").strip()
326
  documents = [doc.strip() for doc in docs_part.split(";") if doc.strip()]
327
-
328
- # 执行重排序
329
  results = reranker.rerank(
330
  query=query,
331
  documents=documents,
332
  top_k=request.top_k,
333
  truncation=True
334
  )
335
- # 格式化 GPT 风格的响应
336
  return GPTResponse(
337
- choices=[{
338
- "index": 0,
339
- "message": {
340
- "role": "assistant",
341
- "content": f"重排序结果(按相关性降序):\n{[{'文档': r.document, '分数': r.score, '排名': r.rank} for r in results]}"
342
- },
343
- "finish_reason": "stop"
344
- }]
 
 
345
  )
346
  except ValueError as e:
347
  raise HTTPException(status_code=400, detail=str(e))
348
  except Exception as e:
349
  raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}")
350
 
351
- # 6.4 健康检查接口
352
- @app.get("/api/v1/health", description="服务健康检查接口(无需认证)")
353
  async def health_check():
354
  return {
355
  "status": "healthy",
356
  "model": reranker.model_name,
357
  "device": reranker.device,
358
- "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
359
- "message": "服务正常运行"
360
  }
361
 
362
- # ------------------- 7. 本地运行入口(开发环境用) -------------------
363
  if __name__ == "__main__":
364
  import uvicorn
365
- uvicorn.run(
366
- app="app:app",
367
- host="0.0.0.0",
368
- port=7860,
369
- reload=False # 生产环境关闭 reload
370
- )
 
1
  import os
2
  import uuid
3
  from datetime import datetime
4
+ from fastapi import FastAPI, HTTPException, Depends, Request
5
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
6
+ from fastapi.responses import HTMLResponse
7
  from pydantic import BaseModel
8
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
  import torch
10
  from typing import List, Optional
11
 
12
+ # ------------------- 1. 基础配置(缓存 + 环境变量) -------------------
 
13
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache"
14
  os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache"
15
 
16
+ # 从环境变量获取 API Key(OpenAI 风格)
17
+ API_KEY = os.getenv("OPENAI_API_KEY")
18
  if not API_KEY:
19
+ raise ValueError("请设置环境变量 OPENAI_API_KEY")
20
 
21
+ # ------------------- 2. 初始化 FastAPI 应用 -------------------
22
  app = FastAPI(
23
+ title="OpenAI 兼容的 Cross-Encoder 重排序 API",
24
+ description="基于 cross-encoder/ms-marco-MiniLM-L-6-v2 的文本相关性排序接口",
25
  version="1.0.0"
26
  )
27
 
28
+ # ------------------- 3. OpenAI 风格认证(Bearer Token) -------------------
29
+ oauth2_scheme = HTTPBearer(auto_error=False)
 
 
 
 
 
 
 
 
 
 
30
 
31
+ def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(oauth2_scheme)):
32
+ """验证 API Key:必须通过 Authorization: Bearer YOUR_API_KEY 传递"""
33
+ if not credentials or credentials.scheme != "Bearer" or credentials.credentials != API_KEY:
34
+ raise HTTPException(
35
+ status_code=401,
36
+ detail="无效的 API Key(请使用 'Authorization: Bearer YOUR_API_KEY')",
37
+ headers={"WWW-Authenticate": "Bearer"}
38
+ )
39
+ return credentials.credentials
 
 
 
 
 
40
 
41
+ # ------------------- 4. 数据模型定义 -------------------
42
  class RerankRequest(BaseModel):
43
+ query: str
44
+ documents: List[str]
45
+ top_k: Optional[int] = 3
46
+ truncation: Optional[bool] = True
 
47
 
48
  class DocumentScore(BaseModel):
49
+ document: str
50
+ score: float
51
+ rank: int
 
52
 
53
  class RerankResponse(BaseModel):
54
+ request_id: str
55
+ query: str
56
+ top_k: int
57
+ results: List[DocumentScore]
58
+ model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
59
+ timestamp: str = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
 
60
 
61
+ # GPT 兼容的请求/响应模型
62
  class GPTMessage(BaseModel):
63
+ role: str
64
+ content: str
65
 
66
  class GPTRequest(BaseModel):
67
+ model: str
68
+ messages: List[GPTMessage]
69
+ top_k: Optional[int] = 3
70
+
71
+ class Choice(BaseModel):
72
+ index: int
73
+ message: GPTMessage
74
+ finish_reason: str = "stop"
75
 
76
  class GPTResponse(BaseModel):
77
+ id: str = f"chatcmpl-{uuid.uuid4().hex}"
 
78
  object: str = "chat.completion"
79
  created: int = int(datetime.now().timestamp())
80
+ model: str
81
+ choices: List[Choice]
82
  usage: dict = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
83
 
84
+ # ------------------- 5. 加载 Cross-Encoder 模型 -------------------
 
85
  class CrossEncoderModel:
86
  def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
87
  self.model_name = model_name
88
+ # 验证缓存目录可写
 
89
  cache_dir = os.environ.get("TRANSFORMERS_CACHE", "/tmp/huggingface_cache")
90
  try:
91
+ test_file = os.path.join(cache_dir, "test.txt")
 
92
  with open(test_file, "w") as f:
93
  f.write("test")
94
+ os.remove(test_file)
95
+ print(f"缓存目录可写:{cache_dir}")
96
  except Exception as e:
97
+ raise RuntimeError(f"缓存目录不可写:{str(e)}")
98
+ # 加载模型
99
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
100
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir=cache_dir)
 
 
 
 
 
 
 
101
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
102
  self.model.to(self.device)
103
  self.model.eval()
104
+ print(f"模型加载完成,设备:{self.device}")
 
 
105
 
106
  def rerank(self, query: str, documents: List[str], top_k: int, truncation: bool) -> List[DocumentScore]:
 
 
107
  if not documents:
108
+ raise ValueError("候选文档不能为空")
109
  if top_k <= 0 or top_k > len(documents):
110
  raise ValueError(f"top_k 需在 1~{len(documents)} 之间")
 
 
111
  doc_scores = []
112
  for doc in documents:
 
 
113
  inputs = self.tokenizer(
114
+ f"{query} {self.tokenizer.sep_token} {doc}",
115
  return_tensors="pt",
116
  padding="max_length",
117
  truncation=truncation,
118
+ max_length=512
119
  ).to(self.device)
 
 
120
  with torch.no_grad():
121
  outputs = self.model(**inputs)
 
122
  score = outputs.logits.item()
123
  doc_scores.append((doc, score))
 
 
124
  sorted_docs = sorted(doc_scores, key=lambda x: x[1], reverse=True)[:top_k]
125
  return [
126
+ DocumentScore(document=doc, score=round(score, 4), rank=i+1)
127
+ for i, (doc, score) in enumerate(sorted_docs)
 
 
 
128
  ]
129
 
 
130
  reranker = CrossEncoderModel()
131
 
132
+ # ------------------- 6. API 端点(OpenAI 风格路径) -------------------
133
+ # 6.1 根路径首页
134
+ @app.get("/", response_class=HTMLResponse)
135
  async def home_page():
136
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
137
  return f"""
 
139
  <html lang="zh-CN">
140
  <head>
141
  <meta charset="UTF-8">
142
+ <title>OpenAI 兼容重排序 API</title>
143
  <style>
144
  body {{ font-family: Arial, sans-serif; max-width: 1200px; margin: 0 auto; padding: 20px; }}
145
  h1 {{ color: #2c3e50; border-bottom: 2px solid #3498db; padding-bottom: 10px; }}
 
148
  table {{ border-collapse: collapse; width: 100%; margin: 20px 0; }}
149
  th, td {{ border: 1px solid #e9ecef; padding: 12px; text-align: left; }}
150
  th {{ background-color: #f1f5f9; }}
 
 
151
  </style>
152
  </head>
153
  <body>
154
+ <h1>OpenAI 兼容的 Cross-Encoder 重排序 API</h1>
155
+ <p>基于 <code>cross-encoder/ms-marco-MiniLM-L-6-v2</code> 模型,支持 OpenAI 风格 API 调用。</p>
 
 
 
 
 
 
 
156
 
157
  <h2>接口列表</h2>
158
  <table>
159
  <tr>
160
+ <th>接口</th>
161
  <th>URL</th>
162
  <th>方法</th>
163
+ <th>认证</th>
164
  </tr>
165
  <tr>
166
+ <td>基础重排序</td>
167
+ <td class="api-url">/v1/rerank</td>
168
  <td>POST</td>
169
+ <td>Authorization: Bearer API_KEY</td>
170
  </tr>
171
  <tr>
172
+ <td>GPT 兼容重排序</td>
173
+ <td class="api-url">/v1/chat/completions</td>
174
  <td>POST</td>
175
+ <td>Authorization: Bearer API_KEY</td>
176
  </tr>
177
  <tr>
178
  <td>健康检查</td>
179
+ <td class="api-url">/v1/health</td>
180
  <td>GET</td>
181
+ <td>无需认证</td>
182
  </tr>
183
  </table>
184
 
185
+ <h2>调用示例(Python)</h2>
186
+ <pre><code>import openai
187
 
188
+ client = openai.OpenAI(
189
+ api_key="YOUR_API_KEY",
190
+ base_url="https://your-space.hf.space/v1" # 替换为你的 Space URL
 
191
  )
192
 
 
193
  response = client.chat.completions.create(
194
+ model="cross-encoder/ms-marco-MiniLM-L-6-v2",
195
  messages=[
196
  {{
197
  "role": "user",
198
+ "content": "query: 什么是机器学习?; documents: 机器学习是AI的分支; Python是编程语言;"
199
  }}
200
  ],
201
+ top_k=2
202
  )
203
 
 
204
  print(response.choices[0].message.content)</code></pre>
 
 
 
 
 
 
 
 
 
205
  </body>
206
  </html>
207
  """
208
 
209
+ # 6.2 基础重排序接口(/v1/rerank)
210
+ @app.post("/v1/rerank", response_model=RerankResponse)
 
 
 
 
211
  async def base_rerank(
212
  request: RerankRequest,
213
  api_key: str = Depends(verify_api_key)
214
  ):
215
  try:
 
216
  results = reranker.rerank(
217
  query=request.query,
218
  documents=request.documents,
219
  top_k=request.top_k,
220
  truncation=request.truncation
221
  )
 
222
  return RerankResponse(
223
  request_id=str(uuid.uuid4()),
224
  query=request.query,
 
230
  except Exception as e:
231
  raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}")
232
 
233
+ # 6.3 GPT 兼容接口(/v1/chat/completions)
234
+ @app.post("/v1/chat/completions", response_model=GPTResponse)
 
 
 
 
235
  async def gpt_compatible_rerank(
236
  request: GPTRequest,
237
  api_key: str = Depends(verify_api_key)
238
  ):
239
  try:
 
240
  if request.model != reranker.model_name:
241
  raise ValueError(f"仅支持模型:{reranker.model_name}")
 
242
  if not request.messages or request.messages[-1].role != "user":
243
  raise ValueError("最后一条消息必须是 'user' 角色")
 
 
244
  content = request.messages[-1].content
245
  if "; documents: " not in content:
246
+ raise ValueError("输入格式需为 'query: [查询]; documents: [文档1]; [文档2]; ...'")
247
  query_part, docs_part = content.split("; documents: ")
248
  query = query_part.replace("query: ", "").strip()
249
  documents = [doc.strip() for doc in docs_part.split(";") if doc.strip()]
 
 
250
  results = reranker.rerank(
251
  query=query,
252
  documents=documents,
253
  top_k=request.top_k,
254
  truncation=True
255
  )
 
256
  return GPTResponse(
257
+ model=request.model,
258
+ choices=[
259
+ Choice(
260
+ index=0,
261
+ message=GPTMessage(
262
+ role="assistant",
263
+ content=f"重排序结果:{results}"
264
+ )
265
+ )
266
+ ]
267
  )
268
  except ValueError as e:
269
  raise HTTPException(status_code=400, detail=str(e))
270
  except Exception as e:
271
  raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}")
272
 
273
+ # 6.4 健康检查接口(/v1/health)
274
+ @app.get("/v1/health")
275
  async def health_check():
276
  return {
277
  "status": "healthy",
278
  "model": reranker.model_name,
279
  "device": reranker.device,
280
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
 
281
  }
282
 
283
+ # ------------------- 7. 本地运行入口 -------------------
284
  if __name__ == "__main__":
285
  import uvicorn
286
+ uvicorn.run(app, host="0.0.0.0", port=7860)