HalassaBot / src /streamlit_app.py
sahilursa's picture
Update src/streamlit_app.py
a148d79 verified
import os
import re
import pickle
from pathlib import Path
from typing import List, Dict, Any
import streamlit as st
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
# ========= LLM backend config =========
USE_OPENAI = os.getenv("USE_OPENAI", "0") == "1"
GEMINI_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.5-flash")
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
if USE_OPENAI:
from openai import OpenAI
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
if OPENAI_API_KEY:
openai_client = OpenAI(api_key=OPENAI_API_KEY)
else:
openai_client = None
else:
import google.generativeai as genai
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "")
if GOOGLE_API_KEY:
genai.configure(api_key=GOOGLE_API_KEY)
# ========= Page config =========
st.set_page_config(
page_title="Halassa Lab Literature Chatbot",
page_icon="🧠",
layout="wide",
)
from pathlib import Path
BASE_DIR = Path(__file__).resolve().parent # points to src/
DATA_DIR = BASE_DIR / "data"
VECTOR_PATH = DATA_DIR / "vector_store.index"
PKL_PATH = DATA_DIR / "data.pkl"
EMBED_MODEL_NAME = os.getenv("EMBED_MODEL_NAME", "BAAI/bge-large-en-v1.5")
TOP_K = int(os.getenv("TOP_K", "5"))
MAX_CONTEXT_CHARS = int(os.getenv("MAX_CONTEXT_CHARS", "12000"))
SUGGESTED_Q = int(os.getenv("SUGGESTED_Q", "4"))
# ========= Helpers =========
def load_index_and_data():
if not VECTOR_PATH.exists() or not PKL_PATH.exists():
st.error(f"Missing index or data:\n- {VECTOR_PATH}\n- {PKL_PATH}")
st.stop()
index = faiss.read_index(str(VECTOR_PATH))
with open(PKL_PATH, "rb") as f:
stored = pickle.load(f)
texts = stored.get("texts", [])
sources = stored.get("sources", [])
meta = stored.get("meta", [None] * len(texts))
if len(texts) == 0 or len(texts) != len(sources):
st.error("data.pkl must contain 'texts' and 'sources' of equal length.")
st.stop()
return index, texts, sources, meta
@st.cache_resource(show_spinner=False)
def get_embedder():
return SentenceTransformer(EMBED_MODEL_NAME)
def encode_query(query: str, embedder) -> np.ndarray:
vec = embedder.encode([query])
return vec.astype(np.float32)
def retrieve(query: str, index, texts, sources, meta, k=TOP_K):
embedder = get_embedder()
qvec = encode_query(query, embedder)
D, I = index.search(qvec, k)
results = []
for rank, idx in enumerate(I[0].tolist()):
if 0 <= idx < len(texts):
results.append({
"rank": rank + 1,
"text": texts[idx],
"source": sources[idx],
"meta": meta[idx] if meta and idx < len(meta) else None
})
return results
def build_context(retrieved: List[Dict[str, Any]]) -> str:
parts, total = [], 0
for r in retrieved:
src = r["source"]
txt = r["text"].strip()
chunk = f"Source: {src}\nContent: {txt}\n"
if total + len(chunk) > MAX_CONTEXT_CHARS:
break
parts.append(chunk)
total += len(chunk)
return "\n---\n".join(parts)
def call_llm(system_prompt: str, user_prompt: str) -> str:
# OpenAI path
if USE_OPENAI and os.getenv("OPENAI_API_KEY") and openai_client:
resp = openai_client.chat.completions.create(
model=OPENAI_MODEL,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
temperature=0.2,
)
return resp.choices[0].message.content
# Gemini path
if not USE_OPENAI and os.getenv("GOOGLE_API_KEY"):
model = genai.GenerativeModel(GEMINI_MODEL)
resp = model.generate_content(system_prompt + "\n\n" + user_prompt)
return resp.text
# Fallback (no key) for UI testing
return "(LLM disabled) " + user_prompt[:800]
def highlight_terms(text: str, query: str) -> str:
# lightweight term highlighter
import re
terms = [t for t in re.split(r"\W+", query) if len(t) >= 3]
out = text
for t in set(terms):
out = re.sub(rf"({re.escape(t)})", r"<mark>\1</mark>", out, flags=re.IGNORECASE)
return out
def suggest_questions(last_answer: str, k=SUGGESTED_Q) -> List[str]:
prompt = f"""Generate {k} concise follow-up questions a user might ask next, given the expert answer below.
Each question should be short (max ~12 words) and deepen the discussion.
Answer only with a bulletless list, one question per line.
Expert answer:
{last_answer}
"""
out = call_llm(
system_prompt="You are a helpful assistant that proposes follow-up questions.",
user_prompt=prompt,
)
qs = [re.sub(r"^[\-\*\d\.\)\s]+", "", q).strip() for q in out.splitlines() if q.strip()]
return [q for q in qs if q][:k]
# ========= Load index/data =========
index, TEXTS, SOURCES, META = load_index_and_data()
# ========= Sidebar =========
with st.sidebar:
st.title("⚙️ Settings")
st.write("**Retrieval**")
TOP_K = st.slider("Top-K passages", 3, 10, TOP_K)
st.divider()
st.write("**Models**")
st.write(f"Embedding: `{EMBED_MODEL_NAME}`")
st.write("LLM:", "OpenAI" if USE_OPENAI else "Gemini",
f"({OPENAI_MODEL if USE_OPENAI else GEMINI_MODEL})")
st.caption("Switch with env vars: USE_OPENAI, OPENAI_API_KEY, GOOGLE_API_KEY.")
st.divider()
st.write("**Files**")
st.write(f"Index: `{VECTOR_PATH}`")
st.write(f"Data : `{PKL_PATH}`")
# ========= Main Layout =========
st.title("Halassa Lab Onboarder 🧠📄")
st.caption("Ask questions; see exactly which passages were used.")
if "chat" not in st.session_state:
st.session_state.chat = [] # list[dict]: {"role": "user"/"assistant", "content": str, "retrieved": list}
if "last_suggestions" not in st.session_state:
st.session_state.last_suggestions = []
# Input row
with st.container():
cols = st.columns([6, 1])
with cols[0]:
user_message = st.text_input(
"Ask your question",
"",
placeholder="e.g., How does MD dopamine shape error-driven flexibility?",
)
with cols[1]:
ask = st.button("Send", use_container_width=True)
def answer_query(query: str):
retrieved = retrieve(query, index, TEXTS, SOURCES, META, k=TOP_K)
context_str = build_context(retrieved)
sys_prompt = (
"You are an Expert scientist in the Halassa Lab at MIT, expert in computational neuroscience. "
"Answer thoroughly and clearly. Synthesize from provided context; write in your own words. "
"If you cite directly from a provided paper, add citations at the end as [filename - Page X]. "
"If context is partial, add helpful background."
)
user_prompt = f"""Context:
---
{context_str}
---
User Question: {query}
Expert Answer:
"""
answer = call_llm(sys_prompt, user_prompt)
st.session_state.chat.append({"role": "user", "content": query})
st.session_state.chat.append({"role": "assistant", "content": answer, "retrieved": retrieved})
try:
st.session_state.last_suggestions = suggest_questions(answer, k=SUGGESTED_Q)
except Exception:
st.session_state.last_suggestions = []
# Trigger on click or Enter
if ask and user_message.strip():
answer_query(user_message.strip())
elif user_message.strip() and st.session_state.chat == []:
# allow pressing Enter to submit first question
answer_query(user_message.strip())
# Two-column layout
col_chat, col_docs = st.columns([2, 1], gap="large")
# Left: Chat
with col_chat:
for turn in st.session_state.chat:
if turn["role"] == "user":
st.chat_message("user").markdown(turn["content"])
else:
st.chat_message("assistant").markdown(turn["content"])
if st.session_state.last_suggestions:
st.subheader("Try next:")
sug_cols = st.columns(len(st.session_state.last_suggestions))
for i, q in enumerate(st.session_state.last_suggestions):
if sug_cols[i].button(q):
answer_query(q)
# Right: Relevant chunks (no PDF viewer)
with col_docs:
st.subheader("Relevant Sources")
last_assistant = None
for t in reversed(st.session_state.chat):
if t.get("role") == "assistant" and "retrieved" in t:
last_assistant = t
break
if not last_assistant:
st.info("Ask a question to see relevant passages.")
else:
# Find preceding user query for highlighting
query_text = ""
for i in range(len(st.session_state.chat)-1, -1, -1):
if st.session_state.chat[i]["role"] == "user":
query_text = st.session_state.chat[i]["content"]
break
for r in last_assistant["retrieved"]:
src = r["source"]
with st.expander(f"#{r['rank']} {src}"):
html = highlight_terms(r["text"], query_text)
st.markdown(html, unsafe_allow_html=True)
# Small utility buttons
st.download_button(
"Download chunk",
data=r["text"].encode("utf-8"),
file_name=f"chunk_{r['rank']}.txt",
use_container_width=True
)
st.divider()
st.caption("Tip: Ensure your `sources` strings match your citation format (e.g., `paper.pdf - Page 12`) so your LLM’s citations are clean.")