1bincooo commited on
Commit
f71153e
·
verified ·
1 Parent(s): 03252c8

Create localrerank.py

Browse files
Files changed (1) hide show
  1. localrerank.py +76 -0
localrerank.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import logging
4
+ import uvicorn
5
+ import datetime
6
+ from fastapi import FastAPI, Security, HTTPException
7
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
8
+ from FlagEmbedding import FlagReranker
9
+ from pydantic import Field, BaseModel, validator
10
+ from typing import Optional, List
11
+
12
+ app = FastAPI()
13
+ security = HTTPBearer()
14
+
15
+ #环境变量传入
16
+ sk_key = os.environ.get('sk-key', 'sk-aaabbbcccdddeeefffggghhhiiijjjkkk')
17
+
18
+ class QADocs(BaseModel):
19
+ query: Optional[str]
20
+ documents: Optional[List[str]]
21
+
22
+
23
+ class Singleton(type):
24
+ def __call__(cls, *args, **kwargs):
25
+ if not hasattr(cls, '_instance'):
26
+ cls._instance = super().__call__(*args, **kwargs)
27
+ return cls._instance
28
+
29
+
30
+ RERANK_MODEL_PATH = "Qwen/Qwen3-Reranker-4B"
31
+
32
+ class ReRanker(metaclass=Singleton):
33
+ def __init__(self, model_path):
34
+ self.reranker = FlagReranker(model_path, use_fp16=False)
35
+
36
+ def compute_score(self, pairs: List[List[str]]):
37
+ if len(pairs) > 0:
38
+ result = self.reranker.compute_score(pairs, normalize=True)
39
+ if isinstance(result, float):
40
+ result = [result]
41
+ return result
42
+ else:
43
+ return None
44
+
45
+ class Chat(object):
46
+ def __init__(self, rerank_model_path: str = RERANK_MODEL_PATH):
47
+ self.reranker = ReRanker(rerank_model_path)
48
+
49
+ def fit_query_answer_rerank(self, query_docs: QADocs) -> List:
50
+ if query_docs is None or len(query_docs.documents) == 0:
51
+ return []
52
+
53
+ pair = [[query_docs.query, doc] for doc in query_docs.documents]
54
+ scores = self.reranker.compute_score(pair)
55
+
56
+ new_docs = []
57
+ for index, score in enumerate(scores):
58
+ new_docs.append({"index": index, "text": query_docs.documents[index], "score": score})
59
+ results = [{"index": documents["index"], "relevance_score": documents["score"]} for documents in list(sorted(new_docs, key=lambda x: x["score"], reverse=True))]
60
+ return results
61
+
62
+ @app.post('/v1/rerank')
63
+ async def handle_post_request(docs: QADocs, credentials: HTTPAuthorizationCredentials = Security(security)):
64
+ token = credentials.credentials
65
+ if sk_key is not None and token != sk_key:
66
+ raise HTTPException(status_code=401, detail="Invalid token")
67
+ chat = Chat()
68
+ try:
69
+ results = chat.fit_query_answer_rerank(docs)
70
+ return {"results": results}
71
+ except Exception as e:
72
+ print(f"报错:\n{e}")
73
+ return {"error": "重排出错"}
74
+
75
+ if __name__ == "__main__":
76
+ uvicorn.run("localrerank:app", host='0.0.0.0', port=7860, workers=1)