fyerfyer commited on
Commit
b10e29c
·
1 Parent(s): c9531de

added rerank component

Browse files
Files changed (2) hide show
  1. app.py +23 -3
  2. requirements.txt +2 -1
app.py CHANGED
@@ -4,6 +4,8 @@ import gradio as gr
4
  from openai import OpenAI
5
  from qdrant_client import QdrantClient
6
  from sentence_transformers import SentenceTransformer
 
 
7
 
8
  API_KEY = os.environ.get('DEEPSEEK_API_KEY')
9
  BASE_URL = "https://api.deepseek.com"
@@ -40,6 +42,8 @@ class HFRAG:
40
  http_client=httpx.Client(proxy=None, trust_env=False)
41
  )
42
 
 
 
43
  def retrieve(self, query: str, top_k: int = 5, score_threshold: float = 0.40):
44
  query_vector = self.embed_model.encode(query).tolist()
45
 
@@ -47,18 +51,34 @@ class HFRAG:
47
  results = self.db_client.search(
48
  collection_name=COLLECTION_NAME,
49
  query_vector=query_vector,
50
- limit=top_k,
51
  score_threshold=score_threshold
52
  )
53
  else:
54
  results = self.db_client.query_points(
55
  collection_name=COLLECTION_NAME,
56
  query=query_vector,
57
- limit=top_k,
58
  with_payload=True,
59
  score_threshold=score_threshold
60
  ).points
61
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  def format_context(self, search_results):
64
  context_pieces = []
 
4
  from openai import OpenAI
5
  from qdrant_client import QdrantClient
6
  from sentence_transformers import SentenceTransformer
7
+ from flashrank import Ranker, RerankRequest
8
+ from types import SimpleNamespace
9
 
10
  API_KEY = os.environ.get('DEEPSEEK_API_KEY')
11
  BASE_URL = "https://api.deepseek.com"
 
42
  http_client=httpx.Client(proxy=None, trust_env=False)
43
  )
44
 
45
+ self.reranker = Ranker(model_name="ms-marco-TinyBERT-L-2-v2", cache_dir="/tmp")
46
+
47
  def retrieve(self, query: str, top_k: int = 5, score_threshold: float = 0.40):
48
  query_vector = self.embed_model.encode(query).tolist()
49
 
 
51
  results = self.db_client.search(
52
  collection_name=COLLECTION_NAME,
53
  query_vector=query_vector,
54
+ limit=20, # 扩大召回范围,之后进行重排序
55
  score_threshold=score_threshold
56
  )
57
  else:
58
  results = self.db_client.query_points(
59
  collection_name=COLLECTION_NAME,
60
  query=query_vector,
61
+ limit=20,
62
  with_payload=True,
63
  score_threshold=score_threshold
64
  ).points
65
+
66
+ passages = [
67
+ {"id": result.payload['metadata']['source'], "text": result.payload['text'], "meta": result.payload}
68
+ for result in results
69
+ ]
70
+ rerank_request = RerankRequest(query=query, passages=passages)
71
+ reranked_results = self.reranker.rerank(rerank_request)
72
+
73
+ # 从重排序后的序列中取出 TopK
74
+ final_results = []
75
+ for item in reranked_results[:top_k]:
76
+ final_result = SimpleNamespace()
77
+ final_result.payload = item['meta']
78
+ final_result.score = item['score']
79
+ final_results.append(final_result)
80
+
81
+ return final_results
82
 
83
  def format_context(self, search_results):
84
  context_pieces = []
requirements.txt CHANGED
@@ -4,4 +4,5 @@ qdrant-client
4
  sentence-transformers
5
  httpx
6
  torch
7
- python-dotenv
 
 
4
  sentence-transformers
5
  httpx
6
  torch
7
+ python-dotenv
8
+ flashrank