Spaces:
Sleeping
Sleeping
hybrid_retriever_tool for RAG and update application
Browse files- app.py +22 -1
- tools/hybrid_retriever_tool.py +70 -0
app.py
CHANGED
|
@@ -6,6 +6,7 @@ import os
|
|
| 6 |
from dotenv import load_dotenv
|
| 7 |
from pathlib import Path
|
| 8 |
import gradio as gr
|
|
|
|
| 9 |
|
| 10 |
# control warnings
|
| 11 |
warnings.filterwarnings("ignore")
|
|
@@ -19,16 +20,21 @@ llm_writer = LLM(model="gpt-5-mini", temperature=1.0)
|
|
| 19 |
llm_editor = LLM(model="gpt-4-turbo", temperature=0.3)
|
| 20 |
llm_fact=LLM(model="gpt-4o-mini", temperature=0.3)
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
#Creating Agents
|
| 23 |
|
| 24 |
planner = Agent(
|
| 25 |
role="Content Planner",
|
| 26 |
goal="Plan engaging and factually accurate content on {topic}",
|
| 27 |
backstory="You are working on planning a blog article about the topic {topic}."
|
|
|
|
| 28 |
"You collect relevant information that helps the audience learn something and make informed decisions. "
|
| 29 |
"Your work is the basis for the Content Writer to write an article on this topic.",
|
| 30 |
allow_delegation=False,
|
| 31 |
verbose=True,
|
|
|
|
| 32 |
llm=llm_planner
|
| 33 |
)
|
| 34 |
|
|
@@ -48,9 +54,11 @@ writer = Agent(
|
|
| 48 |
fact_checker = Agent(
|
| 49 |
role="Fact Checker",
|
| 50 |
goal="Verify factual accuracy, detect unsupported claims and identify missing references or sources.",
|
| 51 |
-
backstory="You are a meticulous research analyst who checks every claim against known facts and relaible sources"
|
|
|
|
| 52 |
allow_delegation=False,
|
| 53 |
verbose=True,
|
|
|
|
| 54 |
llm=llm_fact
|
| 55 |
)
|
| 56 |
|
|
@@ -133,6 +141,15 @@ crew = Crew(
|
|
| 133 |
verbose=True
|
| 134 |
)
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
# Define Gradio Handler
|
| 137 |
def generate_blog(topic, tone):
|
| 138 |
yield "β³ Generating blog β this may take a few moments..."
|
|
@@ -180,10 +197,14 @@ with gr.Blocks(css="""
|
|
| 180 |
label="Select Writing Tone",
|
| 181 |
value="academic"
|
| 182 |
)
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
run_button = gr.Button("π Generate Blog", variant="primary")
|
| 185 |
output = gr.Textbox(label="π° Generated Blog Post", elem_id="output-box", lines=25, interactive=False, show_label=False)
|
| 186 |
|
|
|
|
| 187 |
run_button.click(generate_blog, inputs=[topic, tone], outputs=[output])
|
| 188 |
|
| 189 |
#Launch app
|
|
|
|
| 6 |
from dotenv import load_dotenv
|
| 7 |
from pathlib import Path
|
| 8 |
import gradio as gr
|
| 9 |
+
from tools.hybrid_retriever_tool import HybridRetrieverTool
|
| 10 |
|
| 11 |
# control warnings
|
| 12 |
warnings.filterwarnings("ignore")
|
|
|
|
| 20 |
llm_editor = LLM(model="gpt-4-turbo", temperature=0.3)
|
| 21 |
llm_fact=LLM(model="gpt-4o-mini", temperature=0.3)
|
| 22 |
|
| 23 |
+
#Define tools
|
| 24 |
+
hybrid_tool = HybridRetrieverTool(alpha=0.6) #including RAG in search
|
| 25 |
+
|
| 26 |
#Creating Agents
|
| 27 |
|
| 28 |
planner = Agent(
|
| 29 |
role="Content Planner",
|
| 30 |
goal="Plan engaging and factually accurate content on {topic}",
|
| 31 |
backstory="You are working on planning a blog article about the topic {topic}."
|
| 32 |
+
"Use the retriever tool to gather accurate, recent information before outlining." #RAG search
|
| 33 |
"You collect relevant information that helps the audience learn something and make informed decisions. "
|
| 34 |
"Your work is the basis for the Content Writer to write an article on this topic.",
|
| 35 |
allow_delegation=False,
|
| 36 |
verbose=True,
|
| 37 |
+
tools = [hybrid_tool] #Rag search
|
| 38 |
llm=llm_planner
|
| 39 |
)
|
| 40 |
|
|
|
|
| 54 |
fact_checker = Agent(
|
| 55 |
role="Fact Checker",
|
| 56 |
goal="Verify factual accuracy, detect unsupported claims and identify missing references or sources.",
|
| 57 |
+
backstory="You are a meticulous research analyst who checks every claim against known facts and relaible sources"
|
| 58 |
+
"Use the retriever tool to cross-check the Content Writer's statements against reliable, recent information.",
|
| 59 |
allow_delegation=False,
|
| 60 |
verbose=True,
|
| 61 |
+
tools = [hybrid_tool], #Rag search
|
| 62 |
llm=llm_fact
|
| 63 |
)
|
| 64 |
|
|
|
|
| 141 |
verbose=True
|
| 142 |
)
|
| 143 |
|
| 144 |
+
# fetch context for RAG search
|
| 145 |
+
def fetch_context(topic):
|
| 146 |
+
passages = hybrid_tool._run(topic)
|
| 147 |
+
if isinstance(passages, str):
|
| 148 |
+
summary = passages
|
| 149 |
+
else:
|
| 150 |
+
summary = hybrid_tool.summarize_passages(topic, passages)
|
| 151 |
+
return summary
|
| 152 |
+
|
| 153 |
# Define Gradio Handler
|
| 154 |
def generate_blog(topic, tone):
|
| 155 |
yield "β³ Generating blog β this may take a few moments..."
|
|
|
|
| 197 |
label="Select Writing Tone",
|
| 198 |
value="academic"
|
| 199 |
)
|
| 200 |
+
fetch_btn = gr.Button("π Fetch & Summarize Context", variant="secondary") # Rag Search
|
| 201 |
+
context_output = gr.Markdown(label="π Retrieved Context Summary") # Rag Search
|
| 202 |
+
|
| 203 |
|
| 204 |
run_button = gr.Button("π Generate Blog", variant="primary")
|
| 205 |
output = gr.Textbox(label="π° Generated Blog Post", elem_id="output-box", lines=25, interactive=False, show_label=False)
|
| 206 |
|
| 207 |
+
fetch_btn.click(fetch_context, inputs=[topic], outputs=[context_output]) # Rag Search
|
| 208 |
run_button.click(generate_blog, inputs=[topic, tone], outputs=[output])
|
| 209 |
|
| 210 |
#Launch app
|
tools/hybrid_retriever_tool.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from rank_bm25 import BM25Okapi
|
| 3 |
+
from sentence_transformers import SentenceTransformer
|
| 4 |
+
from tavily import TavilyClient
|
| 5 |
+
from openai import OpenAI
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
class HybridRetrieverTool:
|
| 9 |
+
"""
|
| 10 |
+
Dynamically builds a hybrid BM25 + semantic retriever from live Tavily results.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, alpha=0.6):
|
| 14 |
+
self.alpha = alpha
|
| 15 |
+
self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
| 16 |
+
self.tavily = TavilyClient(api_key=os.getenv("TAVILITY_API_KEY"))
|
| 17 |
+
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 18 |
+
|
| 19 |
+
def _build_corpus(self, topic):
|
| 20 |
+
"""Fetch up-to-date search results."""
|
| 21 |
+
results = self.tavily.search(query=topic, max_results=30)
|
| 22 |
+
corpus = []
|
| 23 |
+
for r in results.get("results", []):
|
| 24 |
+
content = r.get("content") or ""
|
| 25 |
+
if len(content.strip()) > 0:
|
| 26 |
+
corpus.append(content)
|
| 27 |
+
return corpus
|
| 28 |
+
|
| 29 |
+
def _run(self, query, top_k=8):
|
| 30 |
+
"""
|
| 31 |
+
Run hybrid search: BM25 + semantic similarity.
|
| 32 |
+
"""
|
| 33 |
+
corpus = self._build_corpus(query)
|
| 34 |
+
if not corpus:
|
| 35 |
+
return "No relevant content found."
|
| 36 |
+
|
| 37 |
+
bm25 = BM25Okapi([doc.split() for doc in corpus])
|
| 38 |
+
bm25_scores = np.array(bm25.get(query.split()))
|
| 39 |
+
|
| 40 |
+
emb_corpus = self.embedder.encode(corpus, convert_to_numpy=True, normalize_embeddings=True)
|
| 41 |
+
emb_query = self.embedder.encode(query, convert_to_numpy=True, normalize_embeddings=True)
|
| 42 |
+
sem_scores = np.dot(emb_corpus, emb_query)
|
| 43 |
+
|
| 44 |
+
# Normalize scores
|
| 45 |
+
bm25_norm = (bm25_scores - bm25_scores.min()) / (bm25_scores.ptp() + 1e-8)
|
| 46 |
+
sem_norm = (sem_scores - sem_scores.min()) / (sem_scores.ptp() + 1e-8)
|
| 47 |
+
|
| 48 |
+
# Weighted fusion
|
| 49 |
+
hybrid_scores = self.alpha * sem_norm + (1 - self.alpha) * bm25_norm
|
| 50 |
+
top_indices= np.argsort(hybrid_scores)[::-1][:top_k]
|
| 51 |
+
|
| 52 |
+
top_passages = [corpus[i] for i in top_indices]
|
| 53 |
+
return "\n\n".join(top_passages)
|
| 54 |
+
|
| 55 |
+
def summarize_passages(self, topic, passages):
|
| 56 |
+
if isinstance(passages, str):
|
| 57 |
+
passages = [passages]
|
| 58 |
+
text_block = "\n".join(passages)
|
| 59 |
+
try:
|
| 60 |
+
response = self.client.chat.completions.create(
|
| 61 |
+
model="gpt-4o-mini",
|
| 62 |
+
messages=[
|
| 63 |
+
{"role": "system", "content": "You are an expert summarizer."},
|
| 64 |
+
{"role": "user", "content": f"Summarize these passages about {topic}"}
|
| 65 |
+
],
|
| 66 |
+
temperature=0.3
|
| 67 |
+
)
|
| 68 |
+
return response.choices[0].message.content.strip()
|
| 69 |
+
except Exception as e:
|
| 70 |
+
return f"Summarization failed: {e}"
|