fiewolf1000 commited on
Commit
72de27b
·
verified ·
1 Parent(s): 8fb64e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -59
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import uuid
 
3
  from datetime import datetime
4
  from fastapi import FastAPI, HTTPException, Depends, Request
5
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
@@ -9,36 +10,63 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
  import torch
10
  from typing import List, Optional
11
 
12
- # ------------------- 1. 基础配置(缓存 + 环境变量) -------------------
 
 
 
 
 
 
 
 
13
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache"
14
  os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache"
15
 
16
  # 从环境变量获取 API Key(OpenAI 风格)
17
  API_KEY = os.getenv("OPENAI_API_KEY")
18
  if not API_KEY:
 
19
  raise ValueError("请设置环境变量 OPENAI_API_KEY")
 
20
 
21
- # ------------------- 2. 初始化 FastAPI 应用 -------------------
22
  app = FastAPI(
23
  title="OpenAI 兼容的 Cross-Encoder 重排序 API",
24
  description="基于 cross-encoder/ms-marco-MiniLM-L-6-v2 的文本相关性排序接口",
25
  version="1.0.0"
26
  )
27
 
28
- # ------------------- 3. OpenAI 风格认证(Bearer Token) -------------------
29
  oauth2_scheme = HTTPBearer(auto_error=False)
30
 
31
  def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(oauth2_scheme)):
32
  """验证 API Key:必须通过 Authorization: Bearer YOUR_API_KEY 传递"""
33
- if not credentials or credentials.scheme != "Bearer" or credentials.credentials != API_KEY:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  raise HTTPException(
35
  status_code=401,
36
- detail="无效的 API Key(请使用 'Authorization: Bearer YOUR_API_KEY')",
37
  headers={"WWW-Authenticate": "Bearer"}
38
  )
39
- return credentials.credentials
 
40
 
41
- # ------------------- 4. 数据模型定义 -------------------
42
  class RerankRequest(BaseModel):
43
  query: str
44
  documents: List[str]
@@ -81,10 +109,12 @@ class GPTResponse(BaseModel):
81
  choices: List[Choice]
82
  usage: dict = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
83
 
84
- # ------------------- 5. 加载 Cross-Encoder 模型 -------------------
85
  class CrossEncoderModel:
86
  def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
87
  self.model_name = model_name
 
 
88
  # 验证缓存目录可写
89
  cache_dir = os.environ.get("TRANSFORMERS_CACHE", "/tmp/huggingface_cache")
90
  try:
@@ -92,50 +122,94 @@ class CrossEncoderModel:
92
  with open(test_file, "w") as f:
93
  f.write("test")
94
  os.remove(test_file)
95
- print(f"缓存目录可写:{cache_dir}")
96
  except Exception as e:
 
97
  raise RuntimeError(f"缓存目录不可写:{str(e)}")
 
98
  # 加载模型
99
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
100
- self.model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir=cache_dir)
101
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
102
- self.model.to(self.device)
103
- self.model.eval()
104
- print(f"模型加载完成,设备:{self.device}")
 
 
 
 
 
 
 
 
 
 
105
 
106
- def rerank(self, query: str, documents: List[str], top_k: int, truncation: bool) -> List[DocumentScore]:
 
 
 
 
107
  if not documents:
 
108
  raise ValueError("候选文档不能为空")
109
  if top_k <= 0:
 
110
  raise ValueError("top_k 必须为正整数")
 
111
  # 自动将 top_k 限制为文档数量(避免超出)
112
- top_k = min(top_k, len(documents))
 
 
 
 
113
  doc_scores = []
114
- for doc in documents:
115
- inputs = self.tokenizer(
116
- f"{query} {self.tokenizer.sep_token} {doc}",
117
- return_tensors="pt",
118
- padding="max_length",
119
- truncation=truncation,
120
- max_length=512
121
- ).to(self.device)
122
- with torch.no_grad():
123
- outputs = self.model(**inputs)
124
- score = outputs.logits.item()
125
- doc_scores.append((doc, score))
126
- sorted_docs = sorted(doc_scores, key=lambda x: x[1], reverse=True)[:top_k]
127
- return [
128
- DocumentScore(document=doc, score=round(score, 4), rank=i+1)
129
- for i, (doc, score) in enumerate(sorted_docs)
130
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- reranker = CrossEncoderModel()
 
 
 
 
 
133
 
134
- # ------------------- 6. API 端点(OpenAI 风格路径) -------------------
135
- # 6.1 根路径首页
136
  @app.get("/", response_class=HTMLResponse)
137
- async def home_page():
 
138
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
 
139
  return f"""
140
  <!DOCTYPE html>
141
  <html lang="zh-CN">
@@ -208,55 +282,94 @@ print(response.choices[0].message.content)</code></pre>
208
  </html>
209
  """
210
 
211
- # 6.2 基础重排序接口(/v1/rerank)
212
  @app.post("/v1/rerank", response_model=RerankResponse)
213
  async def base_rerank(
214
  request: RerankRequest,
215
- api_key: str = Depends(verify_api_key)
216
  ):
 
217
  try:
218
- print(f"接收到的请求:{request.dict()}") # 打印请求内容
 
 
219
  results = reranker.rerank(
220
  query=request.query,
221
  documents=request.documents,
222
  top_k=request.top_k,
223
- truncation=request.truncation
 
224
  )
225
- return RerankResponse(
226
- request_id=str(uuid.uuid4()),
 
 
227
  query=request.query,
228
- top_k=request.top_k,
229
  results=results
230
  )
 
 
 
 
231
  except ValueError as e:
 
232
  raise HTTPException(status_code=400, detail=str(e))
233
  except Exception as e:
 
234
  raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}")
235
 
236
- # 6.3 GPT 兼容接口(/v1/chat/completions)
237
  @app.post("/v1/chat/completions", response_model=GPTResponse)
238
  async def gpt_compatible_rerank(
239
  request: GPTRequest,
240
- api_key: str = Depends(verify_api_key)
241
  ):
 
242
  try:
 
 
 
243
  if request.model != reranker.model_name:
244
- raise ValueError(f"仅支持模型:{reranker.model_name}")
245
- if not request.messages or request.messages[-1].role != "user":
246
- raise ValueError("最后一条消息必须是 'user' 角色")
 
 
 
 
 
 
 
 
 
 
 
247
  content = request.messages[-1].content
 
 
248
  if "; documents: " not in content:
249
- raise ValueError("输入格式需为 'query: [查询]; documents: [文档1]; [文档2]; ...'")
 
 
 
250
  query_part, docs_part = content.split("; documents: ")
251
  query = query_part.replace("query: ", "").strip()
252
  documents = [doc.strip() for doc in docs_part.split(";") if doc.strip()]
 
 
 
 
253
  results = reranker.rerank(
254
  query=query,
255
  documents=documents,
256
  top_k=request.top_k,
257
- truncation=True
 
258
  )
259
- return GPTResponse(
 
 
260
  model=request.model,
261
  choices=[
262
  Choice(
@@ -268,22 +381,38 @@ async def gpt_compatible_rerank(
268
  )
269
  ]
270
  )
 
 
 
 
271
  except ValueError as e:
 
272
  raise HTTPException(status_code=400, detail=str(e))
273
  except Exception as e:
 
274
  raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}")
275
 
276
- # 6.4 健康检查接口(/v1/health)
277
  @app.get("/v1/health")
278
- async def health_check():
279
- return {
 
280
  "status": "healthy",
281
  "model": reranker.model_name,
282
  "device": reranker.device,
283
- "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
 
284
  }
 
 
285
 
286
- # ------------------- 7. 本地运行入口 -------------------
287
  if __name__ == "__main__":
288
  import uvicorn
289
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
1
  import os
2
  import uuid
3
+ import logging
4
  from datetime import datetime
5
  from fastapi import FastAPI, HTTPException, Depends, Request
6
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 
10
  import torch
11
  from typing import List, Optional
12
 
13
+ # ------------------- 1. 日志配置 -------------------
14
+ logging.basicConfig(
15
+ level=logging.INFO,
16
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
17
+ datefmt="%Y-%m-%d %H:%M:%S"
18
+ )
19
+ logger = logging.getLogger("cross-encoder-api")
20
+
21
+ # ------------------- 2. 基础配置(缓存 + 环境变量) -------------------
22
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache"
23
  os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface_cache"
24
 
25
  # 从环境变量获取 API Key(OpenAI 风格)
26
  API_KEY = os.getenv("OPENAI_API_KEY")
27
  if not API_KEY:
28
+ logger.error("环境变量 OPENAI_API_KEY 未设置")
29
  raise ValueError("请设置环境变量 OPENAI_API_KEY")
30
+ logger.info("API Key 加载成功")
31
 
32
+ # ------------------- 3. 初始化 FastAPI 应用 -------------------
33
  app = FastAPI(
34
  title="OpenAI 兼容的 Cross-Encoder 重排序 API",
35
  description="基于 cross-encoder/ms-marco-MiniLM-L-6-v2 的文本相关性排序接口",
36
  version="1.0.0"
37
  )
38
 
39
+ # ------------------- 4. OpenAI 风格认证(Bearer Token) -------------------
40
  oauth2_scheme = HTTPBearer(auto_error=False)
41
 
42
  def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(oauth2_scheme)):
43
  """验证 API Key:必须通过 Authorization: Bearer YOUR_API_KEY 传递"""
44
+ request_id = str(uuid.uuid4())[:8] # 生成短请求ID用于日志追踪
45
+ if not credentials:
46
+ logger.warning(f"请求 {request_id}:缺少认证信息")
47
+ raise HTTPException(
48
+ status_code=401,
49
+ detail="缺少认证信息(请使用 'Authorization: Bearer YOUR_API_KEY')",
50
+ headers={"WWW-Authenticate": "Bearer"}
51
+ )
52
+ if credentials.scheme != "Bearer":
53
+ logger.warning(f"请求 {request_id}:认证方案错误,应为 Bearer,实际为 {credentials.scheme}")
54
+ raise HTTPException(
55
+ status_code=401,
56
+ detail="认证方案错误(请使用 'Bearer' 方案)",
57
+ headers={"WWW-Authenticate": "Bearer"}
58
+ )
59
+ if credentials.credentials != API_KEY:
60
+ logger.warning(f"请求 {request_id}:无效的 API Key")
61
  raise HTTPException(
62
  status_code=401,
63
+ detail="无效的 API Key",
64
  headers={"WWW-Authenticate": "Bearer"}
65
  )
66
+ logger.info(f"请求 {request_id}:API Key 验证通过")
67
+ return (credentials.credentials, request_id) # 返回API Key和请求ID
68
 
69
+ # ------------------- 5. 数据模型定义 -------------------
70
  class RerankRequest(BaseModel):
71
  query: str
72
  documents: List[str]
 
109
  choices: List[Choice]
110
  usage: dict = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
111
 
112
+ # ------------------- 6. 加载 Cross-Encoder 模型 -------------------
113
  class CrossEncoderModel:
114
  def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
115
  self.model_name = model_name
116
+ logger.info(f"开始加载模型:{model_name}")
117
+
118
  # 验证缓存目录可写
119
  cache_dir = os.environ.get("TRANSFORMERS_CACHE", "/tmp/huggingface_cache")
120
  try:
 
122
  with open(test_file, "w") as f:
123
  f.write("test")
124
  os.remove(test_file)
125
+ logger.info(f"缓存目录可写:{cache_dir}")
126
  except Exception as e:
127
+ logger.error(f"缓存目录不可写:{str(e)}")
128
  raise RuntimeError(f"缓存目录不可写:{str(e)}")
129
+
130
  # 加载模型
131
+ try:
132
+ logger.info("开始加载分词器...")
133
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
134
+ logger.info("分词器加载完成")
135
+
136
+ logger.info("开始加载模型权重...")
137
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir=cache_dir)
138
+ logger.info("模型权重加载完成")
139
+
140
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
141
+ self.model.to(self.device)
142
+ self.model.eval()
143
+ logger.info(f"模型加载完成,使用设备:{self.device}")
144
+ except Exception as e:
145
+ logger.error(f"模型加载失败:{str(e)}")
146
+ raise
147
 
148
+ def rerank(self, query: str, documents: List[str], top_k: int, truncation: bool, request_id: str) -> List[DocumentScore]:
149
+ """核心重排序逻辑,增加详细日志"""
150
+ logger.info(f"请求 {request_id}:开始重排序处理,查询长度: {len(query)}, 文档数量: {len(documents)}, top_k: {top_k}")
151
+
152
+ # 参数校验
153
  if not documents:
154
+ logger.warning(f"请求 {request_id}:候选文档列表为空")
155
  raise ValueError("候选文档不能为空")
156
  if top_k <= 0:
157
+ logger.warning(f"请求 {request_id}:无效的 top_k 值: {top_k}")
158
  raise ValueError("top_k 必须为正整数")
159
+
160
  # 自动将 top_k 限制为文档数量(避免超出)
161
+ adjusted_top_k = min(top_k, len(documents))
162
+ if adjusted_top_k != top_k:
163
+ logger.info(f"请求 {request_id}:top_k 从 {top_k} 调整为 {adjusted_top_k}(文档数量限制)")
164
+
165
+ # 计算每篇文档的相关性分数
166
  doc_scores = []
167
+ try:
168
+ for i, doc in enumerate(documents):
169
+ if i % 5 == 0: # 每处理5个文档输出一次日志
170
+ logger.info(f"请求 {request_id}:正在处理第 {i+1}/{len(documents)} 个文档")
171
+
172
+ inputs = self.tokenizer(
173
+ f"{query} {self.tokenizer.sep_token} {doc}",
174
+ return_tensors="pt",
175
+ padding="max_length",
176
+ truncation=truncation,
177
+ max_length=512
178
+ ).to(self.device)
179
+
180
+ with torch.no_grad():
181
+ outputs = self.model(**inputs)
182
+
183
+ score = outputs.logits.item()
184
+ doc_scores.append((doc, score))
185
+ logger.debug(f"请求 {request_id}:文档 {i+1} 分数: {score:.4f}")
186
+
187
+ # 排序并返回结果
188
+ sorted_docs = sorted(doc_scores, key=lambda x: x[1], reverse=True)[:adjusted_top_k]
189
+ logger.info(f"请求 {request_id}:重排序完成,返回 {len(sorted_docs)} 个结果")
190
+
191
+ return [
192
+ DocumentScore(document=doc, score=round(score, 4), rank=i+1)
193
+ for i, (doc, score) in enumerate(sorted_docs)
194
+ ]
195
+ except Exception as e:
196
+ logger.error(f"请求 {request_id}:重排序过程出错: {str(e)}")
197
+ raise
198
 
199
+ # 初始化模型(全局唯一)
200
+ try:
201
+ reranker = CrossEncoderModel()
202
+ except Exception as e:
203
+ logger.critical(f"模型初始化失败,服务无法启动: {str(e)}")
204
+ raise
205
 
206
+ # ------------------- 7. API 端点(OpenAI 风格路径) -------------------
207
+ # 7.1 根路径首页
208
  @app.get("/", response_class=HTMLResponse)
209
+ async def home_page(request: Request):
210
+ client_ip = request.client.host
211
  current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
212
+ logger.info(f"首页访问来自 {client_ip}")
213
  return f"""
214
  <!DOCTYPE html>
215
  <html lang="zh-CN">
 
282
  </html>
283
  """
284
 
285
+ # 7.2 基础重排序接口(/v1/rerank)
286
  @app.post("/v1/rerank", response_model=RerankResponse)
287
  async def base_rerank(
288
  request: RerankRequest,
289
+ auth_result: tuple = Depends(verify_api_key)
290
  ):
291
+ api_key, request_id = auth_result
292
  try:
293
+ logger.info(f"请求 {request_id}:收到 /v1/rerank 请求,query: {request.query[:50]}...(截断显示)")
294
+
295
+ # 执行重排序
296
  results = reranker.rerank(
297
  query=request.query,
298
  documents=request.documents,
299
  top_k=request.top_k,
300
+ truncation=request.truncation,
301
+ request_id=request_id
302
  )
303
+
304
+ # 构建响应
305
+ response = RerankResponse(
306
+ request_id=request_id,
307
  query=request.query,
308
+ top_k=min(request.top_k, len(request.documents)),
309
  results=results
310
  )
311
+
312
+ logger.info(f"请求 {request_id}:处理完成,返回 {len(results)} 个结果")
313
+ return response
314
+
315
  except ValueError as e:
316
+ logger.warning(f"请求 {request_id}:参数错误 - {str(e)}")
317
  raise HTTPException(status_code=400, detail=str(e))
318
  except Exception as e:
319
+ logger.error(f"请求 {request_id}:服务器错误 - {str(e)}", exc_info=True)
320
  raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}")
321
 
322
+ # 7.3 GPT 兼容接口(/v1/chat/completions)
323
  @app.post("/v1/chat/completions", response_model=GPTResponse)
324
  async def gpt_compatible_rerank(
325
  request: GPTRequest,
326
+ auth_result: tuple = Depends(verify_api_key)
327
  ):
328
+ api_key, request_id = auth_result
329
  try:
330
+ logger.info(f"请求 {request_id}:收到 /v1/chat/completions 请求,模型: {request.model}")
331
+
332
+ # 验证模型名
333
  if request.model != reranker.model_name:
334
+ error_msg = f"仅支持模型:{reranker.model_name},实际请求:{request.model}"
335
+ logger.warning(f"请求 {request_id}:{error_msg}")
336
+ raise ValueError(error_msg)
337
+
338
+ # 验证消息格式
339
+ if not request.messages:
340
+ logger.warning(f"请求 {request_id}:消息列表为空")
341
+ raise ValueError("消息列表不能为空")
342
+ if request.messages[-1].role != "user":
343
+ error_msg = f"最后一条消息必须是 'user' 角色,实际为:{request.messages[-1].role}"
344
+ logger.warning(f"请求 {request_id}:{error_msg}")
345
+ raise ValueError(error_msg)
346
+
347
+ # 解析输入内容
348
  content = request.messages[-1].content
349
+ logger.info(f"请求 {request_id}:用户输入: {content[:100]}...(截断显示)")
350
+
351
  if "; documents: " not in content:
352
+ error_msg = "输入格式需为 'query: [查询]; documents: [文档1]; [文档2]; ...'"
353
+ logger.warning(f"请求 {request_id}:{error_msg}")
354
+ raise ValueError(error_msg)
355
+
356
  query_part, docs_part = content.split("; documents: ")
357
  query = query_part.replace("query: ", "").strip()
358
  documents = [doc.strip() for doc in docs_part.split(";") if doc.strip()]
359
+
360
+ logger.info(f"请求 {request_id}:解析完成,query: {query[:50]}..., 文档数量: {len(documents)}")
361
+
362
+ # 执行重排序
363
  results = reranker.rerank(
364
  query=query,
365
  documents=documents,
366
  top_k=request.top_k,
367
+ truncation=True,
368
+ request_id=request_id
369
  )
370
+
371
+ # 构建 GPT 风格响应
372
+ response = GPTResponse(
373
  model=request.model,
374
  choices=[
375
  Choice(
 
381
  )
382
  ]
383
  )
384
+
385
+ logger.info(f"请求 {request_id}:处理完成,返回 {len(results)} 个结果")
386
+ return response
387
+
388
  except ValueError as e:
389
+ logger.warning(f"请求 {request_id}:参数错误 - {str(e)}")
390
  raise HTTPException(status_code=400, detail=str(e))
391
  except Exception as e:
392
+ logger.error(f"请求 {request_id}:服务器错误 - {str(e)}", exc_info=True)
393
  raise HTTPException(status_code=500, detail=f"服务器错误:{str(e)}")
394
 
395
+ # 7.4 健康检查接口(/v1/health)
396
  @app.get("/v1/health")
397
+ async def health_check(request: Request):
398
+ client_ip = request.client.host
399
+ status = {
400
  "status": "healthy",
401
  "model": reranker.model_name,
402
  "device": reranker.device,
403
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
404
+ "uptime": datetime.now().strftime("%Y-%m-%d %H:%M:%S") # 简化版uptime
405
  }
406
+ logger.info(f"健康检查来自 {client_ip}:{status['status']}")
407
+ return status
408
 
409
+ # ------------------- 8. 本地运行入口 -------------------
410
  if __name__ == "__main__":
411
  import uvicorn
412
+ logger.info("启动本地开发服务器...")
413
+ uvicorn.run(
414
+ app,
415
+ host="0.0.0.0",
416
+ port=7860,
417
+ log_config=None # 使用自定义日志配置
418
+ )