| import streamlit as st |
|
|
| from search.search_classical import classical_search, classical_retrieve_chunks |
| from search.search_best_pair import best_pair_search |
| from llm import generate_response |
|
|
|
|
| def pick_mode(label: str) -> str: |
| if label.startswith("Semantic"): |
| return "semantic" |
| if label.startswith("Keyword"): |
| return "bm25" |
| return "hybrid" |
|
|
|
|
| st.set_page_config(page_title="π Multimodal Search The Batch") |
| st.image("data/the-batch-logo.webp", width=300) |
| st.title("Multimodal Assistant") |
|
|
| mode = st.selectbox("π Select the search mode:", ["Classical RAG", "Multimodal RAG"]) |
| query = st.text_input("π Enter the text query:") |
|
|
| |
| classical_retriever = "Semantic (embeddings)" |
| use_reranker = True |
|
|
| if mode == "Classical RAG": |
| classical_retriever = st.radio( |
| "π§© Classical retrieval:", |
| ["Semantic (embeddings)", "Keyword (BM25)", "Hybrid (BM25 + Semantic)"], |
| horizontal=True |
| ) |
| use_reranker = st.checkbox("β¨ Use reranker (cross-encoder)", value=True) |
|
|
| |
| results = [] |
| if query: |
| if mode == "Classical RAG": |
| search_mode = pick_mode(classical_retriever) |
| results = classical_search(query, k=3, mode=search_mode) |
| else: |
| results = best_pair_search(query, k=3) |
|
|
| st.markdown(f"### π Results found: {len(results)}") |
|
|
| for i, meta in enumerate(results): |
| st.markdown(f"### πΉ Result {i + 1}") |
| if meta.get("title"): |
| st.markdown(f"**π Name:** {meta['title']}") |
| if meta.get("date"): |
| st.markdown(f"**π
Date of publication:** {meta['date']}") |
| if meta.get("description"): |
| st.markdown(f"**π Description:** {meta['description']}") |
| if meta.get("image_url"): |
| st.image(meta["image_url"], use_container_width=True) |
| if meta.get("content"): |
| st.markdown("**π Part of the article:**") |
| st.write(meta["content"][:500] + "...") |
| if meta.get("source_url"): |
| st.markdown(f"[π Read the full article β]({meta['source_url']})") |
| st.markdown("---") |
|
|
| |
| if query and st.button("π§ Generate a response to a query"): |
|
|
| if mode == "Classical RAG": |
| search_mode = pick_mode(classical_retriever) |
|
|
| chunks = classical_retrieve_chunks( |
| query=query, |
| mode=search_mode, |
| fetch_k=50, |
| rerank_k=5, |
| use_reranker=use_reranker |
| ) |
|
|
| docs = [] |
| for idx, c in enumerate(chunks, start=1): |
| meta = c.get("metadata", {}) |
| docs.append({ |
| "id": idx, |
| "title": meta.get("title", ""), |
| "description": meta.get("description", ""), |
| "source_url": meta.get("source_url", ""), |
| "content": c.get("chunk_text", ""), |
| "retriever": c.get("retriever", ""), |
| "rerank_score": c.get("rerank_score", None), |
| }) |
|
|
| response = generate_response(query, docs) |
| st.markdown("### π€ Generated Response:") |
| st.success(response) |
|
|
| st.markdown("### π Sources") |
| for d in docs: |
| st.markdown(f"**[{d['id']}] {d.get('title','')}**") |
| if d.get("source_url"): |
| st.markdown(d["source_url"]) |
| st.write((d.get("content") or "")[:450] + "...") |
| if d.get("retriever"): |
| st.caption(f"retriever: {d['retriever']}") |
| if d.get("rerank_score") is not None: |
| st.caption(f"rerank_score: {d['rerank_score']:.4f}") |
| st.markdown("---") |
|
|
| else: |
| |
| |
| |
| chunks = classical_retrieve_chunks( |
| query=query, |
| mode="hybrid", |
| fetch_k=50, |
| rerank_k=5, |
| use_reranker=True |
| ) |
|
|
| docs = [] |
| for idx, c in enumerate(chunks, start=1): |
| meta = c.get("metadata", {}) |
| docs.append({ |
| "id": idx, |
| "title": meta.get("title", ""), |
| "description": meta.get("description", ""), |
| "source_url": meta.get("source_url", ""), |
| "content": c.get("chunk_text", ""), |
| "retriever": c.get("retriever", ""), |
| "rerank_score": c.get("rerank_score", None), |
| }) |
|
|
| response = generate_response(query, docs) |
| st.markdown("### π€ Generated Response:") |
| st.success(response) |
|
|
| st.markdown("### π Sources (text chunks)") |
| for d in docs: |
| st.markdown(f"**[{d['id']}] {d.get('title','')}**") |
| if d.get("source_url"): |
| st.markdown(d["source_url"]) |
| st.write((d.get("content") or "")[:450] + "...") |
| if d.get("retriever"): |
| st.caption(f"retriever: {d['retriever']}") |
| if d.get("rerank_score") is not None: |
| st.caption(f"rerank_score: {d['rerank_score']:.4f}") |
| st.markdown("---") |
|
|