Spaces:
Sleeping
Sleeping
| """ | |
| app.py – Scientific RAG Interactive Application (Hugging Face Safe) | |
| """ | |
| import sys | |
| from pathlib import Path | |
| import time | |
| import os | |
| sys.path.insert(0, ".") | |
| import streamlit as st | |
| # ============================================================================= | |
| # PAGE CONFIG | |
| # ============================================================================= | |
| st.set_page_config( | |
| page_title="Scientific RAG – AIMS Sénégal", | |
| page_icon="🔬", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| ) | |
| # ============================================================================= | |
| # SESSION STATE | |
| # ============================================================================= | |
| def _init_state(): | |
| defaults = dict( | |
| embedder=None, | |
| store=None, | |
| llm=None, | |
| reranker=None, | |
| papers_loaded=0, | |
| last_result=None, | |
| comparison=None, | |
| ) | |
| for k, v in defaults.items(): | |
| if k not in st.session_state: | |
| st.session_state[k] = v | |
| _init_state() | |
| # ============================================================================= | |
| # SIDEBAR | |
| # ============================================================================= | |
| with st.sidebar: | |
| st.title("⚙️ Configuration") | |
| st.divider() | |
| # ---------------- Corpus ---------------- | |
| st.subheader("📚 Corpus") | |
| corpus_source = st.radio( | |
| "Source", | |
| ["Load from disk", "Upload PDFs"], | |
| index=0 | |
| ) | |
| uploaded_files = None | |
| if corpus_source == "Upload PDFs": | |
| uploaded_files = st.file_uploader( | |
| "Upload scientific PDFs", | |
| type="pdf", | |
| accept_multiple_files=True | |
| ) | |
| else: | |
| papers_dir = st.text_input("Papers directory", value="data/papers") | |
| metadata_file = st.text_input("Metadata file", value="data/metadata.jsonl") | |
| chroma_dir = st.text_input("Chroma persist dir", value="data/chroma_scientific") | |
| chunking_strategy = st.selectbox( | |
| "Chunking strategy", | |
| ["section_aware", "fixed_size"] | |
| ) | |
| chunk_size = st.slider("Chunk size (chars)", 300, 1000, 600, 50) | |
| # ---------------- BUILD INDEX ---------------- | |
| if st.button("🔄 Build / Load Index", use_container_width=True): | |
| with st.spinner("Processing corpus…"): | |
| try: | |
| from src.retrieval.embedder import Embedder | |
| from src.retrieval.vector_store import ScientificChromaStore | |
| from src.ingestion.scientific_ingestor import ingest_paper, ingest_corpus | |
| from src.ingestion.scientific_chunker import chunk_corpus | |
| import tempfile | |
| import shutil | |
| embedder = Embedder(st.session_state.get("_emb_key", "multilingual-e5-small")) | |
| store = ScientificChromaStore(persist_dir=chroma_dir) | |
| # ================= UPLOAD MODE ================= | |
| if corpus_source == "Upload PDFs": | |
| if not uploaded_files: | |
| st.warning("Upload at least one PDF.") | |
| else: | |
| tmp_dir = Path(tempfile.mkdtemp()) | |
| papers = [] | |
| for uf in uploaded_files: | |
| dest = tmp_dir / uf.name | |
| dest.write_bytes(uf.read()) | |
| try: | |
| paper = ingest_paper(str(dest)) | |
| paper.title = paper.title or uf.name.replace(".pdf", "") | |
| paper.arxiv_id = uf.name.replace(".pdf", "") | |
| papers.append(paper) | |
| except Exception as e: | |
| st.warning(f"Error {uf.name}: {e}") | |
| shutil.rmtree(tmp_dir, ignore_errors=True) | |
| if papers: | |
| chunks = chunk_corpus( | |
| papers, | |
| strategy=chunking_strategy, | |
| chunk_size=chunk_size | |
| ) | |
| texts = [c.text for c in chunks] | |
| embeddings = embedder.encode(texts) | |
| store.index_chunks(chunks, embeddings) | |
| st.session_state.embedder = embedder | |
| st.session_state.store = store | |
| st.session_state.papers_loaded = len(papers) | |
| st.success(f"Indexed {len(chunks)} chunks") | |
| # ================= DISK MODE ================= | |
| else: | |
| if store.count() > 0: | |
| st.session_state.store = store | |
| st.session_state.embedder = embedder | |
| st.session_state.papers_loaded = len(store.list_papers()) | |
| st.success(f"Loaded existing index ({store.count()} chunks)") | |
| else: | |
| papers = ingest_corpus( | |
| papers_dir=papers_dir, | |
| metadata_file=metadata_file, | |
| ) | |
| chunks = chunk_corpus( | |
| papers, | |
| strategy=chunking_strategy, | |
| chunk_size=chunk_size | |
| ) | |
| texts = [c.text for c in chunks] | |
| embeddings = embedder.encode(texts) | |
| store.index_chunks(chunks, embeddings) | |
| st.session_state.store = store | |
| st.session_state.embedder = embedder | |
| st.session_state.papers_loaded = len(papers) | |
| st.success(f"Indexed {len(chunks)} chunks") | |
| except Exception as e: | |
| import traceback | |
| st.error(f"Error: {e}") | |
| st.code(traceback.format_exc()) | |
| if st.session_state.store: | |
| st.caption( | |
| f"📊 {st.session_state.store.count()} chunks | " | |
| f"{st.session_state.papers_loaded} papers" | |
| ) | |
| st.divider() | |
| # ---------------- MODELS ---------------- | |
| st.subheader("🤖 Models") | |
| emb_model = st.selectbox( | |
| "Embedding model", | |
| ["multilingual-e5-small", "bge-small-en", "minilm-l6"], | |
| ) | |
| st.session_state["_emb_key"] = emb_model | |
| llm_choice = st.selectbox( | |
| "LLM Backend", | |
| [ | |
| "Groq", | |
| "Qwen", | |
| "Gemma", | |
| "Claude" | |
| ], | |
| ) | |
| if st.button("Load LLM", use_container_width=True): | |
| with st.spinner("Loading LLM…"): | |
| try: | |
| from src.generation.llm_backend import ( | |
| make_llm, | |
| GroqBackend, | |
| ClaudeHaikuBackend | |
| ) | |
| if llm_choice == "Groq": | |
| llm = GroqBackend() | |
| elif llm_choice == "Claude": | |
| llm = ClaudeHaikuBackend() | |
| elif llm_choice == "Gemma": | |
| llm = make_llm( | |
| model_path="data/gemma4-e2b-it.task", | |
| fallback_hf="Qwen/Qwen3-0.6B" | |
| ) | |
| else: | |
| llm = make_llm(fallback_hf="Qwen/Qwen3-0.6B") | |
| st.session_state.llm = llm | |
| st.success("LLM loaded") | |
| except Exception as e: | |
| st.error(f"LLM error: {e}") | |
| st.divider() | |
| # ---------------- RETRIEVAL ---------------- | |
| st.subheader("🔍 Retrieval") | |
| top_k = st.slider("Top-k chunks", 1, 20, 5) | |
| use_reranker = st.checkbox("Enable reranker", value=False) | |
| reranker_k = st.slider( | |
| "Reranker k", | |
| 5, | |
| 50, | |
| 20, | |
| disabled=not use_reranker | |
| ) | |
| if use_reranker and st.session_state.reranker is None: | |
| if st.button("Load reranker"): | |
| try: | |
| from src.retrieval.reranker import CrossEncoderReranker | |
| st.session_state.reranker = CrossEncoderReranker() | |
| st.success("Reranker loaded") | |
| except Exception as e: | |
| st.error(str(e)) | |
| elif not use_reranker: | |
| st.session_state.reranker = None | |
| st.divider() | |
| st.subheader("Filters") | |
| year_min = st.text_input("Min year") | |
| section_filter = st.multiselect( | |
| "Sections", | |
| ["abstract", "introduction", "methodology", "experiments", "conclusion"] | |
| ) | |
| # ============================================================================= | |
| # MAIN UI | |
| # ============================================================================= | |
| st.title("🔬 RAG on Scientific Corpus") | |
| tabs = st.tabs(["Q&A", "Comparison", "Corpus"]) | |
| # ============================================================================= | |
| # TAB 1 | |
| # ============================================================================= | |
| with tabs[0]: | |
| question = st.text_area("Question") | |
| col1, col2 = st.columns([3, 1]) | |
| with col2: | |
| mode = st.selectbox("Mode", ["RAG", "RAG+Reranker", "LLM"]) | |
| with col1: | |
| run = st.button("Generate") | |
| if run and question: | |
| llm = st.session_state.llm | |
| store = st.session_state.store | |
| embedder = st.session_state.embedder | |
| reranker = st.session_state.reranker | |
| if not llm: | |
| st.warning("Load LLM first") | |
| elif mode != "LLM" and not store: | |
| st.warning("Build index first") | |
| else: | |
| with st.spinner("Generating…"): | |
| from src.generation.rag_pipeline import ( | |
| answer_with_rag, | |
| answer_without_rag | |
| ) | |
| meta_filter = {} | |
| if year_min: | |
| meta_filter["year"] = {"$gte": year_min} | |
| if section_filter: | |
| meta_filter["section"] = {"$in": section_filter} | |
| if mode == "LLM": | |
| t0 = time.time() | |
| ans = answer_without_rag(question, llm) | |
| result = { | |
| "answer": ans, | |
| "sources": [], | |
| "latency": {"total_s": round(time.time() - t0, 2)} | |
| } | |
| else: | |
| use_rr = (mode == "RAG+Reranker") and reranker | |
| result = answer_with_rag( | |
| question, | |
| embedder, | |
| store, | |
| llm, | |
| k=top_k, | |
| reranker=use_rr and reranker, | |
| reranker_k=reranker_k, | |
| metadata_filter=meta_filter or None, | |
| ) | |
| st.session_state.last_result = result | |
| if st.session_state.last_result: | |
| res = st.session_state.last_result | |
| st.subheader("Answer") | |
| st.write(res["answer"]) | |
| # ============================================================================= | |
| # TAB 2 | |
| # ============================================================================= | |
| with tabs[1]: | |
| st.write("Comparison mode placeholder") | |
| # ============================================================================= | |
| # TAB 3 | |
| # ============================================================================= | |
| with tabs[2]: | |
| store = st.session_state.store | |
| if store: | |
| st.metric("Chunks", store.count()) | |
| st.metric("Papers", len(store.list_papers())) | |
| else: | |
| st.info("No index loaded") |