genAI-Project / app.py
OGB2000's picture
Update app.py
828ade4 verified
Raw
History Blame Contribute Delete
11.6 kB
"""
app.py – Scientific RAG Interactive Application (Hugging Face Safe)
"""
import sys
from pathlib import Path
import time
import os
sys.path.insert(0, ".")
import streamlit as st
# =============================================================================
# PAGE CONFIG
# =============================================================================
st.set_page_config(
page_title="Scientific RAG – AIMS Sénégal",
page_icon="🔬",
layout="wide",
initial_sidebar_state="expanded",
)
# =============================================================================
# SESSION STATE
# =============================================================================
def _init_state():
defaults = dict(
embedder=None,
store=None,
llm=None,
reranker=None,
papers_loaded=0,
last_result=None,
comparison=None,
)
for k, v in defaults.items():
if k not in st.session_state:
st.session_state[k] = v
_init_state()
# =============================================================================
# SIDEBAR
# =============================================================================
with st.sidebar:
st.title("⚙️ Configuration")
st.divider()
# ---------------- Corpus ----------------
st.subheader("📚 Corpus")
corpus_source = st.radio(
"Source",
["Load from disk", "Upload PDFs"],
index=0
)
uploaded_files = None
if corpus_source == "Upload PDFs":
uploaded_files = st.file_uploader(
"Upload scientific PDFs",
type="pdf",
accept_multiple_files=True
)
else:
papers_dir = st.text_input("Papers directory", value="data/papers")
metadata_file = st.text_input("Metadata file", value="data/metadata.jsonl")
chroma_dir = st.text_input("Chroma persist dir", value="data/chroma_scientific")
chunking_strategy = st.selectbox(
"Chunking strategy",
["section_aware", "fixed_size"]
)
chunk_size = st.slider("Chunk size (chars)", 300, 1000, 600, 50)
# ---------------- BUILD INDEX ----------------
if st.button("🔄 Build / Load Index", use_container_width=True):
with st.spinner("Processing corpus…"):
try:
from src.retrieval.embedder import Embedder
from src.retrieval.vector_store import ScientificChromaStore
from src.ingestion.scientific_ingestor import ingest_paper, ingest_corpus
from src.ingestion.scientific_chunker import chunk_corpus
import tempfile
import shutil
embedder = Embedder(st.session_state.get("_emb_key", "multilingual-e5-small"))
store = ScientificChromaStore(persist_dir=chroma_dir)
# ================= UPLOAD MODE =================
if corpus_source == "Upload PDFs":
if not uploaded_files:
st.warning("Upload at least one PDF.")
else:
tmp_dir = Path(tempfile.mkdtemp())
papers = []
for uf in uploaded_files:
dest = tmp_dir / uf.name
dest.write_bytes(uf.read())
try:
paper = ingest_paper(str(dest))
paper.title = paper.title or uf.name.replace(".pdf", "")
paper.arxiv_id = uf.name.replace(".pdf", "")
papers.append(paper)
except Exception as e:
st.warning(f"Error {uf.name}: {e}")
shutil.rmtree(tmp_dir, ignore_errors=True)
if papers:
chunks = chunk_corpus(
papers,
strategy=chunking_strategy,
chunk_size=chunk_size
)
texts = [c.text for c in chunks]
embeddings = embedder.encode(texts)
store.index_chunks(chunks, embeddings)
st.session_state.embedder = embedder
st.session_state.store = store
st.session_state.papers_loaded = len(papers)
st.success(f"Indexed {len(chunks)} chunks")
# ================= DISK MODE =================
else:
if store.count() > 0:
st.session_state.store = store
st.session_state.embedder = embedder
st.session_state.papers_loaded = len(store.list_papers())
st.success(f"Loaded existing index ({store.count()} chunks)")
else:
papers = ingest_corpus(
papers_dir=papers_dir,
metadata_file=metadata_file,
)
chunks = chunk_corpus(
papers,
strategy=chunking_strategy,
chunk_size=chunk_size
)
texts = [c.text for c in chunks]
embeddings = embedder.encode(texts)
store.index_chunks(chunks, embeddings)
st.session_state.store = store
st.session_state.embedder = embedder
st.session_state.papers_loaded = len(papers)
st.success(f"Indexed {len(chunks)} chunks")
except Exception as e:
import traceback
st.error(f"Error: {e}")
st.code(traceback.format_exc())
if st.session_state.store:
st.caption(
f"📊 {st.session_state.store.count()} chunks | "
f"{st.session_state.papers_loaded} papers"
)
st.divider()
# ---------------- MODELS ----------------
st.subheader("🤖 Models")
emb_model = st.selectbox(
"Embedding model",
["multilingual-e5-small", "bge-small-en", "minilm-l6"],
)
st.session_state["_emb_key"] = emb_model
llm_choice = st.selectbox(
"LLM Backend",
[
"Groq",
"Qwen",
"Gemma",
"Claude"
],
)
if st.button("Load LLM", use_container_width=True):
with st.spinner("Loading LLM…"):
try:
from src.generation.llm_backend import (
make_llm,
GroqBackend,
ClaudeHaikuBackend
)
if llm_choice == "Groq":
llm = GroqBackend()
elif llm_choice == "Claude":
llm = ClaudeHaikuBackend()
elif llm_choice == "Gemma":
llm = make_llm(
model_path="data/gemma4-e2b-it.task",
fallback_hf="Qwen/Qwen3-0.6B"
)
else:
llm = make_llm(fallback_hf="Qwen/Qwen3-0.6B")
st.session_state.llm = llm
st.success("LLM loaded")
except Exception as e:
st.error(f"LLM error: {e}")
st.divider()
# ---------------- RETRIEVAL ----------------
st.subheader("🔍 Retrieval")
top_k = st.slider("Top-k chunks", 1, 20, 5)
use_reranker = st.checkbox("Enable reranker", value=False)
reranker_k = st.slider(
"Reranker k",
5,
50,
20,
disabled=not use_reranker
)
if use_reranker and st.session_state.reranker is None:
if st.button("Load reranker"):
try:
from src.retrieval.reranker import CrossEncoderReranker
st.session_state.reranker = CrossEncoderReranker()
st.success("Reranker loaded")
except Exception as e:
st.error(str(e))
elif not use_reranker:
st.session_state.reranker = None
st.divider()
st.subheader("Filters")
year_min = st.text_input("Min year")
section_filter = st.multiselect(
"Sections",
["abstract", "introduction", "methodology", "experiments", "conclusion"]
)
# =============================================================================
# MAIN UI
# =============================================================================
st.title("🔬 RAG on Scientific Corpus")
tabs = st.tabs(["Q&A", "Comparison", "Corpus"])
# =============================================================================
# TAB 1
# =============================================================================
with tabs[0]:
question = st.text_area("Question")
col1, col2 = st.columns([3, 1])
with col2:
mode = st.selectbox("Mode", ["RAG", "RAG+Reranker", "LLM"])
with col1:
run = st.button("Generate")
if run and question:
llm = st.session_state.llm
store = st.session_state.store
embedder = st.session_state.embedder
reranker = st.session_state.reranker
if not llm:
st.warning("Load LLM first")
elif mode != "LLM" and not store:
st.warning("Build index first")
else:
with st.spinner("Generating…"):
from src.generation.rag_pipeline import (
answer_with_rag,
answer_without_rag
)
meta_filter = {}
if year_min:
meta_filter["year"] = {"$gte": year_min}
if section_filter:
meta_filter["section"] = {"$in": section_filter}
if mode == "LLM":
t0 = time.time()
ans = answer_without_rag(question, llm)
result = {
"answer": ans,
"sources": [],
"latency": {"total_s": round(time.time() - t0, 2)}
}
else:
use_rr = (mode == "RAG+Reranker") and reranker
result = answer_with_rag(
question,
embedder,
store,
llm,
k=top_k,
reranker=use_rr and reranker,
reranker_k=reranker_k,
metadata_filter=meta_filter or None,
)
st.session_state.last_result = result
if st.session_state.last_result:
res = st.session_state.last_result
st.subheader("Answer")
st.write(res["answer"])
# =============================================================================
# TAB 2
# =============================================================================
with tabs[1]:
st.write("Comparison mode placeholder")
# =============================================================================
# TAB 3
# =============================================================================
with tabs[2]:
store = st.session_state.store
if store:
st.metric("Chunks", store.count())
st.metric("Papers", len(store.list_papers()))
else:
st.info("No index loaded")