cicboy commited on
Commit
baf9e10
Β·
1 Parent(s): e30c01e

hybrid_retriever_tool for RAG and update application

Browse files
Files changed (2) hide show
  1. app.py +22 -1
  2. tools/hybrid_retriever_tool.py +70 -0
app.py CHANGED
@@ -6,6 +6,7 @@ import os
6
  from dotenv import load_dotenv
7
  from pathlib import Path
8
  import gradio as gr
 
9
 
10
  # control warnings
11
  warnings.filterwarnings("ignore")
@@ -19,16 +20,21 @@ llm_writer = LLM(model="gpt-5-mini", temperature=1.0)
19
  llm_editor = LLM(model="gpt-4-turbo", temperature=0.3)
20
  llm_fact=LLM(model="gpt-4o-mini", temperature=0.3)
21
 
 
 
 
22
  #Creating Agents
23
 
24
  planner = Agent(
25
  role="Content Planner",
26
  goal="Plan engaging and factually accurate content on {topic}",
27
  backstory="You are working on planning a blog article about the topic {topic}."
 
28
  "You collect relevant information that helps the audience learn something and make informed decisions. "
29
  "Your work is the basis for the Content Writer to write an article on this topic.",
30
  allow_delegation=False,
31
  verbose=True,
 
32
  llm=llm_planner
33
  )
34
 
@@ -48,9 +54,11 @@ writer = Agent(
48
  fact_checker = Agent(
49
  role="Fact Checker",
50
  goal="Verify factual accuracy, detect unsupported claims and identify missing references or sources.",
51
- backstory="You are a meticulous research analyst who checks every claim against known facts and relaible sources",
 
52
  allow_delegation=False,
53
  verbose=True,
 
54
  llm=llm_fact
55
  )
56
 
@@ -133,6 +141,15 @@ crew = Crew(
133
  verbose=True
134
  )
135
 
 
 
 
 
 
 
 
 
 
136
  # Define Gradio Handler
137
  def generate_blog(topic, tone):
138
  yield "⏳ Generating blog β€” this may take a few moments..."
@@ -180,10 +197,14 @@ with gr.Blocks(css="""
180
  label="Select Writing Tone",
181
  value="academic"
182
  )
 
 
 
183
 
184
  run_button = gr.Button("πŸš€ Generate Blog", variant="primary")
185
  output = gr.Textbox(label="πŸ“° Generated Blog Post", elem_id="output-box", lines=25, interactive=False, show_label=False)
186
 
 
187
  run_button.click(generate_blog, inputs=[topic, tone], outputs=[output])
188
 
189
  #Launch app
 
6
  from dotenv import load_dotenv
7
  from pathlib import Path
8
  import gradio as gr
9
+ from tools.hybrid_retriever_tool import HybridRetrieverTool
10
 
11
  # control warnings
12
  warnings.filterwarnings("ignore")
 
20
  llm_editor = LLM(model="gpt-4-turbo", temperature=0.3)
21
  llm_fact=LLM(model="gpt-4o-mini", temperature=0.3)
22
 
23
+ #Define tools
24
+ hybrid_tool = HybridRetrieverTool(alpha=0.6) #including RAG in search
25
+
26
  #Creating Agents
27
 
28
  planner = Agent(
29
  role="Content Planner",
30
  goal="Plan engaging and factually accurate content on {topic}",
31
  backstory="You are working on planning a blog article about the topic {topic}."
32
+ "Use the retriever tool to gather accurate, recent information before outlining." #RAG search
33
  "You collect relevant information that helps the audience learn something and make informed decisions. "
34
  "Your work is the basis for the Content Writer to write an article on this topic.",
35
  allow_delegation=False,
36
  verbose=True,
37
+ tools = [hybrid_tool] #Rag search
38
  llm=llm_planner
39
  )
40
 
 
54
  fact_checker = Agent(
55
  role="Fact Checker",
56
  goal="Verify factual accuracy, detect unsupported claims and identify missing references or sources.",
57
+ backstory="You are a meticulous research analyst who checks every claim against known facts and relaible sources"
58
+ "Use the retriever tool to cross-check the Content Writer's statements against reliable, recent information.",
59
  allow_delegation=False,
60
  verbose=True,
61
+ tools = [hybrid_tool], #Rag search
62
  llm=llm_fact
63
  )
64
 
 
141
  verbose=True
142
  )
143
 
144
+ # fetch context for RAG search
145
+ def fetch_context(topic):
146
+ passages = hybrid_tool._run(topic)
147
+ if isinstance(passages, str):
148
+ summary = passages
149
+ else:
150
+ summary = hybrid_tool.summarize_passages(topic, passages)
151
+ return summary
152
+
153
  # Define Gradio Handler
154
  def generate_blog(topic, tone):
155
  yield "⏳ Generating blog β€” this may take a few moments..."
 
197
  label="Select Writing Tone",
198
  value="academic"
199
  )
200
+ fetch_btn = gr.Button("🌐 Fetch & Summarize Context", variant="secondary") # Rag Search
201
+ context_output = gr.Markdown(label="πŸ“š Retrieved Context Summary") # Rag Search
202
+
203
 
204
  run_button = gr.Button("πŸš€ Generate Blog", variant="primary")
205
  output = gr.Textbox(label="πŸ“° Generated Blog Post", elem_id="output-box", lines=25, interactive=False, show_label=False)
206
 
207
+ fetch_btn.click(fetch_context, inputs=[topic], outputs=[context_output]) # Rag Search
208
  run_button.click(generate_blog, inputs=[topic, tone], outputs=[output])
209
 
210
  #Launch app
tools/hybrid_retriever_tool.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from rank_bm25 import BM25Okapi
3
+ from sentence_transformers import SentenceTransformer
4
+ from tavily import TavilyClient
5
+ from openai import OpenAI
6
+ import os
7
+
8
+ class HybridRetrieverTool:
9
+ """
10
+ Dynamically builds a hybrid BM25 + semantic retriever from live Tavily results.
11
+ """
12
+
13
+ def __init__(self, alpha=0.6):
14
+ self.alpha = alpha
15
+ self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
16
+ self.tavily = TavilyClient(api_key=os.getenv("TAVILITY_API_KEY"))
17
+ self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
18
+
19
+ def _build_corpus(self, topic):
20
+ """Fetch up-to-date search results."""
21
+ results = self.tavily.search(query=topic, max_results=30)
22
+ corpus = []
23
+ for r in results.get("results", []):
24
+ content = r.get("content") or ""
25
+ if len(content.strip()) > 0:
26
+ corpus.append(content)
27
+ return corpus
28
+
29
+ def _run(self, query, top_k=8):
30
+ """
31
+ Run hybrid search: BM25 + semantic similarity.
32
+ """
33
+ corpus = self._build_corpus(query)
34
+ if not corpus:
35
+ return "No relevant content found."
36
+
37
+ bm25 = BM25Okapi([doc.split() for doc in corpus])
38
+ bm25_scores = np.array(bm25.get(query.split()))
39
+
40
+ emb_corpus = self.embedder.encode(corpus, convert_to_numpy=True, normalize_embeddings=True)
41
+ emb_query = self.embedder.encode(query, convert_to_numpy=True, normalize_embeddings=True)
42
+ sem_scores = np.dot(emb_corpus, emb_query)
43
+
44
+ # Normalize scores
45
+ bm25_norm = (bm25_scores - bm25_scores.min()) / (bm25_scores.ptp() + 1e-8)
46
+ sem_norm = (sem_scores - sem_scores.min()) / (sem_scores.ptp() + 1e-8)
47
+
48
+ # Weighted fusion
49
+ hybrid_scores = self.alpha * sem_norm + (1 - self.alpha) * bm25_norm
50
+ top_indices= np.argsort(hybrid_scores)[::-1][:top_k]
51
+
52
+ top_passages = [corpus[i] for i in top_indices]
53
+ return "\n\n".join(top_passages)
54
+
55
+ def summarize_passages(self, topic, passages):
56
+ if isinstance(passages, str):
57
+ passages = [passages]
58
+ text_block = "\n".join(passages)
59
+ try:
60
+ response = self.client.chat.completions.create(
61
+ model="gpt-4o-mini",
62
+ messages=[
63
+ {"role": "system", "content": "You are an expert summarizer."},
64
+ {"role": "user", "content": f"Summarize these passages about {topic}"}
65
+ ],
66
+ temperature=0.3
67
+ )
68
+ return response.choices[0].message.content.strip()
69
+ except Exception as e:
70
+ return f"Summarization failed: {e}"