cicboy commited on
Commit
42d47d3
·
1 Parent(s): b60e7d3

update changes to hybrid_retriever_tool.py file

Browse files
Files changed (1) hide show
  1. tools/hybrid_retriever_tool.py +60 -5
tools/hybrid_retriever_tool.py CHANGED
@@ -8,6 +8,7 @@ from pydantic import Field, PrivateAttr
8
  import os
9
  from html import unescape
10
  import re
 
11
 
12
  class HybridRetrieverTool(RagTool):
13
  name: str = "Hybrid Retriever Tool"
@@ -76,7 +77,7 @@ class HybridRetrieverTool(RagTool):
76
 
77
  def _build_corpus(self, topic: str):
78
  """Fetch up-to-date search results."""
79
- results = self._tavily.search(query=topic, max_results=30)
80
  raw_texts = [r.get("content", "").strip() for r in results.get("results", []) if r.get("content")]
81
  corpus, all_urls = [], []
82
  for t in raw_texts:
@@ -87,7 +88,58 @@ class HybridRetrieverTool(RagTool):
87
 
88
  #Deduplicate and keep top unique URLs
89
  all_urls = list(dict.fromkeys(all_urls))[:5]
90
- return corpus, all_urls
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  def _run(self, query: str, top_k: int = 8) -> str:
93
  """
@@ -119,7 +171,8 @@ class HybridRetrieverTool(RagTool):
119
  top_indices= np.argsort(hybrid_scores)[::-1][:top_k]
120
 
121
  top_passages = [corpus[i] for i in top_indices]
122
- return "\n\n".join(top_passages)
 
123
 
124
  def summarize_passages(self, topic: str, passages):
125
  """Summarize retrieved content into a coherent short digest, keeping citations."""
@@ -142,7 +195,7 @@ class HybridRetrieverTool(RagTool):
142
  unique_texts = list(dict.fromkeys(main_text))[:5] # prevent duplication
143
  text_block = " ".join(unique_texts)
144
  text_block = re.sub(r"\s{2,}", " ", text_block).strip()
145
- text_block = text_block[:4000] # safety limit for token size
146
 
147
  unique_urls = list(dict.fromkeys(urls))[:5]
148
 
@@ -178,7 +231,9 @@ Return output in Markdown format.
178
  summary = response.choices[0].message.content.strip()
179
 
180
  if unique_urls:
181
- summary += "\n\n**Sources:**\n" + "\n".join(f"- {u}" for u in unique_urls) + "\n"
 
 
182
 
183
  return summary
184
 
 
8
  import os
9
  from html import unescape
10
  import re
11
+ import logging
12
 
13
  class HybridRetrieverTool(RagTool):
14
  name: str = "Hybrid Retriever Tool"
 
77
 
78
  def _build_corpus(self, topic: str):
79
  """Fetch up-to-date search results."""
80
+ results = self._tavily.search(query=topic, max_results=50)
81
  raw_texts = [r.get("content", "").strip() for r in results.get("results", []) if r.get("content")]
82
  corpus, all_urls = [], []
83
  for t in raw_texts:
 
88
 
89
  #Deduplicate and keep top unique URLs
90
  all_urls = list(dict.fromkeys(all_urls))[:5]
91
+ return corpus, all_urls
92
+
93
+ # LLM reranker
94
+ def _rerank(self, query: str, passages: list[str], top_n: int = 5) -> list[str]:
95
+ """
96
+ Use an LLM to re-rank retrieved passages for contextual relevance to the query.
97
+ """
98
+ if not passages:
99
+ return []
100
+
101
+ try:
102
+ formatted_passages = "\n\n".join(
103
+ [f"Passage {i+1}:\n{p}" for i, p in enumerate(passages)]
104
+ )
105
+
106
+ prompt = f"""
107
+ You are a precise research assistant that ranks text passages for relevance.
108
+
109
+ Query:
110
+ "{query}"
111
+
112
+ Passages:
113
+ {formatted_passages}
114
+
115
+ Instructions:
116
+ - Rank passages by how directly and substantively they address the query.
117
+ - Ignore repetitive, boilerplate, or promotional content.
118
+ - Return ONLY the top {top_n} most relevant passages, in their original text form.
119
+ """
120
+
121
+ response = self._client.chat.completions.create(
122
+ model="gpt-4o-mini",
123
+ messages=[
124
+ {"role": "system", "content": "You are an expert LLM reranker for information retrieval."},
125
+ {"role": "user", "content": prompt},
126
+ ],
127
+ temperature=0,
128
+ )
129
+
130
+ ranked_text = response.choices[0].message.content.strip()
131
+ reranked = re.split(r"Passage\s*\d+:", ranked_text)
132
+ reranked = [p.strip() for p in reranked if len(p.strip()) > 20]
133
+
134
+ if len(reranked) == 0:
135
+ print("⚠️ Reranker returned no valid text, using original order.")
136
+ return passages[:top_n]
137
+
138
+ return reranked[:top_n]
139
+
140
+ except Exception as e:
141
+ logging.warning(f"Reranker failed: {e}")
142
+ return passages[:top_n]
143
 
144
  def _run(self, query: str, top_k: int = 8) -> str:
145
  """
 
171
  top_indices= np.argsort(hybrid_scores)[::-1][:top_k]
172
 
173
  top_passages = [corpus[i] for i in top_indices]
174
+ reranked = self._rerank(query, top_passages)
175
+ return "\n\n".join(reranked)
176
 
177
  def summarize_passages(self, topic: str, passages):
178
  """Summarize retrieved content into a coherent short digest, keeping citations."""
 
195
  unique_texts = list(dict.fromkeys(main_text))[:5] # prevent duplication
196
  text_block = " ".join(unique_texts)
197
  text_block = re.sub(r"\s{2,}", " ", text_block).strip()
198
+ text_block = text_block[:5000] # safety limit for token size
199
 
200
  unique_urls = list(dict.fromkeys(urls))[:5]
201
 
 
231
  summary = response.choices[0].message.content.strip()
232
 
233
  if unique_urls:
234
+ if unique_urls:
235
+ summary += "\n\n**Sources:**\n" + "\n".join(f"- [{u}]({u})" for u in unique_urls)
236
+
237
 
238
  return summary
239