Update localrerank.py
Browse files- localrerank.py +8 -3
localrerank.py
CHANGED
|
@@ -5,7 +5,7 @@ import uvicorn
|
|
| 5 |
import datetime
|
| 6 |
from fastapi import FastAPI, Security, HTTPException
|
| 7 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 8 |
-
from FlagEmbedding import
|
| 9 |
from pydantic import Field, BaseModel, validator
|
| 10 |
from typing import Optional, List
|
| 11 |
|
|
@@ -44,7 +44,11 @@ class ReRanker(metaclass=Singleton):
|
|
| 44 |
|
| 45 |
class Chat(object):
|
| 46 |
def __init__(self, rerank_model_path: str = RERANK_MODEL_PATH):
|
| 47 |
-
self.reranker =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
def fit_query_answer_rerank(self, query_docs: QADocs) -> List:
|
| 50 |
if query_docs is None or len(query_docs.documents) == 0:
|
|
@@ -59,12 +63,13 @@ class Chat(object):
|
|
| 59 |
results = [{"index": documents["index"], "relevance_score": documents["score"]} for documents in list(sorted(new_docs, key=lambda x: x["score"], reverse=True))]
|
| 60 |
return results
|
| 61 |
|
|
|
|
|
|
|
| 62 |
@app.post('/v1/rerank')
|
| 63 |
async def handle_post_request(docs: QADocs, credentials: HTTPAuthorizationCredentials = Security(security)):
|
| 64 |
token = credentials.credentials
|
| 65 |
if sk_key is not None and token != sk_key:
|
| 66 |
raise HTTPException(status_code=401, detail="Invalid token")
|
| 67 |
-
chat = Chat()
|
| 68 |
try:
|
| 69 |
results = chat.fit_query_answer_rerank(docs)
|
| 70 |
return {"results": results}
|
|
|
|
| 5 |
import datetime
|
| 6 |
from fastapi import FastAPI, Security, HTTPException
|
| 7 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 8 |
+
from FlagEmbedding import FlagAutoReranker
|
| 9 |
from pydantic import Field, BaseModel, validator
|
| 10 |
from typing import Optional, List
|
| 11 |
|
|
|
|
| 44 |
|
| 45 |
class Chat(object):
|
| 46 |
def __init__(self, rerank_model_path: str = RERANK_MODEL_PATH):
|
| 47 |
+
self.reranker = FlagAutoReranker.from_finetuned(
|
| 48 |
+
model_path,
|
| 49 |
+
use_fp16=False,
|
| 50 |
+
trust_remote_code=True,
|
| 51 |
+
)
|
| 52 |
|
| 53 |
def fit_query_answer_rerank(self, query_docs: QADocs) -> List:
|
| 54 |
if query_docs is None or len(query_docs.documents) == 0:
|
|
|
|
| 63 |
results = [{"index": documents["index"], "relevance_score": documents["score"]} for documents in list(sorted(new_docs, key=lambda x: x["score"], reverse=True))]
|
| 64 |
return results
|
| 65 |
|
| 66 |
+
chat = Chat()
|
| 67 |
+
|
| 68 |
@app.post('/v1/rerank')
|
| 69 |
async def handle_post_request(docs: QADocs, credentials: HTTPAuthorizationCredentials = Security(security)):
|
| 70 |
token = credentials.credentials
|
| 71 |
if sk_key is not None and token != sk_key:
|
| 72 |
raise HTTPException(status_code=401, detail="Invalid token")
|
|
|
|
| 73 |
try:
|
| 74 |
results = chat.fit_query_answer_rerank(docs)
|
| 75 |
return {"results": results}
|