import os import re import pickle from pathlib import Path from typing import List, Dict, Any import streamlit as st import numpy as np import faiss from sentence_transformers import SentenceTransformer # ========= LLM backend config ========= USE_OPENAI = os.getenv("USE_OPENAI", "0") == "1" GEMINI_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.5-flash") OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") if USE_OPENAI: from openai import OpenAI OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") if OPENAI_API_KEY: openai_client = OpenAI(api_key=OPENAI_API_KEY) else: openai_client = None else: import google.generativeai as genai GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "") if GOOGLE_API_KEY: genai.configure(api_key=GOOGLE_API_KEY) # ========= Page config ========= st.set_page_config( page_title="Halassa Lab Literature Chatbot", page_icon="🧠", layout="wide", ) from pathlib import Path BASE_DIR = Path(__file__).resolve().parent # points to src/ DATA_DIR = BASE_DIR / "data" VECTOR_PATH = DATA_DIR / "vector_store.index" PKL_PATH = DATA_DIR / "data.pkl" EMBED_MODEL_NAME = os.getenv("EMBED_MODEL_NAME", "BAAI/bge-large-en-v1.5") TOP_K = int(os.getenv("TOP_K", "5")) MAX_CONTEXT_CHARS = int(os.getenv("MAX_CONTEXT_CHARS", "12000")) SUGGESTED_Q = int(os.getenv("SUGGESTED_Q", "4")) # ========= Helpers ========= def load_index_and_data(): if not VECTOR_PATH.exists() or not PKL_PATH.exists(): st.error(f"Missing index or data:\n- {VECTOR_PATH}\n- {PKL_PATH}") st.stop() index = faiss.read_index(str(VECTOR_PATH)) with open(PKL_PATH, "rb") as f: stored = pickle.load(f) texts = stored.get("texts", []) sources = stored.get("sources", []) meta = stored.get("meta", [None] * len(texts)) if len(texts) == 0 or len(texts) != len(sources): st.error("data.pkl must contain 'texts' and 'sources' of equal length.") st.stop() return index, texts, sources, meta @st.cache_resource(show_spinner=False) def get_embedder(): return SentenceTransformer(EMBED_MODEL_NAME) def encode_query(query: str, embedder) -> np.ndarray: vec = embedder.encode([query]) return vec.astype(np.float32) def retrieve(query: str, index, texts, sources, meta, k=TOP_K): embedder = get_embedder() qvec = encode_query(query, embedder) D, I = index.search(qvec, k) results = [] for rank, idx in enumerate(I[0].tolist()): if 0 <= idx < len(texts): results.append({ "rank": rank + 1, "text": texts[idx], "source": sources[idx], "meta": meta[idx] if meta and idx < len(meta) else None }) return results def build_context(retrieved: List[Dict[str, Any]]) -> str: parts, total = [], 0 for r in retrieved: src = r["source"] txt = r["text"].strip() chunk = f"Source: {src}\nContent: {txt}\n" if total + len(chunk) > MAX_CONTEXT_CHARS: break parts.append(chunk) total += len(chunk) return "\n---\n".join(parts) def call_llm(system_prompt: str, user_prompt: str) -> str: # OpenAI path if USE_OPENAI and os.getenv("OPENAI_API_KEY") and openai_client: resp = openai_client.chat.completions.create( model=OPENAI_MODEL, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], temperature=0.2, ) return resp.choices[0].message.content # Gemini path if not USE_OPENAI and os.getenv("GOOGLE_API_KEY"): model = genai.GenerativeModel(GEMINI_MODEL) resp = model.generate_content(system_prompt + "\n\n" + user_prompt) return resp.text # Fallback (no key) for UI testing return "(LLM disabled) " + user_prompt[:800] def highlight_terms(text: str, query: str) -> str: # lightweight term highlighter import re terms = [t for t in re.split(r"\W+", query) if len(t) >= 3] out = text for t in set(terms): out = re.sub(rf"({re.escape(t)})", r"\1", out, flags=re.IGNORECASE) return out def suggest_questions(last_answer: str, k=SUGGESTED_Q) -> List[str]: prompt = f"""Generate {k} concise follow-up questions a user might ask next, given the expert answer below. Each question should be short (max ~12 words) and deepen the discussion. Answer only with a bulletless list, one question per line. Expert answer: {last_answer} """ out = call_llm( system_prompt="You are a helpful assistant that proposes follow-up questions.", user_prompt=prompt, ) qs = [re.sub(r"^[\-\*\d\.\)\s]+", "", q).strip() for q in out.splitlines() if q.strip()] return [q for q in qs if q][:k] # ========= Load index/data ========= index, TEXTS, SOURCES, META = load_index_and_data() # ========= Sidebar ========= with st.sidebar: st.title("⚙️ Settings") st.write("**Retrieval**") TOP_K = st.slider("Top-K passages", 3, 10, TOP_K) st.divider() st.write("**Models**") st.write(f"Embedding: `{EMBED_MODEL_NAME}`") st.write("LLM:", "OpenAI" if USE_OPENAI else "Gemini", f"({OPENAI_MODEL if USE_OPENAI else GEMINI_MODEL})") st.caption("Switch with env vars: USE_OPENAI, OPENAI_API_KEY, GOOGLE_API_KEY.") st.divider() st.write("**Files**") st.write(f"Index: `{VECTOR_PATH}`") st.write(f"Data : `{PKL_PATH}`") # ========= Main Layout ========= st.title("Halassa Lab Onboarder 🧠📄") st.caption("Ask questions; see exactly which passages were used.") if "chat" not in st.session_state: st.session_state.chat = [] # list[dict]: {"role": "user"/"assistant", "content": str, "retrieved": list} if "last_suggestions" not in st.session_state: st.session_state.last_suggestions = [] # Input row with st.container(): cols = st.columns([6, 1]) with cols[0]: user_message = st.text_input( "Ask your question", "", placeholder="e.g., How does MD dopamine shape error-driven flexibility?", ) with cols[1]: ask = st.button("Send", use_container_width=True) def answer_query(query: str): retrieved = retrieve(query, index, TEXTS, SOURCES, META, k=TOP_K) context_str = build_context(retrieved) sys_prompt = ( "You are an Expert scientist in the Halassa Lab at MIT, expert in computational neuroscience. " "Answer thoroughly and clearly. Synthesize from provided context; write in your own words. " "If you cite directly from a provided paper, add citations at the end as [filename - Page X]. " "If context is partial, add helpful background." ) user_prompt = f"""Context: --- {context_str} --- User Question: {query} Expert Answer: """ answer = call_llm(sys_prompt, user_prompt) st.session_state.chat.append({"role": "user", "content": query}) st.session_state.chat.append({"role": "assistant", "content": answer, "retrieved": retrieved}) try: st.session_state.last_suggestions = suggest_questions(answer, k=SUGGESTED_Q) except Exception: st.session_state.last_suggestions = [] # Trigger on click or Enter if ask and user_message.strip(): answer_query(user_message.strip()) elif user_message.strip() and st.session_state.chat == []: # allow pressing Enter to submit first question answer_query(user_message.strip()) # Two-column layout col_chat, col_docs = st.columns([2, 1], gap="large") # Left: Chat with col_chat: for turn in st.session_state.chat: if turn["role"] == "user": st.chat_message("user").markdown(turn["content"]) else: st.chat_message("assistant").markdown(turn["content"]) if st.session_state.last_suggestions: st.subheader("Try next:") sug_cols = st.columns(len(st.session_state.last_suggestions)) for i, q in enumerate(st.session_state.last_suggestions): if sug_cols[i].button(q): answer_query(q) # Right: Relevant chunks (no PDF viewer) with col_docs: st.subheader("Relevant Sources") last_assistant = None for t in reversed(st.session_state.chat): if t.get("role") == "assistant" and "retrieved" in t: last_assistant = t break if not last_assistant: st.info("Ask a question to see relevant passages.") else: # Find preceding user query for highlighting query_text = "" for i in range(len(st.session_state.chat)-1, -1, -1): if st.session_state.chat[i]["role"] == "user": query_text = st.session_state.chat[i]["content"] break for r in last_assistant["retrieved"]: src = r["source"] with st.expander(f"#{r['rank']} {src}"): html = highlight_terms(r["text"], query_text) st.markdown(html, unsafe_allow_html=True) # Small utility buttons st.download_button( "Download chunk", data=r["text"].encode("utf-8"), file_name=f"chunk_{r['rank']}.txt", use_container_width=True ) st.divider() st.caption("Tip: Ensure your `sources` strings match your citation format (e.g., `paper.pdf - Page 12`) so your LLM’s citations are clean.")