cicboy commited on
Commit
f158f29
·
1 Parent(s): 8b28e9a

update hybrid_retriever_tool file

Browse files
Files changed (1) hide show
  1. tools/hybrid_retriever_tool.py +3 -38
tools/hybrid_retriever_tool.py CHANGED
@@ -8,7 +8,6 @@ from pydantic import Field, PrivateAttr
8
  import os
9
  from html import unescape
10
  import re
11
- import json
12
 
13
  class HybridRetrieverTool(RagTool):
14
  name: str = "Hybrid Retriever Tool"
@@ -88,39 +87,8 @@ class HybridRetrieverTool(RagTool):
88
 
89
  #Deduplicate and keep top unique URLs
90
  all_urls = list(dict.fromkeys(all_urls))[:5]
91
- return corpus, all_urls
92
-
93
- def _rerank(self, query:str, passages: list[str]) -> list[str]:
94
- """Use LLM to rerank received passages for contextual relevance"""
95
- try:
96
- prompt = f"""
97
- You are a research assistant. Rank the following passages by how relevant they are to the topic:
98
- "{query}"
99
-
100
- Return a JSON array of the top 5 passages(most to least relevant).
101
-
102
- Passages:
103
- {json.dumps(passages, indent=2)}
104
- """
105
- response = self. _client.chat.completions.create(
106
- model = "gpt-4o-mini",
107
- messages=[
108
- {"role": "system", "content": "You are an expert re-ranker for information retrieval."},
109
- {"role": "user", "content": prompt}
110
- ],
111
- temperature=0
112
- )
113
- content = response.choices[0].message.content.strip()
114
- try:
115
- ranked = json.loads(response.choices[0].message.content)
116
- # Keep only valid strings
117
- ranked = [p for p in ranked if isinstance(p, str)]
118
- return ranked if ranked else passages
119
- except json.JSONDecodeError:
120
- print("⚠️ Reranker returned non-JSON output, using original order.")
121
- except Exception as e:
122
- print(f"Re-ranker failed: {e}")
123
- return passages
124
 
125
  def _run(self, query: str, top_k: int = 8) -> str:
126
  """
@@ -152,9 +120,6 @@ class HybridRetrieverTool(RagTool):
152
  top_indices= np.argsort(hybrid_scores)[::-1][:top_k]
153
 
154
  top_passages = [corpus[i] for i in top_indices]
155
-
156
- #LLM-based re-ranker
157
- reranked = self._rerank(query, top_passages)
158
  return "\n\n".join(top_passages)
159
 
160
  def summarize_passages(self, topic: str, passages):
@@ -214,7 +179,7 @@ Return output in Markdown format.
214
  summary = response.choices[0].message.content.strip()
215
 
216
  if unique_urls:
217
- summary += "\n\n**Sources:**\n" + "\n".join(f"- {u}" for u in unique_urls)
218
 
219
  return summary
220
 
 
8
  import os
9
  from html import unescape
10
  import re
 
11
 
12
  class HybridRetrieverTool(RagTool):
13
  name: str = "Hybrid Retriever Tool"
 
87
 
88
  #Deduplicate and keep top unique URLs
89
  all_urls = list(dict.fromkeys(all_urls))[:5]
90
+ print(f"[HybridRetrieverTool] Retrieved {len(corpus)} docs, {len(all_urls)} unique URLs for '{topic}'")
91
+ return corpus, all_urls
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  def _run(self, query: str, top_k: int = 8) -> str:
94
  """
 
120
  top_indices= np.argsort(hybrid_scores)[::-1][:top_k]
121
 
122
  top_passages = [corpus[i] for i in top_indices]
 
 
 
123
  return "\n\n".join(top_passages)
124
 
125
  def summarize_passages(self, topic: str, passages):
 
179
  summary = response.choices[0].message.content.strip()
180
 
181
  if unique_urls:
182
+ summary += "\n\n**Sources:**\n" + "\n".join(f"- {u}" for u in unique_urls) + "\n"
183
 
184
  return summary
185