CardioGenRAG / app.py
hlnicholls's picture
feat: bot v1
0edf71e verified
# app.py β€” add "HF Transformers (Space/CPU)" backend option
import os
import streamlit as st
from rag import CardioRAG, LLMBackend
st.set_page_config(page_title="Cardiology Paper Explainer", page_icon="πŸ«€", layout="wide")
st.title("πŸ«€ Cardiology Paper Explainer")
st.caption("Local RAG to help researchers understand cross-discipline cardiovascular papers. Educational use only.")
rag = CardioRAG()
with st.sidebar:
st.header("Settings")
options = ["HF Transformers (Space/CPU)", "llama.cpp (MODEL_PATH)", "LM Studio (localhost)"]
backend = st.selectbox("LLM backend", options, index=0)
top_k = st.slider("Top-k chunks", 2, 6, 3)
# Reader selector
reading_level = st.selectbox("Reader", ["Clinician", "Data scientist"])
# Filter by a single paper (filename)
sources = ["All"] + rag.list_sources()
src_choice = st.selectbox("Filter to paper", sources, index=0)
src_filter = None if src_choice == "All" else src_choice
show_chunks = st.checkbox("Show retrieved snippets", value=True)
# (Optional) choose small CPU model id for Spaces
default_model = os.getenv("HF_LOCAL_MODEL", "google/gemma-2-2b-it")
hf_model = st.text_input("HF model (CPU)", value=default_model, help="e.g., google/gemma-2-2b-it or microsoft/phi-3-mini-4k-instruct")
os.environ["HF_LOCAL_MODEL"] = hf_model
if st.button("Rebuild Index"):
rag.rebuild_index()
st.success("Vector store re-ingested.")
query = st.text_area(
"Ask a question",
height=110,
placeholder="e.g., PICO and clinically relevant endpoints (≀50 words).",
)
if st.button("Explain Paper", type="primary") and query:
with st.spinner("Retrieving and generating..."):
if backend.startswith("HF Transformers"):
llm = LLMBackend.from_hf_local(temperature=0.2)
elif backend.startswith("llama.cpp"):
model_path = os.getenv("MODEL_PATH")
if not model_path:
st.error("MODEL_PATH not set; export it to your .gguf file.")
st.stop()
llm = LLMBackend.from_llamacpp(model_path=model_path, temperature=0.2)
else:
llm = LLMBackend.from_lmstudio(temperature=0.2)
answer_md, sources_list = rag.answer(
query=query,
top_k=top_k,
reading_level=reading_level,
llm=llm,
source_filter=src_filter
)
st.markdown(answer_md)
with st.expander("Sources"):
for s in sources_list:
st.write(f"- {s}")
if show_chunks:
st.divider()
st.subheader("Retrieved snippets (preview)")
previews = rag._retrieve(query=query, top_k=top_k, source_filter=src_filter)
for i, (doc, meta) in enumerate(previews, 1):
st.markdown(f"**{i}. {meta.get('source')} β€” {meta.get('section')}**")
st.write((doc or "")[:400] + ("..." if doc and len(doc) > 400 else ""))