1bincooo commited on
Commit
4e680d3
·
verified ·
1 Parent(s): 5b8f932

Update localrerank.py

Browse files
Files changed (1) hide show
  1. localrerank.py +8 -3
localrerank.py CHANGED
@@ -5,7 +5,7 @@ import uvicorn
5
  import datetime
6
  from fastapi import FastAPI, Security, HTTPException
7
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
8
- from FlagEmbedding import FlagReranker
9
  from pydantic import Field, BaseModel, validator
10
  from typing import Optional, List
11
 
@@ -44,7 +44,11 @@ class ReRanker(metaclass=Singleton):
44
 
45
  class Chat(object):
46
  def __init__(self, rerank_model_path: str = RERANK_MODEL_PATH):
47
- self.reranker = ReRanker(rerank_model_path)
 
 
 
 
48
 
49
  def fit_query_answer_rerank(self, query_docs: QADocs) -> List:
50
  if query_docs is None or len(query_docs.documents) == 0:
@@ -59,12 +63,13 @@ class Chat(object):
59
  results = [{"index": documents["index"], "relevance_score": documents["score"]} for documents in list(sorted(new_docs, key=lambda x: x["score"], reverse=True))]
60
  return results
61
 
 
 
62
  @app.post('/v1/rerank')
63
  async def handle_post_request(docs: QADocs, credentials: HTTPAuthorizationCredentials = Security(security)):
64
  token = credentials.credentials
65
  if sk_key is not None and token != sk_key:
66
  raise HTTPException(status_code=401, detail="Invalid token")
67
- chat = Chat()
68
  try:
69
  results = chat.fit_query_answer_rerank(docs)
70
  return {"results": results}
 
5
  import datetime
6
  from fastapi import FastAPI, Security, HTTPException
7
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
8
+ from FlagEmbedding import FlagAutoReranker
9
  from pydantic import Field, BaseModel, validator
10
  from typing import Optional, List
11
 
 
44
 
45
  class Chat(object):
46
  def __init__(self, rerank_model_path: str = RERANK_MODEL_PATH):
47
+ self.reranker = FlagAutoReranker.from_finetuned(
48
+ model_path,
49
+ use_fp16=False,
50
+ trust_remote_code=True,
51
+ )
52
 
53
  def fit_query_answer_rerank(self, query_docs: QADocs) -> List:
54
  if query_docs is None or len(query_docs.documents) == 0:
 
63
  results = [{"index": documents["index"], "relevance_score": documents["score"]} for documents in list(sorted(new_docs, key=lambda x: x["score"], reverse=True))]
64
  return results
65
 
66
+ chat = Chat()
67
+
68
  @app.post('/v1/rerank')
69
  async def handle_post_request(docs: QADocs, credentials: HTTPAuthorizationCredentials = Security(security)):
70
  token = credentials.credentials
71
  if sk_key is not None and token != sk_key:
72
  raise HTTPException(status_code=401, detail="Invalid token")
 
73
  try:
74
  results = chat.fit_query_answer_rerank(docs)
75
  return {"results": results}