Create localrerank.py
Browse files- localrerank.py +76 -0
localrerank.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import logging
|
| 4 |
+
import uvicorn
|
| 5 |
+
import datetime
|
| 6 |
+
from fastapi import FastAPI, Security, HTTPException
|
| 7 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 8 |
+
from FlagEmbedding import FlagReranker
|
| 9 |
+
from pydantic import Field, BaseModel, validator
|
| 10 |
+
from typing import Optional, List
|
| 11 |
+
|
| 12 |
+
app = FastAPI()
|
| 13 |
+
security = HTTPBearer()
|
| 14 |
+
|
| 15 |
+
#环境变量传入
|
| 16 |
+
sk_key = os.environ.get('sk-key', 'sk-aaabbbcccdddeeefffggghhhiiijjjkkk')
|
| 17 |
+
|
| 18 |
+
class QADocs(BaseModel):
|
| 19 |
+
query: Optional[str]
|
| 20 |
+
documents: Optional[List[str]]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Singleton(type):
|
| 24 |
+
def __call__(cls, *args, **kwargs):
|
| 25 |
+
if not hasattr(cls, '_instance'):
|
| 26 |
+
cls._instance = super().__call__(*args, **kwargs)
|
| 27 |
+
return cls._instance
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
RERANK_MODEL_PATH = "Qwen/Qwen3-Reranker-4B"
|
| 31 |
+
|
| 32 |
+
class ReRanker(metaclass=Singleton):
|
| 33 |
+
def __init__(self, model_path):
|
| 34 |
+
self.reranker = FlagReranker(model_path, use_fp16=False)
|
| 35 |
+
|
| 36 |
+
def compute_score(self, pairs: List[List[str]]):
|
| 37 |
+
if len(pairs) > 0:
|
| 38 |
+
result = self.reranker.compute_score(pairs, normalize=True)
|
| 39 |
+
if isinstance(result, float):
|
| 40 |
+
result = [result]
|
| 41 |
+
return result
|
| 42 |
+
else:
|
| 43 |
+
return None
|
| 44 |
+
|
| 45 |
+
class Chat(object):
|
| 46 |
+
def __init__(self, rerank_model_path: str = RERANK_MODEL_PATH):
|
| 47 |
+
self.reranker = ReRanker(rerank_model_path)
|
| 48 |
+
|
| 49 |
+
def fit_query_answer_rerank(self, query_docs: QADocs) -> List:
|
| 50 |
+
if query_docs is None or len(query_docs.documents) == 0:
|
| 51 |
+
return []
|
| 52 |
+
|
| 53 |
+
pair = [[query_docs.query, doc] for doc in query_docs.documents]
|
| 54 |
+
scores = self.reranker.compute_score(pair)
|
| 55 |
+
|
| 56 |
+
new_docs = []
|
| 57 |
+
for index, score in enumerate(scores):
|
| 58 |
+
new_docs.append({"index": index, "text": query_docs.documents[index], "score": score})
|
| 59 |
+
results = [{"index": documents["index"], "relevance_score": documents["score"]} for documents in list(sorted(new_docs, key=lambda x: x["score"], reverse=True))]
|
| 60 |
+
return results
|
| 61 |
+
|
| 62 |
+
@app.post('/v1/rerank')
|
| 63 |
+
async def handle_post_request(docs: QADocs, credentials: HTTPAuthorizationCredentials = Security(security)):
|
| 64 |
+
token = credentials.credentials
|
| 65 |
+
if sk_key is not None and token != sk_key:
|
| 66 |
+
raise HTTPException(status_code=401, detail="Invalid token")
|
| 67 |
+
chat = Chat()
|
| 68 |
+
try:
|
| 69 |
+
results = chat.fit_query_answer_rerank(docs)
|
| 70 |
+
return {"results": results}
|
| 71 |
+
except Exception as e:
|
| 72 |
+
print(f"报错:\n{e}")
|
| 73 |
+
return {"error": "重排出错"}
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
uvicorn.run("localrerank:app", host='0.0.0.0', port=7860, workers=1)
|