sahilursa commited on
Commit
1075316
·
verified ·
1 Parent(s): 7b1b136

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +270 -33
src/streamlit_app.py CHANGED
@@ -1,40 +1,277 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
 
 
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  """
7
- # Welcome to Streamlit!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import os
2
+ import re
3
+ import pickle
4
+ from pathlib import Path
5
+ from typing import List, Dict, Any
6
+
7
  import streamlit as st
8
+ import numpy as np
9
+ import faiss
10
+ from sentence_transformers import SentenceTransformer
11
+
12
+ # ========= LLM backend config =========
13
+ USE_OPENAI = os.getenv("USE_OPENAI", "0") == "1"
14
+ GEMINI_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.5-flash")
15
+ OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
16
+
17
+ if USE_OPENAI:
18
+ from openai import OpenAI
19
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
20
+ if OPENAI_API_KEY:
21
+ openai_client = OpenAI(api_key=OPENAI_API_KEY)
22
+ else:
23
+ openai_client = None
24
+ else:
25
+ import google.generativeai as genai
26
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "")
27
+ if GOOGLE_API_KEY:
28
+ genai.configure(api_key=GOOGLE_API_KEY)
29
+
30
+ # ========= Page config =========
31
+ st.set_page_config(
32
+ page_title="Halassa Lab Literature Chatbot",
33
+ page_icon="🧠",
34
+ layout="wide",
35
+ )
36
+
37
+ # ========= Paths & knobs =========
38
+ DATA_DIR = Path(os.getenv("DATA_DIR", "data"))
39
+ VECTOR_PATH = DATA_DIR / "vector_store.index"
40
+ PKL_PATH = DATA_DIR / "data.pkl"
41
+
42
+ EMBED_MODEL_NAME = os.getenv("EMBED_MODEL_NAME", "BAAI/bge-large-en-v1.5")
43
+ TOP_K = int(os.getenv("TOP_K", "5"))
44
+ MAX_CONTEXT_CHARS = int(os.getenv("MAX_CONTEXT_CHARS", "12000"))
45
+ SUGGESTED_Q = int(os.getenv("SUGGESTED_Q", "4"))
46
+
47
+ # ========= Helpers =========
48
+ def load_index_and_data():
49
+ if not VECTOR_PATH.exists() or not PKL_PATH.exists():
50
+ st.error(f"Missing index or data:\n- {VECTOR_PATH}\n- {PKL_PATH}")
51
+ st.stop()
52
+
53
+ index = faiss.read_index(str(VECTOR_PATH))
54
+ with open(PKL_PATH, "rb") as f:
55
+ stored = pickle.load(f)
56
+
57
+ texts = stored.get("texts", [])
58
+ sources = stored.get("sources", [])
59
+ meta = stored.get("meta", [None] * len(texts))
60
+
61
+ if len(texts) == 0 or len(texts) != len(sources):
62
+ st.error("data.pkl must contain 'texts' and 'sources' of equal length.")
63
+ st.stop()
64
+
65
+ return index, texts, sources, meta
66
+
67
+ @st.cache_resource(show_spinner=False)
68
+ def get_embedder():
69
+ return SentenceTransformer(EMBED_MODEL_NAME)
70
+
71
+ def encode_query(query: str, embedder) -> np.ndarray:
72
+ vec = embedder.encode([query])
73
+ return vec.astype(np.float32)
74
+
75
+ def retrieve(query: str, index, texts, sources, meta, k=TOP_K):
76
+ embedder = get_embedder()
77
+ qvec = encode_query(query, embedder)
78
+ D, I = index.search(qvec, k)
79
+ results = []
80
+ for rank, idx in enumerate(I[0].tolist()):
81
+ if 0 <= idx < len(texts):
82
+ results.append({
83
+ "rank": rank + 1,
84
+ "text": texts[idx],
85
+ "source": sources[idx],
86
+ "meta": meta[idx] if meta and idx < len(meta) else None
87
+ })
88
+ return results
89
+
90
+ def build_context(retrieved: List[Dict[str, Any]]) -> str:
91
+ parts, total = [], 0
92
+ for r in retrieved:
93
+ src = r["source"]
94
+ txt = r["text"].strip()
95
+ chunk = f"Source: {src}\nContent: {txt}\n"
96
+ if total + len(chunk) > MAX_CONTEXT_CHARS:
97
+ break
98
+ parts.append(chunk)
99
+ total += len(chunk)
100
+ return "\n---\n".join(parts)
101
+
102
+ def call_llm(system_prompt: str, user_prompt: str) -> str:
103
+ # OpenAI path
104
+ if USE_OPENAI and os.getenv("OPENAI_API_KEY") and openai_client:
105
+ resp = openai_client.chat.completions.create(
106
+ model=OPENAI_MODEL,
107
+ messages=[
108
+ {"role": "system", "content": system_prompt},
109
+ {"role": "user", "content": user_prompt},
110
+ ],
111
+ temperature=0.2,
112
+ )
113
+ return resp.choices[0].message.content
114
+
115
+ # Gemini path
116
+ if not USE_OPENAI and os.getenv("GOOGLE_API_KEY"):
117
+ model = genai.GenerativeModel(GEMINI_MODEL)
118
+ resp = model.generate_content(system_prompt + "\n\n" + user_prompt)
119
+ return resp.text
120
+
121
+ # Fallback (no key) for UI testing
122
+ return "(LLM disabled) " + user_prompt[:800]
123
 
124
+ def highlight_terms(text: str, query: str) -> str:
125
+ # lightweight term highlighter
126
+ import re
127
+ terms = [t for t in re.split(r"\W+", query) if len(t) >= 3]
128
+ out = text
129
+ for t in set(terms):
130
+ out = re.sub(rf"({re.escape(t)})", r"<mark>\1</mark>", out, flags=re.IGNORECASE)
131
+ return out
132
+
133
+ def suggest_questions(last_answer: str, k=SUGGESTED_Q) -> List[str]:
134
+ prompt = f"""Generate {k} concise follow-up questions a user might ask next, given the expert answer below.
135
+ Each question should be short (max ~12 words) and deepen the discussion.
136
+
137
+ Answer only with a bulletless list, one question per line.
138
+
139
+ Expert answer:
140
+ {last_answer}
141
  """
142
+ out = call_llm(
143
+ system_prompt="You are a helpful assistant that proposes follow-up questions.",
144
+ user_prompt=prompt,
145
+ )
146
+ qs = [re.sub(r"^[\-\*\d\.\)\s]+", "", q).strip() for q in out.splitlines() if q.strip()]
147
+ return [q for q in qs if q][:k]
148
+
149
+ # ========= Load index/data =========
150
+ index, TEXTS, SOURCES, META = load_index_and_data()
151
+
152
+ # ========= Sidebar =========
153
+ with st.sidebar:
154
+ st.title("⚙️ Settings")
155
+ st.write("**Retrieval**")
156
+ TOP_K = st.slider("Top-K passages", 3, 10, TOP_K)
157
+ st.divider()
158
+ st.write("**Models**")
159
+ st.write(f"Embedding: `{EMBED_MODEL_NAME}`")
160
+ st.write("LLM:", "OpenAI" if USE_OPENAI else "Gemini",
161
+ f"({OPENAI_MODEL if USE_OPENAI else GEMINI_MODEL})")
162
+ st.caption("Switch with env vars: USE_OPENAI, OPENAI_API_KEY, GOOGLE_API_KEY.")
163
+ st.divider()
164
+ st.write("**Files**")
165
+ st.write(f"Index: `{VECTOR_PATH}`")
166
+ st.write(f"Data : `{PKL_PATH}`")
167
+
168
+ # ========= Main Layout =========
169
+ st.title("Halassa Lab Onboarder 🧠📄")
170
+ st.caption("Ask questions; see exactly which passages were used.")
171
+
172
+ if "chat" not in st.session_state:
173
+ st.session_state.chat = [] # list[dict]: {"role": "user"/"assistant", "content": str, "retrieved": list}
174
+ if "last_suggestions" not in st.session_state:
175
+ st.session_state.last_suggestions = []
176
+
177
+ # Input row
178
+ with st.container():
179
+ cols = st.columns([6, 1])
180
+ with cols[0]:
181
+ user_message = st.text_input(
182
+ "Ask your question",
183
+ "",
184
+ placeholder="e.g., How does MD dopamine shape error-driven flexibility?",
185
+ )
186
+ with cols[1]:
187
+ ask = st.button("Send", use_container_width=True)
188
+
189
+ def answer_query(query: str):
190
+ retrieved = retrieve(query, index, TEXTS, SOURCES, META, k=TOP_K)
191
+ context_str = build_context(retrieved)
192
+
193
+ sys_prompt = (
194
+ "You are an Expert scientist in the Halassa Lab at MIT, expert in computational neuroscience. "
195
+ "Answer thoroughly and clearly. Synthesize from provided context; write in your own words. "
196
+ "If you cite directly from a provided paper, add citations at the end as [filename - Page X]. "
197
+ "If context is partial, add helpful background."
198
+ )
199
+ user_prompt = f"""Context:
200
+ ---
201
+ {context_str}
202
+ ---
203
 
204
+ User Question: {query}
 
 
205
 
206
+ Expert Answer:
207
  """
208
+ answer = call_llm(sys_prompt, user_prompt)
209
+
210
+ st.session_state.chat.append({"role": "user", "content": query})
211
+ st.session_state.chat.append({"role": "assistant", "content": answer, "retrieved": retrieved})
212
+
213
+ try:
214
+ st.session_state.last_suggestions = suggest_questions(answer, k=SUGGESTED_Q)
215
+ except Exception:
216
+ st.session_state.last_suggestions = []
217
+
218
+ # Trigger on click or Enter
219
+ if ask and user_message.strip():
220
+ answer_query(user_message.strip())
221
+ elif user_message.strip() and st.session_state.chat == []:
222
+ # allow pressing Enter to submit first question
223
+ answer_query(user_message.strip())
224
+
225
+ # Two-column layout
226
+ col_chat, col_docs = st.columns([2, 1], gap="large")
227
+
228
+ # Left: Chat
229
+ with col_chat:
230
+ for turn in st.session_state.chat:
231
+ if turn["role"] == "user":
232
+ st.chat_message("user").markdown(turn["content"])
233
+ else:
234
+ st.chat_message("assistant").markdown(turn["content"])
235
+
236
+ if st.session_state.last_suggestions:
237
+ st.subheader("Try next:")
238
+ sug_cols = st.columns(len(st.session_state.last_suggestions))
239
+ for i, q in enumerate(st.session_state.last_suggestions):
240
+ if sug_cols[i].button(q):
241
+ answer_query(q)
242
+
243
+ # Right: Relevant chunks (no PDF viewer)
244
+ with col_docs:
245
+ st.subheader("Relevant Sources")
246
+ last_assistant = None
247
+ for t in reversed(st.session_state.chat):
248
+ if t.get("role") == "assistant" and "retrieved" in t:
249
+ last_assistant = t
250
+ break
251
+
252
+ if not last_assistant:
253
+ st.info("Ask a question to see relevant passages.")
254
+ else:
255
+ # Find preceding user query for highlighting
256
+ query_text = ""
257
+ for i in range(len(st.session_state.chat)-1, -1, -1):
258
+ if st.session_state.chat[i]["role"] == "user":
259
+ query_text = st.session_state.chat[i]["content"]
260
+ break
261
+
262
+ for r in last_assistant["retrieved"]:
263
+ src = r["source"]
264
+ with st.expander(f"#{r['rank']} {src}"):
265
+ html = highlight_terms(r["text"], query_text)
266
+ st.markdown(html, unsafe_allow_html=True)
267
+
268
+ # Small utility buttons
269
+ st.download_button(
270
+ "Download chunk",
271
+ data=r["text"].encode("utf-8"),
272
+ file_name=f"chunk_{r['rank']}.txt",
273
+ use_container_width=True
274
+ )
275
 
276
+ st.divider()
277
+ st.caption("Tip: Ensure your `sources` strings match your citation format (e.g., `paper.pdf - Page 12`) so your LLM’s citations are clean.")