Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import pickle | |
| from pathlib import Path | |
| from typing import List, Dict, Any | |
| import streamlit as st | |
| import numpy as np | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| # ========= LLM backend config ========= | |
| USE_OPENAI = os.getenv("USE_OPENAI", "0") == "1" | |
| GEMINI_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.5-flash") | |
| OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") | |
| if USE_OPENAI: | |
| from openai import OpenAI | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") | |
| if OPENAI_API_KEY: | |
| openai_client = OpenAI(api_key=OPENAI_API_KEY) | |
| else: | |
| openai_client = None | |
| else: | |
| import google.generativeai as genai | |
| GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "") | |
| if GOOGLE_API_KEY: | |
| genai.configure(api_key=GOOGLE_API_KEY) | |
| # ========= Page config ========= | |
| st.set_page_config( | |
| page_title="Halassa Lab Literature Chatbot", | |
| page_icon="🧠", | |
| layout="wide", | |
| ) | |
| from pathlib import Path | |
| BASE_DIR = Path(__file__).resolve().parent # points to src/ | |
| DATA_DIR = BASE_DIR / "data" | |
| VECTOR_PATH = DATA_DIR / "vector_store.index" | |
| PKL_PATH = DATA_DIR / "data.pkl" | |
| EMBED_MODEL_NAME = os.getenv("EMBED_MODEL_NAME", "BAAI/bge-large-en-v1.5") | |
| TOP_K = int(os.getenv("TOP_K", "5")) | |
| MAX_CONTEXT_CHARS = int(os.getenv("MAX_CONTEXT_CHARS", "12000")) | |
| SUGGESTED_Q = int(os.getenv("SUGGESTED_Q", "4")) | |
| # ========= Helpers ========= | |
| def load_index_and_data(): | |
| if not VECTOR_PATH.exists() or not PKL_PATH.exists(): | |
| st.error(f"Missing index or data:\n- {VECTOR_PATH}\n- {PKL_PATH}") | |
| st.stop() | |
| index = faiss.read_index(str(VECTOR_PATH)) | |
| with open(PKL_PATH, "rb") as f: | |
| stored = pickle.load(f) | |
| texts = stored.get("texts", []) | |
| sources = stored.get("sources", []) | |
| meta = stored.get("meta", [None] * len(texts)) | |
| if len(texts) == 0 or len(texts) != len(sources): | |
| st.error("data.pkl must contain 'texts' and 'sources' of equal length.") | |
| st.stop() | |
| return index, texts, sources, meta | |
| def get_embedder(): | |
| return SentenceTransformer(EMBED_MODEL_NAME) | |
| def encode_query(query: str, embedder) -> np.ndarray: | |
| vec = embedder.encode([query]) | |
| return vec.astype(np.float32) | |
| def retrieve(query: str, index, texts, sources, meta, k=TOP_K): | |
| embedder = get_embedder() | |
| qvec = encode_query(query, embedder) | |
| D, I = index.search(qvec, k) | |
| results = [] | |
| for rank, idx in enumerate(I[0].tolist()): | |
| if 0 <= idx < len(texts): | |
| results.append({ | |
| "rank": rank + 1, | |
| "text": texts[idx], | |
| "source": sources[idx], | |
| "meta": meta[idx] if meta and idx < len(meta) else None | |
| }) | |
| return results | |
| def build_context(retrieved: List[Dict[str, Any]]) -> str: | |
| parts, total = [], 0 | |
| for r in retrieved: | |
| src = r["source"] | |
| txt = r["text"].strip() | |
| chunk = f"Source: {src}\nContent: {txt}\n" | |
| if total + len(chunk) > MAX_CONTEXT_CHARS: | |
| break | |
| parts.append(chunk) | |
| total += len(chunk) | |
| return "\n---\n".join(parts) | |
| def call_llm(system_prompt: str, user_prompt: str) -> str: | |
| # OpenAI path | |
| if USE_OPENAI and os.getenv("OPENAI_API_KEY") and openai_client: | |
| resp = openai_client.chat.completions.create( | |
| model=OPENAI_MODEL, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=0.2, | |
| ) | |
| return resp.choices[0].message.content | |
| # Gemini path | |
| if not USE_OPENAI and os.getenv("GOOGLE_API_KEY"): | |
| model = genai.GenerativeModel(GEMINI_MODEL) | |
| resp = model.generate_content(system_prompt + "\n\n" + user_prompt) | |
| return resp.text | |
| # Fallback (no key) for UI testing | |
| return "(LLM disabled) " + user_prompt[:800] | |
| def highlight_terms(text: str, query: str) -> str: | |
| # lightweight term highlighter | |
| import re | |
| terms = [t for t in re.split(r"\W+", query) if len(t) >= 3] | |
| out = text | |
| for t in set(terms): | |
| out = re.sub(rf"({re.escape(t)})", r"<mark>\1</mark>", out, flags=re.IGNORECASE) | |
| return out | |
| def suggest_questions(last_answer: str, k=SUGGESTED_Q) -> List[str]: | |
| prompt = f"""Generate {k} concise follow-up questions a user might ask next, given the expert answer below. | |
| Each question should be short (max ~12 words) and deepen the discussion. | |
| Answer only with a bulletless list, one question per line. | |
| Expert answer: | |
| {last_answer} | |
| """ | |
| out = call_llm( | |
| system_prompt="You are a helpful assistant that proposes follow-up questions.", | |
| user_prompt=prompt, | |
| ) | |
| qs = [re.sub(r"^[\-\*\d\.\)\s]+", "", q).strip() for q in out.splitlines() if q.strip()] | |
| return [q for q in qs if q][:k] | |
| # ========= Load index/data ========= | |
| index, TEXTS, SOURCES, META = load_index_and_data() | |
| # ========= Sidebar ========= | |
| with st.sidebar: | |
| st.title("⚙️ Settings") | |
| st.write("**Retrieval**") | |
| TOP_K = st.slider("Top-K passages", 3, 10, TOP_K) | |
| st.divider() | |
| st.write("**Models**") | |
| st.write(f"Embedding: `{EMBED_MODEL_NAME}`") | |
| st.write("LLM:", "OpenAI" if USE_OPENAI else "Gemini", | |
| f"({OPENAI_MODEL if USE_OPENAI else GEMINI_MODEL})") | |
| st.caption("Switch with env vars: USE_OPENAI, OPENAI_API_KEY, GOOGLE_API_KEY.") | |
| st.divider() | |
| st.write("**Files**") | |
| st.write(f"Index: `{VECTOR_PATH}`") | |
| st.write(f"Data : `{PKL_PATH}`") | |
| # ========= Main Layout ========= | |
| st.title("Halassa Lab Onboarder 🧠📄") | |
| st.caption("Ask questions; see exactly which passages were used.") | |
| if "chat" not in st.session_state: | |
| st.session_state.chat = [] # list[dict]: {"role": "user"/"assistant", "content": str, "retrieved": list} | |
| if "last_suggestions" not in st.session_state: | |
| st.session_state.last_suggestions = [] | |
| # Input row | |
| with st.container(): | |
| cols = st.columns([6, 1]) | |
| with cols[0]: | |
| user_message = st.text_input( | |
| "Ask your question", | |
| "", | |
| placeholder="e.g., How does MD dopamine shape error-driven flexibility?", | |
| ) | |
| with cols[1]: | |
| ask = st.button("Send", use_container_width=True) | |
| def answer_query(query: str): | |
| retrieved = retrieve(query, index, TEXTS, SOURCES, META, k=TOP_K) | |
| context_str = build_context(retrieved) | |
| sys_prompt = ( | |
| "You are an Expert scientist in the Halassa Lab at MIT, expert in computational neuroscience. " | |
| "Answer thoroughly and clearly. Synthesize from provided context; write in your own words. " | |
| "If you cite directly from a provided paper, add citations at the end as [filename - Page X]. " | |
| "If context is partial, add helpful background." | |
| ) | |
| user_prompt = f"""Context: | |
| --- | |
| {context_str} | |
| --- | |
| User Question: {query} | |
| Expert Answer: | |
| """ | |
| answer = call_llm(sys_prompt, user_prompt) | |
| st.session_state.chat.append({"role": "user", "content": query}) | |
| st.session_state.chat.append({"role": "assistant", "content": answer, "retrieved": retrieved}) | |
| try: | |
| st.session_state.last_suggestions = suggest_questions(answer, k=SUGGESTED_Q) | |
| except Exception: | |
| st.session_state.last_suggestions = [] | |
| # Trigger on click or Enter | |
| if ask and user_message.strip(): | |
| answer_query(user_message.strip()) | |
| elif user_message.strip() and st.session_state.chat == []: | |
| # allow pressing Enter to submit first question | |
| answer_query(user_message.strip()) | |
| # Two-column layout | |
| col_chat, col_docs = st.columns([2, 1], gap="large") | |
| # Left: Chat | |
| with col_chat: | |
| for turn in st.session_state.chat: | |
| if turn["role"] == "user": | |
| st.chat_message("user").markdown(turn["content"]) | |
| else: | |
| st.chat_message("assistant").markdown(turn["content"]) | |
| if st.session_state.last_suggestions: | |
| st.subheader("Try next:") | |
| sug_cols = st.columns(len(st.session_state.last_suggestions)) | |
| for i, q in enumerate(st.session_state.last_suggestions): | |
| if sug_cols[i].button(q): | |
| answer_query(q) | |
| # Right: Relevant chunks (no PDF viewer) | |
| with col_docs: | |
| st.subheader("Relevant Sources") | |
| last_assistant = None | |
| for t in reversed(st.session_state.chat): | |
| if t.get("role") == "assistant" and "retrieved" in t: | |
| last_assistant = t | |
| break | |
| if not last_assistant: | |
| st.info("Ask a question to see relevant passages.") | |
| else: | |
| # Find preceding user query for highlighting | |
| query_text = "" | |
| for i in range(len(st.session_state.chat)-1, -1, -1): | |
| if st.session_state.chat[i]["role"] == "user": | |
| query_text = st.session_state.chat[i]["content"] | |
| break | |
| for r in last_assistant["retrieved"]: | |
| src = r["source"] | |
| with st.expander(f"#{r['rank']} {src}"): | |
| html = highlight_terms(r["text"], query_text) | |
| st.markdown(html, unsafe_allow_html=True) | |
| # Small utility buttons | |
| st.download_button( | |
| "Download chunk", | |
| data=r["text"].encode("utf-8"), | |
| file_name=f"chunk_{r['rank']}.txt", | |
| use_container_width=True | |
| ) | |
| st.divider() | |
| st.caption("Tip: Ensure your `sources` strings match your citation format (e.g., `paper.pdf - Page 12`) so your LLM’s citations are clean.") | |