sahilursa commited on
Commit
3473bc7
·
verified ·
1 Parent(s): 97608bf

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -277
app.py DELETED
@@ -1,277 +0,0 @@
1
- import os
2
- import re
3
- import pickle
4
- from pathlib import Path
5
- from typing import List, Dict, Any
6
-
7
- import streamlit as st
8
- import numpy as np
9
- import faiss
10
- from sentence_transformers import SentenceTransformer
11
-
12
- # ========= LLM backend config =========
13
- USE_OPENAI = os.getenv("USE_OPENAI", "0") == "1"
14
- GEMINI_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.5-flash")
15
- OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
16
-
17
- if USE_OPENAI:
18
- from openai import OpenAI
19
- OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
20
- if OPENAI_API_KEY:
21
- openai_client = OpenAI(api_key=OPENAI_API_KEY)
22
- else:
23
- openai_client = None
24
- else:
25
- import google.generativeai as genai
26
- GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY", "")
27
- if GOOGLE_API_KEY:
28
- genai.configure(api_key=GOOGLE_API_KEY)
29
-
30
- # ========= Page config =========
31
- st.set_page_config(
32
- page_title="Halassa Lab Literature Chatbot",
33
- page_icon="🧠",
34
- layout="wide",
35
- )
36
-
37
- # ========= Paths & knobs =========
38
- DATA_DIR = Path(os.getenv("DATA_DIR", "data"))
39
- VECTOR_PATH = DATA_DIR / "vector_store.index"
40
- PKL_PATH = DATA_DIR / "data.pkl"
41
-
42
- EMBED_MODEL_NAME = os.getenv("EMBED_MODEL_NAME", "BAAI/bge-large-en-v1.5")
43
- TOP_K = int(os.getenv("TOP_K", "5"))
44
- MAX_CONTEXT_CHARS = int(os.getenv("MAX_CONTEXT_CHARS", "12000"))
45
- SUGGESTED_Q = int(os.getenv("SUGGESTED_Q", "4"))
46
-
47
- # ========= Helpers =========
48
- def load_index_and_data():
49
- if not VECTOR_PATH.exists() or not PKL_PATH.exists():
50
- st.error(f"Missing index or data:\n- {VECTOR_PATH}\n- {PKL_PATH}")
51
- st.stop()
52
-
53
- index = faiss.read_index(str(VECTOR_PATH))
54
- with open(PKL_PATH, "rb") as f:
55
- stored = pickle.load(f)
56
-
57
- texts = stored.get("texts", [])
58
- sources = stored.get("sources", [])
59
- meta = stored.get("meta", [None] * len(texts))
60
-
61
- if len(texts) == 0 or len(texts) != len(sources):
62
- st.error("data.pkl must contain 'texts' and 'sources' of equal length.")
63
- st.stop()
64
-
65
- return index, texts, sources, meta
66
-
67
- @st.cache_resource(show_spinner=False)
68
- def get_embedder():
69
- return SentenceTransformer(EMBED_MODEL_NAME)
70
-
71
- def encode_query(query: str, embedder) -> np.ndarray:
72
- vec = embedder.encode([query])
73
- return vec.astype(np.float32)
74
-
75
- def retrieve(query: str, index, texts, sources, meta, k=TOP_K):
76
- embedder = get_embedder()
77
- qvec = encode_query(query, embedder)
78
- D, I = index.search(qvec, k)
79
- results = []
80
- for rank, idx in enumerate(I[0].tolist()):
81
- if 0 <= idx < len(texts):
82
- results.append({
83
- "rank": rank + 1,
84
- "text": texts[idx],
85
- "source": sources[idx],
86
- "meta": meta[idx] if meta and idx < len(meta) else None
87
- })
88
- return results
89
-
90
- def build_context(retrieved: List[Dict[str, Any]]) -> str:
91
- parts, total = [], 0
92
- for r in retrieved:
93
- src = r["source"]
94
- txt = r["text"].strip()
95
- chunk = f"Source: {src}\nContent: {txt}\n"
96
- if total + len(chunk) > MAX_CONTEXT_CHARS:
97
- break
98
- parts.append(chunk)
99
- total += len(chunk)
100
- return "\n---\n".join(parts)
101
-
102
- def call_llm(system_prompt: str, user_prompt: str) -> str:
103
- # OpenAI path
104
- if USE_OPENAI and os.getenv("OPENAI_API_KEY") and openai_client:
105
- resp = openai_client.chat.completions.create(
106
- model=OPENAI_MODEL,
107
- messages=[
108
- {"role": "system", "content": system_prompt},
109
- {"role": "user", "content": user_prompt},
110
- ],
111
- temperature=0.2,
112
- )
113
- return resp.choices[0].message.content
114
-
115
- # Gemini path
116
- if not USE_OPENAI and os.getenv("GOOGLE_API_KEY"):
117
- model = genai.GenerativeModel(GEMINI_MODEL)
118
- resp = model.generate_content(system_prompt + "\n\n" + user_prompt)
119
- return resp.text
120
-
121
- # Fallback (no key) for UI testing
122
- return "(LLM disabled) " + user_prompt[:800]
123
-
124
- def highlight_terms(text: str, query: str) -> str:
125
- # lightweight term highlighter
126
- import re
127
- terms = [t for t in re.split(r"\W+", query) if len(t) >= 3]
128
- out = text
129
- for t in set(terms):
130
- out = re.sub(rf"({re.escape(t)})", r"<mark>\1</mark>", out, flags=re.IGNORECASE)
131
- return out
132
-
133
- def suggest_questions(last_answer: str, k=SUGGESTED_Q) -> List[str]:
134
- prompt = f"""Generate {k} concise follow-up questions a user might ask next, given the expert answer below.
135
- Each question should be short (max ~12 words) and deepen the discussion.
136
-
137
- Answer only with a bulletless list, one question per line.
138
-
139
- Expert answer:
140
- {last_answer}
141
- """
142
- out = call_llm(
143
- system_prompt="You are a helpful assistant that proposes follow-up questions.",
144
- user_prompt=prompt,
145
- )
146
- qs = [re.sub(r"^[\-\*\d\.\)\s]+", "", q).strip() for q in out.splitlines() if q.strip()]
147
- return [q for q in qs if q][:k]
148
-
149
- # ========= Load index/data =========
150
- index, TEXTS, SOURCES, META = load_index_and_data()
151
-
152
- # ========= Sidebar =========
153
- with st.sidebar:
154
- st.title("⚙️ Settings")
155
- st.write("**Retrieval**")
156
- TOP_K = st.slider("Top-K passages", 3, 10, TOP_K)
157
- st.divider()
158
- st.write("**Models**")
159
- st.write(f"Embedding: `{EMBED_MODEL_NAME}`")
160
- st.write("LLM:", "OpenAI" if USE_OPENAI else "Gemini",
161
- f"({OPENAI_MODEL if USE_OPENAI else GEMINI_MODEL})")
162
- st.caption("Switch with env vars: USE_OPENAI, OPENAI_API_KEY, GOOGLE_API_KEY.")
163
- st.divider()
164
- st.write("**Files**")
165
- st.write(f"Index: `{VECTOR_PATH}`")
166
- st.write(f"Data : `{PKL_PATH}`")
167
-
168
- # ========= Main Layout =========
169
- st.title("Halassa Lab Onboarder 🧠📄")
170
- st.caption("Ask questions; see exactly which passages were used.")
171
-
172
- if "chat" not in st.session_state:
173
- st.session_state.chat = [] # list[dict]: {"role": "user"/"assistant", "content": str, "retrieved": list}
174
- if "last_suggestions" not in st.session_state:
175
- st.session_state.last_suggestions = []
176
-
177
- # Input row
178
- with st.container():
179
- cols = st.columns([6, 1])
180
- with cols[0]:
181
- user_message = st.text_input(
182
- "Ask your question",
183
- "",
184
- placeholder="e.g., How does MD dopamine shape error-driven flexibility?",
185
- )
186
- with cols[1]:
187
- ask = st.button("Send", use_container_width=True)
188
-
189
- def answer_query(query: str):
190
- retrieved = retrieve(query, index, TEXTS, SOURCES, META, k=TOP_K)
191
- context_str = build_context(retrieved)
192
-
193
- sys_prompt = (
194
- "You are an Expert scientist in the Halassa Lab at MIT, expert in computational neuroscience. "
195
- "Answer thoroughly and clearly. Synthesize from provided context; write in your own words. "
196
- "If you cite directly from a provided paper, add citations at the end as [filename - Page X]. "
197
- "If context is partial, add helpful background."
198
- )
199
- user_prompt = f"""Context:
200
- ---
201
- {context_str}
202
- ---
203
-
204
- User Question: {query}
205
-
206
- Expert Answer:
207
- """
208
- answer = call_llm(sys_prompt, user_prompt)
209
-
210
- st.session_state.chat.append({"role": "user", "content": query})
211
- st.session_state.chat.append({"role": "assistant", "content": answer, "retrieved": retrieved})
212
-
213
- try:
214
- st.session_state.last_suggestions = suggest_questions(answer, k=SUGGESTED_Q)
215
- except Exception:
216
- st.session_state.last_suggestions = []
217
-
218
- # Trigger on click or Enter
219
- if ask and user_message.strip():
220
- answer_query(user_message.strip())
221
- elif user_message.strip() and st.session_state.chat == []:
222
- # allow pressing Enter to submit first question
223
- answer_query(user_message.strip())
224
-
225
- # Two-column layout
226
- col_chat, col_docs = st.columns([2, 1], gap="large")
227
-
228
- # Left: Chat
229
- with col_chat:
230
- for turn in st.session_state.chat:
231
- if turn["role"] == "user":
232
- st.chat_message("user").markdown(turn["content"])
233
- else:
234
- st.chat_message("assistant").markdown(turn["content"])
235
-
236
- if st.session_state.last_suggestions:
237
- st.subheader("Try next:")
238
- sug_cols = st.columns(len(st.session_state.last_suggestions))
239
- for i, q in enumerate(st.session_state.last_suggestions):
240
- if sug_cols[i].button(q):
241
- answer_query(q)
242
-
243
- # Right: Relevant chunks (no PDF viewer)
244
- with col_docs:
245
- st.subheader("Relevant Sources")
246
- last_assistant = None
247
- for t in reversed(st.session_state.chat):
248
- if t.get("role") == "assistant" and "retrieved" in t:
249
- last_assistant = t
250
- break
251
-
252
- if not last_assistant:
253
- st.info("Ask a question to see relevant passages.")
254
- else:
255
- # Find preceding user query for highlighting
256
- query_text = ""
257
- for i in range(len(st.session_state.chat)-1, -1, -1):
258
- if st.session_state.chat[i]["role"] == "user":
259
- query_text = st.session_state.chat[i]["content"]
260
- break
261
-
262
- for r in last_assistant["retrieved"]:
263
- src = r["source"]
264
- with st.expander(f"#{r['rank']} {src}"):
265
- html = highlight_terms(r["text"], query_text)
266
- st.markdown(html, unsafe_allow_html=True)
267
-
268
- # Small utility buttons
269
- st.download_button(
270
- "Download chunk",
271
- data=r["text"].encode("utf-8"),
272
- file_name=f"chunk_{r['rank']}.txt",
273
- use_container_width=True
274
- )
275
-
276
- st.divider()
277
- st.caption("Tip: Ensure your `sources` strings match your citation format (e.g., `paper.pdf - Page 12`) so your LLM’s citations are clean.")