Batch_RAG / main.py
DolAr1610
add new logic
a5c9fa3
import streamlit as st
from search.search_classical import classical_search, classical_retrieve_chunks
from search.search_best_pair import best_pair_search
from llm import generate_response
def pick_mode(label: str) -> str:
if label.startswith("Semantic"):
return "semantic"
if label.startswith("Keyword"):
return "bm25"
return "hybrid"
st.set_page_config(page_title="πŸ” Multimodal Search The Batch")
st.image("data/the-batch-logo.webp", width=300)
st.title("Multimodal Assistant")
mode = st.selectbox("πŸ”Ž Select the search mode:", ["Classical RAG", "Multimodal RAG"])
query = st.text_input("πŸ“ Enter the text query:")
# --- Classical controls ---
classical_retriever = "Semantic (embeddings)"
use_reranker = True
if mode == "Classical RAG":
classical_retriever = st.radio(
"🧩 Classical retrieval:",
["Semantic (embeddings)", "Keyword (BM25)", "Hybrid (BM25 + Semantic)"],
horizontal=True
)
use_reranker = st.checkbox("✨ Use reranker (cross-encoder)", value=True)
# --- Preview results ---
results = []
if query:
if mode == "Classical RAG":
search_mode = pick_mode(classical_retriever)
results = classical_search(query, k=3, mode=search_mode)
else:
results = best_pair_search(query, k=3)
st.markdown(f"### πŸ“„ Results found: {len(results)}")
for i, meta in enumerate(results):
st.markdown(f"### πŸ”Ή Result {i + 1}")
if meta.get("title"):
st.markdown(f"**πŸ“– Name:** {meta['title']}")
if meta.get("date"):
st.markdown(f"**πŸ“… Date of publication:** {meta['date']}")
if meta.get("description"):
st.markdown(f"**πŸ“ Description:** {meta['description']}")
if meta.get("image_url"):
st.image(meta["image_url"], use_container_width=True)
if meta.get("content"):
st.markdown("**πŸ“š Part of the article:**")
st.write(meta["content"][:500] + "...")
if meta.get("source_url"):
st.markdown(f"[πŸ”— Read the full article β†’]({meta['source_url']})")
st.markdown("---")
# --- Generate answer ---
if query and st.button("🧠 Generate a response to a query"):
if mode == "Classical RAG":
search_mode = pick_mode(classical_retriever)
chunks = classical_retrieve_chunks(
query=query,
mode=search_mode,
fetch_k=50,
rerank_k=5,
use_reranker=use_reranker
)
docs = []
for idx, c in enumerate(chunks, start=1):
meta = c.get("metadata", {})
docs.append({
"id": idx,
"title": meta.get("title", ""),
"description": meta.get("description", ""),
"source_url": meta.get("source_url", ""),
"content": c.get("chunk_text", ""),
"retriever": c.get("retriever", ""),
"rerank_score": c.get("rerank_score", None),
})
response = generate_response(query, docs)
st.markdown("### πŸ€– Generated Response:")
st.success(response)
st.markdown("### πŸ“Œ Sources")
for d in docs:
st.markdown(f"**[{d['id']}] {d.get('title','')}**")
if d.get("source_url"):
st.markdown(d["source_url"])
st.write((d.get("content") or "")[:450] + "...")
if d.get("retriever"):
st.caption(f"retriever: {d['retriever']}")
if d.get("rerank_score") is not None:
st.caption(f"rerank_score: {d['rerank_score']:.4f}")
st.markdown("---")
else:
# βœ… Multimodal mode:
# Preview stays multimodal (best_pair_search),
# but the ANSWER is generated from TEXT chunks (hybrid) for reliable QA + citations.
chunks = classical_retrieve_chunks(
query=query,
mode="hybrid",
fetch_k=50,
rerank_k=5,
use_reranker=True
)
docs = []
for idx, c in enumerate(chunks, start=1):
meta = c.get("metadata", {})
docs.append({
"id": idx,
"title": meta.get("title", ""),
"description": meta.get("description", ""),
"source_url": meta.get("source_url", ""),
"content": c.get("chunk_text", ""),
"retriever": c.get("retriever", ""),
"rerank_score": c.get("rerank_score", None),
})
response = generate_response(query, docs)
st.markdown("### πŸ€– Generated Response:")
st.success(response)
st.markdown("### πŸ“Œ Sources (text chunks)")
for d in docs:
st.markdown(f"**[{d['id']}] {d.get('title','')}**")
if d.get("source_url"):
st.markdown(d["source_url"])
st.write((d.get("content") or "")[:450] + "...")
if d.get("retriever"):
st.caption(f"retriever: {d['retriever']}")
if d.get("rerank_score") is not None:
st.caption(f"rerank_score: {d['rerank_score']:.4f}")
st.markdown("---")