cicboy commited on
Commit
2237728
·
1 Parent(s): 9ab1b3b

update hybrid_retriever_tool file

Browse files
Files changed (1) hide show
  1. tools/hybrid_retriever_tool.py +36 -0
tools/hybrid_retriever_tool.py CHANGED
@@ -8,6 +8,7 @@ from pydantic import Field, PrivateAttr
8
  import os
9
  from html import unescape
10
  import re
 
11
 
12
  class HybridRetrieverTool(RagTool):
13
  name: str = "Hybrid Retriever Tool"
@@ -89,6 +90,38 @@ class HybridRetrieverTool(RagTool):
89
  all_urls = list(dict.fromkeys(all_urls))[:5]
90
  return corpus, all_urls
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def _run(self, query: str, top_k: int = 8) -> str:
93
  """
94
  Run hybrid search: BM25 + semantic similarity.
@@ -119,6 +152,9 @@ class HybridRetrieverTool(RagTool):
119
  top_indices= np.argsort(hybrid_scores)[::-1][:top_k]
120
 
121
  top_passages = [corpus[i] for i in top_indices]
 
 
 
122
  return "\n\n".join(top_passages)
123
 
124
  def summarize_passages(self, topic: str, passages):
 
8
  import os
9
  from html import unescape
10
  import re
11
+ import json
12
 
13
  class HybridRetrieverTool(RagTool):
14
  name: str = "Hybrid Retriever Tool"
 
90
  all_urls = list(dict.fromkeys(all_urls))[:5]
91
  return corpus, all_urls
92
 
93
+ def _rerank(self, query:str, passages: list[str]) -> list[str]:
94
+ """Use LLM to rerank received passages for contextual relevance"""
95
+ try:
96
+ prompt = f"""
97
+ You are a research assistant. Rank the following passages by how relevant they are to the topic:
98
+ "{query}"
99
+
100
+ Return a JSON array of the top 5 passages(most to least relevant).
101
+
102
+ Passages:
103
+ {json.dumps(passages, indent=2)}
104
+ """
105
+ response = self. _client.chat.completions.create(
106
+ model = "gpt-4o-mini",
107
+ messages=[
108
+ {"role": "system", "content": "You are an expert re-ranker for information retrieval."},
109
+ {"role": "user", "content": prompt}
110
+ ],
111
+ temperature=0
112
+ )
113
+ content = response.choices[0].message.conten.strip()
114
+ try:
115
+ ranked = json.loads(response.choices[0].message.content)
116
+ # Keep only valid strings
117
+ ranked = [p for p in ranked if isinstance(p, str)]
118
+ return ranked if ranked else passages
119
+ except json.JSONDecodeError:
120
+ print("⚠️ Reranker returned non-JSON output, using original order.")
121
+ except Exception as e:
122
+ print(f"Re-ranker failed: {e}")
123
+ return passages
124
+
125
  def _run(self, query: str, top_k: int = 8) -> str:
126
  """
127
  Run hybrid search: BM25 + semantic similarity.
 
152
  top_indices= np.argsort(hybrid_scores)[::-1][:top_k]
153
 
154
  top_passages = [corpus[i] for i in top_indices]
155
+
156
+ #LLM-based re-ranker
157
+ reranked = self._rerank(query, top_passages)
158
  return "\n\n".join(top_passages)
159
 
160
  def summarize_passages(self, topic: str, passages):