geqintan commited on
Commit
97ea61b
·
1 Parent(s): 67f78b9
Files changed (1) hide show
  1. app.py +39 -33
app.py CHANGED
@@ -1,80 +1,86 @@
1
- from fastapi import FastAPI, HTTPException, Depends, Header
 
2
  from pydantic import BaseModel, Field
3
- from sentence_transformers import CrossEncoder # 关键修改:使用 CrossEncoder 而非 SentenceTransformer
4
  import logging
5
  import os
6
- from typing import List
7
 
8
- # Configure logging
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
- # 依赖项:校验 Authorization Header
13
  async def verify_auth(authorization: str = Header(..., alias="Authorization")):
14
  if not authorization.startswith("Bearer "):
15
- raise HTTPException(status_code=401, detail="Invalid token format. Use 'Bearer YOUR_TOKEN'")
16
  token = authorization[len("Bearer "):]
17
  if token != os.getenv("AUTHORIZATION"):
18
- raise HTTPException(status_code=401, detail="Invalid token")
19
  return token
20
 
21
  app = FastAPI()
22
 
23
- # 加载重排序模型(初始化时加载,避免每次请求重复加载)
 
 
24
  try:
25
  model = CrossEncoder(
26
- "BAAI/bge-reranker-large",
27
- tokenizer_args={"truncation": True}, # 在初始化时设置截断
28
- max_length=512 # 可选:限制最大长度
29
  )
30
- logger.info("Model loaded with truncation support")
 
 
31
  except Exception as e:
32
  logger.critical(f"Model load failed: {str(e)}")
33
- raise RuntimeError("Model initialization error")
34
 
35
- # 请求体模型
36
  class RerankRequest(BaseModel):
37
  query: str = Field(..., min_length=1, max_length=8192)
38
  documents: List[str] = Field(..., min_items=1)
39
  top_k: int = Field(None, ge=1, le=100)
40
 
41
- # 响应模型
42
  class RerankResult(BaseModel):
43
  index: int
44
  score: float
45
  document: str
46
 
47
- @app.post("/rerank", response_model=List[RerankResult])
 
 
 
 
 
 
 
 
48
  async def rerank(
49
  request: RerankRequest,
50
- token: str = Depends(verify_auth) # 依赖权限校验
51
- ):
52
  try:
53
- # 创建 (query, document)
54
- pairs = [(request.query, doc) for doc in request.documents]
55
 
56
- # 获取预测分数(添加 truncation 处理长文本)
57
- scores = model.predict(pairs)
58
 
59
- # 构建结果列表
60
  results = [
61
  {"index": idx, "score": float(score), "document": doc}
62
  for idx, (doc, score) in enumerate(zip(request.documents, scores))
63
  ]
64
-
65
- # 按分数排序
66
  sorted_results = sorted(results, key=lambda x: x["score"], reverse=True)
67
 
68
- # 返回 top_k 结果
69
- if request.top_k is not None:
70
- return sorted_results[:request.top_k]
71
- # 返回符合 OpenAI 风格的格式
72
  return {
73
  "object": "list",
74
- "data": sorted_results, # 结果包装在 data 字段中
75
- "model": "BAAI/bge-reranker-large"
76
  }
77
 
78
  except Exception as e:
79
- logger.error(f"API Error: {str(e)}", exc_info=True)
80
- raise HTTPException(status_code=500, detail=str(e))
 
1
+ from fastapi import FastAPI, HTTPException, Depends, Header, Request
2
+ from fastapi.responses import JSONResponse
3
  from pydantic import BaseModel, Field
4
+ from sentence_transformers import CrossEncoder
5
  import logging
6
  import os
7
+ from typing import List, Dict
8
 
 
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
+ # 鉴权
13
  async def verify_auth(authorization: str = Header(..., alias="Authorization")):
14
  if not authorization.startswith("Bearer "):
15
+ raise HTTPException(401, detail="Invalid token format")
16
  token = authorization[len("Bearer "):]
17
  if token != os.getenv("AUTHORIZATION"):
18
+ raise HTTPException(401, detail="Invalid token")
19
  return token
20
 
21
  app = FastAPI()
22
 
23
+ # 模型配置
24
+ MODEL_NAME = "BAAI/bge-reranker-large" # 确保名称正确
25
+
26
  try:
27
  model = CrossEncoder(
28
+ MODEL_NAME,
29
+ tokenizer_args={"truncation": True},
30
+ max_length=512
31
  )
32
+ # 健康检查
33
+ test_score = model.predict([("test", "test")])[0]
34
+ logger.info(f"Model loaded. Test score: {test_score}")
35
  except Exception as e:
36
  logger.critical(f"Model load failed: {str(e)}")
37
+ raise RuntimeError("Model init failed")
38
 
39
+ # 请求/响应模型
40
  class RerankRequest(BaseModel):
41
  query: str = Field(..., min_length=1, max_length=8192)
42
  documents: List[str] = Field(..., min_items=1)
43
  top_k: int = Field(None, ge=1, le=100)
44
 
 
45
  class RerankResult(BaseModel):
46
  index: int
47
  score: float
48
  document: str
49
 
50
+ # 统一错误响应
51
+ @app.exception_handler(HTTPException)
52
+ async def handle_errors(request: Request, exc: HTTPException):
53
+ return JSONResponse(
54
+ status_code=exc.status_code,
55
+ content={"error": {"message": exc.detail, "type": "api_error"}}
56
+ )
57
+
58
+ @app.post("/rerank")
59
  async def rerank(
60
  request: RerankRequest,
61
+ token: str = Depends(verify_auth)
62
+ ) -> Dict:
63
  try:
64
+ logger.info(f"Processing query: {request.query[:50]}...")
 
65
 
66
+ pairs = [(request.query, doc) for doc in request.documents]
67
+ scores = model.predict(pairs)
68
 
 
69
  results = [
70
  {"index": idx, "score": float(score), "document": doc}
71
  for idx, (doc, score) in enumerate(zip(request.documents, scores))
72
  ]
 
 
73
  sorted_results = sorted(results, key=lambda x: x["score"], reverse=True)
74
 
75
+ if request.top_k:
76
+ sorted_results = sorted_results[:request.top_k]
77
+
 
78
  return {
79
  "object": "list",
80
+ "data": sorted_results,
81
+ "model": MODEL_NAME
82
  }
83
 
84
  except Exception as e:
85
+ logger.error(f"Error: {str(e)}", exc_info=True)
86
+ raise HTTPException(500, detail="Internal server error")