# app.py — add "HF Transformers (Space/CPU)" backend option import os import streamlit as st from rag import CardioRAG, LLMBackend st.set_page_config(page_title="Cardiology Paper Explainer", page_icon="🫀", layout="wide") st.title("🫀 Cardiology Paper Explainer") st.caption("Local RAG to help researchers understand cross-discipline cardiovascular papers. Educational use only.") rag = CardioRAG() with st.sidebar: st.header("Settings") options = ["HF Transformers (Space/CPU)", "llama.cpp (MODEL_PATH)", "LM Studio (localhost)"] backend = st.selectbox("LLM backend", options, index=0) top_k = st.slider("Top-k chunks", 2, 6, 3) # Reader selector reading_level = st.selectbox("Reader", ["Clinician", "Data scientist"]) # Filter by a single paper (filename) sources = ["All"] + rag.list_sources() src_choice = st.selectbox("Filter to paper", sources, index=0) src_filter = None if src_choice == "All" else src_choice show_chunks = st.checkbox("Show retrieved snippets", value=True) # (Optional) choose small CPU model id for Spaces default_model = os.getenv("HF_LOCAL_MODEL", "google/gemma-2-2b-it") hf_model = st.text_input("HF model (CPU)", value=default_model, help="e.g., google/gemma-2-2b-it or microsoft/phi-3-mini-4k-instruct") os.environ["HF_LOCAL_MODEL"] = hf_model if st.button("Rebuild Index"): rag.rebuild_index() st.success("Vector store re-ingested.") query = st.text_area( "Ask a question", height=110, placeholder="e.g., PICO and clinically relevant endpoints (≤50 words).", ) if st.button("Explain Paper", type="primary") and query: with st.spinner("Retrieving and generating..."): if backend.startswith("HF Transformers"): llm = LLMBackend.from_hf_local(temperature=0.2) elif backend.startswith("llama.cpp"): model_path = os.getenv("MODEL_PATH") if not model_path: st.error("MODEL_PATH not set; export it to your .gguf file.") st.stop() llm = LLMBackend.from_llamacpp(model_path=model_path, temperature=0.2) else: llm = LLMBackend.from_lmstudio(temperature=0.2) answer_md, sources_list = rag.answer( query=query, top_k=top_k, reading_level=reading_level, llm=llm, source_filter=src_filter ) st.markdown(answer_md) with st.expander("Sources"): for s in sources_list: st.write(f"- {s}") if show_chunks: st.divider() st.subheader("Retrieved snippets (preview)") previews = rag._retrieve(query=query, top_k=top_k, source_filter=src_filter) for i, (doc, meta) in enumerate(previews, 1): st.markdown(f"**{i}. {meta.get('source')} — {meta.get('section')}**") st.write((doc or "")[:400] + ("..." if doc and len(doc) > 400 else ""))