zakerytclarke commited on
Commit
ad255be
·
verified ·
1 Parent(s): 5b512b2

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +53 -91
src/streamlit_app.py CHANGED
@@ -8,7 +8,7 @@ import streamlit as st
8
  import torch
9
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
10
 
11
- # Optional parsing libs (best-effort)
12
  try:
13
  from pypdf import PdfReader # pip install pypdf
14
  except Exception:
@@ -25,6 +25,7 @@ try:
25
  except Exception:
26
  LangSmithClient = None
27
 
 
28
  # =========================
29
  # CONFIG
30
  # =========================
@@ -76,11 +77,12 @@ if "pending_response" not in st.session_state:
76
 
77
 
78
  # =========================
79
- # HEADER (LOGO)
80
  # =========================
81
  col1, col2 = st.columns([1, 6])
82
  with col1:
83
- st.image(LOGO_URL, use_container_width=True)
 
84
  with col2:
85
  st.markdown("## TeapotAI Chat")
86
  st.caption("Fast grounded answers with clean web context")
@@ -104,25 +106,25 @@ with st.sidebar:
104
  height=180,
105
  )
106
 
107
- st.markdown("### Local Context")
108
  local_context_text = st.text_area(
109
  "Paste additional context (optional)",
110
  height=140,
111
  placeholder="This will be appended after web content...",
112
  )
113
 
 
114
  uploaded_files = st.file_uploader(
115
- "Upload files to add to Local Context (pdf, txt, csv, md, json, etc.)",
116
- type=None,
117
  accept_multiple_files=True,
 
118
  )
119
 
120
 
121
  # =========================
122
- # FILE PARSING -> STRING
123
  # =========================
124
- def _safe_decode(b: bytes) -> str:
125
- # best effort decode without throwing
126
  for enc in ("utf-8", "utf-16", "latin-1"):
127
  try:
128
  return b.decode(enc)
@@ -131,58 +133,46 @@ def _safe_decode(b: bytes) -> str:
131
  return b.decode("utf-8", errors="ignore")
132
 
133
 
134
- def parse_uploaded_file_to_text(file) -> str:
135
  name = (file.name or "").lower()
136
  raw = file.getvalue()
137
 
138
  # PDF
139
- if name.endswith(".pdf"):
140
- if not PdfReader:
141
- return (
142
- f"[{file.name}] PDF parsing not available (install pypdf). "
143
- f"Raw bytes={len(raw)}"
144
- )
145
  try:
146
  reader = PdfReader(io.BytesIO(raw))
147
- parts = []
148
- for i, page in enumerate(reader.pages):
149
  txt = page.extract_text() or ""
150
- txt = txt.strip()
151
- if txt:
152
- parts.append(txt)
153
- return "\n\n".join(parts).strip()
154
  except Exception as e:
155
- return f"[{file.name}] PDF parse error: {e}"
156
 
157
  # CSV
158
- if name.endswith(".csv"):
159
- if not pd:
160
- return (
161
- f"[{file.name}] CSV parsing not available (install pandas). "
162
- f"Raw bytes={len(raw)}"
163
- )
164
  try:
165
  df = pd.read_csv(io.BytesIO(raw))
166
- # Keep it compact but readable
167
  return df.to_csv(index=False)
168
- except Exception as e:
169
- # fallback: raw text
170
- return f"[{file.name}] CSV parse error ({e}). Raw:\n{_safe_decode(raw)}"
171
 
172
- # JSON / TXT / MD / others -> decode
173
- return _safe_decode(raw).strip()
174
 
175
 
176
- def build_local_context(text_area: str, files) -> str:
177
  chunks = []
178
- if text_area.strip():
179
- chunks.append(text_area.strip())
 
180
 
181
  if files:
182
  for f in files:
183
- parsed = parse_uploaded_file_to_text(f).strip()
184
- if parsed:
185
- chunks.append(f"\n\n--- FILE: {f.name} ---\n{parsed}")
186
 
187
  return "\n\n".join(chunks).strip()
188
 
@@ -191,7 +181,7 @@ local_context = build_local_context(local_context_text, uploaded_files)
191
 
192
 
193
  # =========================
194
- # WEB SEARCH (SNIPPETS ONLY) - ALWAYS ON
195
  # =========================
196
  def web_search_snippets(query: str):
197
  api_key = os.getenv("BRAVE_API_KEY") or st.secrets.get("BRAVE_API_KEY", None)
@@ -221,14 +211,13 @@ def web_search_snippets(query: str):
221
  if desc:
222
  snippets.append(desc)
223
 
224
- clean_context = "\n\n".join(snippets) # paragraph-separated only
225
- return clean_context, (t1 - t0)
226
 
227
 
228
  # =========================
229
- # TRUNCATE TO LAST 512 TOKENS (TAIL)
230
  # =========================
231
- def truncate_context(web_ctx: str, local_ctx: str, system: str, question: str) -> str:
232
  ordered_context = f"{web_ctx}\n\n{local_ctx}".strip()
233
 
234
  base = f"\n{system}\n{question}\n"
@@ -236,13 +225,13 @@ def truncate_context(web_ctx: str, local_ctx: str, system: str, question: str) -
236
  budget = MAX_INPUT_TOKENS - len(base_tokens)
237
 
238
  if budget <= 0:
239
- return "" # system+question already consume budget
240
 
241
  ctx_tokens = tokenizer.encode(ordered_context) if ordered_context else []
242
  if len(ctx_tokens) <= budget:
243
  return ordered_context
244
 
245
- truncated = ctx_tokens[-budget:] # keep MOST RECENT tokens
246
  return tokenizer.decode(truncated, skip_special_tokens=True)
247
 
248
 
@@ -262,7 +251,7 @@ def stream_generate(prompt: str):
262
  streamer=streamer,
263
  )
264
 
265
- thread = threading.Thread(target=run, daemon=True)
266
  thread.start()
267
 
268
  text = ""
@@ -272,7 +261,7 @@ def stream_generate(prompt: str):
272
 
273
 
274
  # =========================
275
- # FEEDBACK HANDLER (Native st.feedback)
276
  # =========================
277
  def handle_feedback(idx: int):
278
  val = st.session_state[f"feedback_{idx}"]
@@ -301,7 +290,7 @@ for i, msg in enumerate(st.session_state.messages):
301
  st.markdown(msg["content"])
302
  continue
303
 
304
- # Assistant messages: collapsed-by-default expander = "whole message response be the dropdown"
305
  with st.expander("🫖 Assistant response (click to expand)", expanded=False):
306
  st.markdown(msg["content"])
307
 
@@ -312,17 +301,9 @@ for i, msg in enumerate(st.session_state.messages):
312
  f"🧾 in={msg['input_tokens']} • out={msg['output_tokens']}"
313
  )
314
 
315
- # Show EXACT prompt passed into the model (and the parts)
316
- st.markdown("---")
317
- st.markdown("#### Prompt & Inputs (exactly what was passed to the model)")
318
- st.markdown("**System prompt:**")
319
- st.code(msg.get("system_prompt", ""), language="text")
320
- st.markdown("**Question:**")
321
- st.code(msg.get("question", ""), language="text")
322
- st.markdown("**Full model input (prompt):**")
323
- st.code(msg.get("prompt", ""), language="text")
324
-
325
- # Native thumbs feedback (outside expander so it's still reachable)
326
  key = f"feedback_{i}"
327
  st.session_state.setdefault(key, msg.get("feedback"))
328
  st.feedback(
@@ -340,7 +321,6 @@ for i, msg in enumerate(st.session_state.messages):
340
  query = st.chat_input("Ask a question...")
341
 
342
  if query:
343
- # show user message first
344
  st.session_state.messages.append({"role": "user", "content": query})
345
  st.rerun()
346
 
@@ -355,24 +335,22 @@ if (
355
  ):
356
  question = st.session_state.messages[-1]["content"]
357
 
358
- # --- Web Search (always on) ---
359
  web_ctx, search_time = web_search_snippets(question)
360
 
361
- # --- Strict Order Context ---
362
  final_context = truncate_context(
363
- web_ctx=web_ctx,
364
- local_ctx=local_context,
365
- system=system_prompt,
366
- question=question,
367
  )
368
 
369
- # IMPORTANT: prompt is EXACTLY what we pass to the model
370
- prompt = f"{final_context}\n{system_prompt}\n{question}\n".strip() + "\n"
371
 
372
- # Token accounting (split input vs output)
373
  input_tokens = len(tokenizer.encode(prompt))
374
 
375
- # LangSmith run
376
  run_id = None
377
  if ls_client:
378
  try:
@@ -380,19 +358,14 @@ if (
380
  name="teapot_chat",
381
  run_type="llm",
382
  inputs={
383
- "web_content": web_ctx,
384
- "local_context": local_context,
385
- "system_prompt": system_prompt,
386
- "question": question,
387
- "final_context": final_context,
388
  "prompt": prompt,
 
389
  },
390
  )
391
  run_id = run.id
392
  except Exception:
393
  pass
394
 
395
- # --- Stream UI: assistant response itself is a dropdown ---
396
  with st.chat_message("assistant"):
397
  with st.expander("🫖 Assistant response (click to expand)", expanded=False):
398
  placeholder = st.empty()
@@ -414,13 +387,7 @@ if (
414
  f"🧾 in={input_tokens} • out={output_tokens}"
415
  )
416
 
417
- st.markdown("---")
418
- st.markdown("#### Prompt & Inputs (exactly what was passed to the model)")
419
- st.markdown("**System prompt:**")
420
- st.code(system_prompt, language="text")
421
- st.markdown("**Question:**")
422
- st.code(question, language="text")
423
- st.markdown("**Full model input (prompt):**")
424
  st.code(prompt, language="text")
425
 
426
  if ls_client and run_id:
@@ -433,11 +400,6 @@ if (
433
  {
434
  "role": "assistant",
435
  "content": final_text,
436
- "system_prompt": system_prompt,
437
- "question": question,
438
- "web_context": web_ctx,
439
- "local_context": local_context,
440
- "final_context": final_context,
441
  "prompt": prompt,
442
  "search_time": search_time,
443
  "gen_time": gen_time,
 
8
  import torch
9
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
10
 
11
+ # Optional parsing libs (safe fallbacks)
12
  try:
13
  from pypdf import PdfReader # pip install pypdf
14
  except Exception:
 
25
  except Exception:
26
  LangSmithClient = None
27
 
28
+
29
  # =========================
30
  # CONFIG
31
  # =========================
 
77
 
78
 
79
  # =========================
80
+ # HEADER (SAFE IMAGE CALL)
81
  # =========================
82
  col1, col2 = st.columns([1, 6])
83
  with col1:
84
+ # IMPORTANT: use_column_width=True (works on your Streamlit version)
85
+ st.image(LOGO_URL, use_column_width=True)
86
  with col2:
87
  st.markdown("## TeapotAI Chat")
88
  st.caption("Fast grounded answers with clean web context")
 
106
  height=180,
107
  )
108
 
109
+ st.markdown("### Local Context (Text)")
110
  local_context_text = st.text_area(
111
  "Paste additional context (optional)",
112
  height=140,
113
  placeholder="This will be appended after web content...",
114
  )
115
 
116
+ st.markdown("### Local Context (File Upload)")
117
  uploaded_files = st.file_uploader(
118
+ "Upload files (pdf, txt, csv, md, json, etc.)",
 
119
  accept_multiple_files=True,
120
+ type=None,
121
  )
122
 
123
 
124
  # =========================
125
+ # FILE PARSING
126
  # =========================
127
+ def safe_decode(b: bytes) -> str:
 
128
  for enc in ("utf-8", "utf-16", "latin-1"):
129
  try:
130
  return b.decode(enc)
 
133
  return b.decode("utf-8", errors="ignore")
134
 
135
 
136
+ def parse_file_to_text(file) -> str:
137
  name = (file.name or "").lower()
138
  raw = file.getvalue()
139
 
140
  # PDF
141
+ if name.endswith(".pdf") and PdfReader:
 
 
 
 
 
142
  try:
143
  reader = PdfReader(io.BytesIO(raw))
144
+ pages = []
145
+ for page in reader.pages:
146
  txt = page.extract_text() or ""
147
+ if txt.strip():
148
+ pages.append(txt.strip())
149
+ return "\n\n".join(pages)
 
150
  except Exception as e:
151
+ return f"[PDF parse error: {e}]"
152
 
153
  # CSV
154
+ if name.endswith(".csv") and pd:
 
 
 
 
 
155
  try:
156
  df = pd.read_csv(io.BytesIO(raw))
 
157
  return df.to_csv(index=False)
158
+ except Exception:
159
+ return safe_decode(raw)
 
160
 
161
+ # TXT / MD / JSON / fallback
162
+ return safe_decode(raw)
163
 
164
 
165
+ def build_local_context(text_block: str, files) -> str:
166
  chunks = []
167
+
168
+ if text_block and text_block.strip():
169
+ chunks.append(text_block.strip())
170
 
171
  if files:
172
  for f in files:
173
+ parsed = parse_file_to_text(f)
174
+ if parsed and parsed.strip():
175
+ chunks.append(f"\n\n--- FILE: {f.name} ---\n{parsed.strip()}")
176
 
177
  return "\n\n".join(chunks).strip()
178
 
 
181
 
182
 
183
  # =========================
184
+ # WEB SEARCH (ALWAYS ON)
185
  # =========================
186
  def web_search_snippets(query: str):
187
  api_key = os.getenv("BRAVE_API_KEY") or st.secrets.get("BRAVE_API_KEY", None)
 
211
  if desc:
212
  snippets.append(desc)
213
 
214
+ return "\n\n".join(snippets), (t1 - t0)
 
215
 
216
 
217
  # =========================
218
+ # CONTEXT TRUNCATION (TAIL)
219
  # =========================
220
+ def truncate_context(web_ctx, local_ctx, system, question):
221
  ordered_context = f"{web_ctx}\n\n{local_ctx}".strip()
222
 
223
  base = f"\n{system}\n{question}\n"
 
225
  budget = MAX_INPUT_TOKENS - len(base_tokens)
226
 
227
  if budget <= 0:
228
+ return ""
229
 
230
  ctx_tokens = tokenizer.encode(ordered_context) if ordered_context else []
231
  if len(ctx_tokens) <= budget:
232
  return ordered_context
233
 
234
+ truncated = ctx_tokens[-budget:]
235
  return tokenizer.decode(truncated, skip_special_tokens=True)
236
 
237
 
 
251
  streamer=streamer,
252
  )
253
 
254
+ thread = threading.Thread(target=run)
255
  thread.start()
256
 
257
  text = ""
 
261
 
262
 
263
  # =========================
264
+ # FEEDBACK HANDLER
265
  # =========================
266
  def handle_feedback(idx: int):
267
  val = st.session_state[f"feedback_{idx}"]
 
290
  st.markdown(msg["content"])
291
  continue
292
 
293
+ # Entire response as collapsed dropdown (less visible inspector)
294
  with st.expander("🫖 Assistant response (click to expand)", expanded=False):
295
  st.markdown(msg["content"])
296
 
 
301
  f"🧾 in={msg['input_tokens']} • out={msg['output_tokens']}"
302
  )
303
 
304
+ st.markdown("### Exact Model Input (Prompt)")
305
+ st.code(msg["prompt"], language="text")
306
+
 
 
 
 
 
 
 
 
307
  key = f"feedback_{i}"
308
  st.session_state.setdefault(key, msg.get("feedback"))
309
  st.feedback(
 
321
  query = st.chat_input("Ask a question...")
322
 
323
  if query:
 
324
  st.session_state.messages.append({"role": "user", "content": query})
325
  st.rerun()
326
 
 
335
  ):
336
  question = st.session_state.messages[-1]["content"]
337
 
338
+ # Always do web search
339
  web_ctx, search_time = web_search_snippets(question)
340
 
 
341
  final_context = truncate_context(
342
+ web_ctx,
343
+ local_context,
344
+ system_prompt,
345
+ question,
346
  )
347
 
348
+ # EXACT prompt passed to model
349
+ prompt = f"{final_context}\n{system_prompt}\n{question}\n"
350
 
 
351
  input_tokens = len(tokenizer.encode(prompt))
352
 
353
+ # LangSmith run (optional)
354
  run_id = None
355
  if ls_client:
356
  try:
 
358
  name="teapot_chat",
359
  run_type="llm",
360
  inputs={
 
 
 
 
 
361
  "prompt": prompt,
362
+ "question": question,
363
  },
364
  )
365
  run_id = run.id
366
  except Exception:
367
  pass
368
 
 
369
  with st.chat_message("assistant"):
370
  with st.expander("🫖 Assistant response (click to expand)", expanded=False):
371
  placeholder = st.empty()
 
387
  f"🧾 in={input_tokens} • out={output_tokens}"
388
  )
389
 
390
+ st.markdown("### Exact Model Input (Prompt)")
 
 
 
 
 
 
391
  st.code(prompt, language="text")
392
 
393
  if ls_client and run_id:
 
400
  {
401
  "role": "assistant",
402
  "content": final_text,
 
 
 
 
 
403
  "prompt": prompt,
404
  "search_time": search_time,
405
  "gen_time": gen_time,