zakerytclarke commited on
Commit
b00bb52
·
verified ·
1 Parent(s): e4379b8

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +204 -81
src/streamlit_app.py CHANGED
@@ -2,14 +2,27 @@ import os
2
  import time
3
  import threading
4
  import requests
 
 
5
  import streamlit as st
6
  import torch
7
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
8
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Optional LangSmith
10
  try:
11
  from langsmith import Client as LangSmithClient
12
- except:
13
  LangSmithClient = None
14
 
15
  # =========================
@@ -23,6 +36,7 @@ LOGO_URL = "https://teapotai.com/assets/logo.gif"
23
 
24
  st.set_page_config(page_title="TeapotAI Chat", page_icon="🫖", layout="centered")
25
 
 
26
  # =========================
27
  # LOAD MODEL (CACHED)
28
  # =========================
@@ -34,8 +48,10 @@ def load_model():
34
  model.to(device).eval()
35
  return tokenizer, model, device
36
 
 
37
  tokenizer, model, device = load_model()
38
 
 
39
  # =========================
40
  # LANGSMITH (OPTIONAL)
41
  # =========================
@@ -46,8 +62,10 @@ def get_langsmith():
46
  return LangSmithClient()
47
  return None
48
 
 
49
  ls_client = get_langsmith()
50
 
 
51
  # =========================
52
  # SESSION STATE
53
  # =========================
@@ -56,16 +74,18 @@ if "messages" not in st.session_state:
56
  if "pending_response" not in st.session_state:
57
  st.session_state.pending_response = None
58
 
 
59
  # =========================
60
  # HEADER (LOGO)
61
  # =========================
62
  col1, col2 = st.columns([1, 6])
63
  with col1:
64
- st.image(LOGO_URL, use_column_width=True)
65
  with col2:
66
  st.markdown("## TeapotAI Chat")
67
  st.caption("Fast grounded answers with clean web context")
68
 
 
69
  # =========================
70
  # SIDEBAR SETTINGS
71
  # =========================
@@ -84,17 +104,94 @@ with st.sidebar:
84
  height=180,
85
  )
86
 
87
- st.markdown("### Local Context (Appended)")
88
- local_context = st.text_area(
89
  "Paste additional context (optional)",
90
- height=160,
91
- placeholder="This will be appended after web content..."
 
 
 
 
 
 
92
  )
93
 
94
- use_web = st.checkbox("Use web search", value=True)
95
 
96
  # =========================
97
- # WEB SEARCH (SNIPPETS ONLY)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  # =========================
99
  def web_search_snippets(query: str):
100
  api_key = os.getenv("BRAVE_API_KEY") or st.secrets.get("BRAVE_API_KEY", None)
@@ -113,7 +210,7 @@ def web_search_snippets(query: str):
113
  timeout=6,
114
  )
115
  data = r.json()
116
- except:
117
  return "", 0.0
118
  t1 = time.perf_counter()
119
 
@@ -124,30 +221,31 @@ def web_search_snippets(query: str):
124
  if desc:
125
  snippets.append(desc)
126
 
127
- # Paragraph-separated ONLY (no title, no URL)
128
- clean_context = "\n\n".join(snippets)
129
  return clean_context, (t1 - t0)
130
 
 
131
  # =========================
132
- # TRUNCATE TO LAST 512 TOKENS
133
  # =========================
134
- def truncate_context(web_ctx, local_ctx, system, question):
135
- ordered_context = (
136
- f"{web_ctx}\n\n{local_ctx}".strip()
137
- )
138
 
139
  base = f"\n{system}\n{question}\n"
140
  base_tokens = tokenizer.encode(base)
141
  budget = MAX_INPUT_TOKENS - len(base_tokens)
142
 
143
- ctx_tokens = tokenizer.encode(ordered_context)
 
 
 
144
  if len(ctx_tokens) <= budget:
145
  return ordered_context
146
 
147
- # Keep MOST RECENT tokens (tail truncation)
148
- truncated = ctx_tokens[-budget:]
149
  return tokenizer.decode(truncated, skip_special_tokens=True)
150
 
 
151
  # =========================
152
  # STREAM GENERATION
153
  # =========================
@@ -164,7 +262,7 @@ def stream_generate(prompt: str):
164
  streamer=streamer,
165
  )
166
 
167
- thread = threading.Thread(target=run)
168
  thread.start()
169
 
170
  text = ""
@@ -172,6 +270,7 @@ def stream_generate(prompt: str):
172
  text += chunk
173
  yield text
174
 
 
175
  # =========================
176
  # FEEDBACK HANDLER (Native st.feedback)
177
  # =========================
@@ -189,80 +288,89 @@ def handle_feedback(idx: int):
189
  score=score,
190
  comment="thumbs_up" if score else "thumbs_down",
191
  )
192
- except:
193
  pass
194
 
 
195
  # =========================
196
  # RENDER CHAT HISTORY
197
  # =========================
198
  for i, msg in enumerate(st.session_state.messages):
199
  with st.chat_message(msg["role"]):
200
- st.markdown(msg["content"])
 
 
 
 
 
 
201
 
202
- if msg["role"] == "assistant":
203
- # Metrics row
204
  st.caption(
205
  f"🔎 {msg['search_time']:.2f}s search • "
206
  f"🧠 {msg['gen_time']:.2f}s generate • "
207
  f"⚡ {msg['tps']:.1f} tok/s • "
208
- f"🧮 {msg['tokens']} tokens"
209
  )
210
 
211
- # Inspectable context (clean UX)
212
- with st.expander("🔍 Inspect Context Used"):
213
- st.markdown("**Web Content:**")
214
- st.write(msg["web_context"] or "_None_")
215
- st.markdown("**Local Context:**")
216
- st.write(msg["local_context"] or "_None_")
217
- st.markdown("**Final Truncated Context (512 tokens tail):**")
218
- st.write(msg["final_context"])
219
-
220
- # Native thumbs feedback
221
- key = f"feedback_{i}"
222
- st.session_state.setdefault(key, msg.get("feedback"))
223
- st.feedback(
224
- "thumbs",
225
- key=key,
226
- disabled=msg.get("feedback") is not None,
227
- on_change=handle_feedback,
228
- args=(i,),
229
- )
 
 
230
 
231
  # =========================
232
- # USER INPUT (FIXED ORDER)
233
  # =========================
234
  query = st.chat_input("Ask a question...")
235
 
236
  if query:
237
- # 1️⃣ Immediately show user message FIRST (fix streaming race)
238
  st.session_state.messages.append({"role": "user", "content": query})
239
  st.rerun()
240
 
 
241
  # =========================
242
- # GENERATE AFTER RERUN (Prevents premature streaming)
243
  # =========================
244
  if (
245
  st.session_state.messages
246
  and st.session_state.messages[-1]["role"] == "user"
247
  and st.session_state.pending_response is None
248
  ):
249
- query = st.session_state.messages[-1]["content"]
250
 
251
- # --- Web Search ---
252
- web_ctx = ""
253
- search_time = 0.0
254
- if use_web:
255
- web_ctx, search_time = web_search_snippets(query)
256
 
257
  # --- Strict Order Context ---
258
  final_context = truncate_context(
259
- web_ctx,
260
- local_context,
261
- system_prompt,
262
- query,
263
  )
264
 
265
- prompt = f"{final_context}\n{system_prompt}\n{query}\n"
 
 
 
 
266
 
267
  # LangSmith run
268
  run_id = None
@@ -275,51 +383,66 @@ if (
275
  "web_content": web_ctx,
276
  "local_context": local_context,
277
  "system_prompt": system_prompt,
278
- "question": query,
279
  "final_context": final_context,
 
280
  },
281
  )
282
  run_id = run.id
283
- except:
284
  pass
285
 
286
- # --- Stream UI ---
287
  with st.chat_message("assistant"):
288
- placeholder = st.empty()
289
- start = time.perf_counter()
290
- final_text = ""
291
-
292
- for partial in stream_generate(prompt):
293
- final_text = partial
294
- placeholder.markdown(final_text)
295
-
296
- gen_time = time.perf_counter() - start
297
- tokens = len(tokenizer.encode(final_text))
298
- tps = tokens / gen_time if gen_time > 0 else 0.0
299
-
300
- st.caption(
301
- f"🔎 {search_time:.2f}s search • "
302
- f"🧠 {gen_time:.2f}s generate • "
303
- f" {tps:.1f} tok/s • "
304
- f"🧮 {tokens} tokens"
305
- )
 
 
 
 
 
 
 
 
 
 
306
 
307
  if ls_client and run_id:
308
  try:
309
  ls_client.update_run(run_id, outputs={"answer": final_text})
310
- except:
311
  pass
312
 
313
  st.session_state.messages.append(
314
  {
315
  "role": "assistant",
316
  "content": final_text,
 
 
317
  "web_context": web_ctx,
318
  "local_context": local_context,
319
  "final_context": final_context,
 
320
  "search_time": search_time,
321
  "gen_time": gen_time,
322
- "tokens": tokens,
 
323
  "tps": tps,
324
  "run_id": run_id,
325
  "feedback": None,
 
2
  import time
3
  import threading
4
  import requests
5
+ import io
6
+
7
  import streamlit as st
8
  import torch
9
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
10
 
11
+ # Optional parsing libs (best-effort)
12
+ try:
13
+ from pypdf import PdfReader # pip install pypdf
14
+ except Exception:
15
+ PdfReader = None
16
+
17
+ try:
18
+ import pandas as pd # pip install pandas
19
+ except Exception:
20
+ pd = None
21
+
22
  # Optional LangSmith
23
  try:
24
  from langsmith import Client as LangSmithClient
25
+ except Exception:
26
  LangSmithClient = None
27
 
28
  # =========================
 
36
 
37
  st.set_page_config(page_title="TeapotAI Chat", page_icon="🫖", layout="centered")
38
 
39
+
40
  # =========================
41
  # LOAD MODEL (CACHED)
42
  # =========================
 
48
  model.to(device).eval()
49
  return tokenizer, model, device
50
 
51
+
52
  tokenizer, model, device = load_model()
53
 
54
+
55
  # =========================
56
  # LANGSMITH (OPTIONAL)
57
  # =========================
 
62
  return LangSmithClient()
63
  return None
64
 
65
+
66
  ls_client = get_langsmith()
67
 
68
+
69
  # =========================
70
  # SESSION STATE
71
  # =========================
 
74
  if "pending_response" not in st.session_state:
75
  st.session_state.pending_response = None
76
 
77
+
78
  # =========================
79
  # HEADER (LOGO)
80
  # =========================
81
  col1, col2 = st.columns([1, 6])
82
  with col1:
83
+ st.image(LOGO_URL, use_container_width=True)
84
  with col2:
85
  st.markdown("## TeapotAI Chat")
86
  st.caption("Fast grounded answers with clean web context")
87
 
88
+
89
  # =========================
90
  # SIDEBAR SETTINGS
91
  # =========================
 
104
  height=180,
105
  )
106
 
107
+ st.markdown("### Local Context")
108
+ local_context_text = st.text_area(
109
  "Paste additional context (optional)",
110
+ height=140,
111
+ placeholder="This will be appended after web content...",
112
+ )
113
+
114
+ uploaded_files = st.file_uploader(
115
+ "Upload files to add to Local Context (pdf, txt, csv, md, json, etc.)",
116
+ type=None,
117
+ accept_multiple_files=True,
118
  )
119
 
 
120
 
121
  # =========================
122
+ # FILE PARSING -> STRING
123
+ # =========================
124
+ def _safe_decode(b: bytes) -> str:
125
+ # best effort decode without throwing
126
+ for enc in ("utf-8", "utf-16", "latin-1"):
127
+ try:
128
+ return b.decode(enc)
129
+ except Exception:
130
+ pass
131
+ return b.decode("utf-8", errors="ignore")
132
+
133
+
134
+ def parse_uploaded_file_to_text(file) -> str:
135
+ name = (file.name or "").lower()
136
+ raw = file.getvalue()
137
+
138
+ # PDF
139
+ if name.endswith(".pdf"):
140
+ if not PdfReader:
141
+ return (
142
+ f"[{file.name}] PDF parsing not available (install pypdf). "
143
+ f"Raw bytes={len(raw)}"
144
+ )
145
+ try:
146
+ reader = PdfReader(io.BytesIO(raw))
147
+ parts = []
148
+ for i, page in enumerate(reader.pages):
149
+ txt = page.extract_text() or ""
150
+ txt = txt.strip()
151
+ if txt:
152
+ parts.append(txt)
153
+ return "\n\n".join(parts).strip()
154
+ except Exception as e:
155
+ return f"[{file.name}] PDF parse error: {e}"
156
+
157
+ # CSV
158
+ if name.endswith(".csv"):
159
+ if not pd:
160
+ return (
161
+ f"[{file.name}] CSV parsing not available (install pandas). "
162
+ f"Raw bytes={len(raw)}"
163
+ )
164
+ try:
165
+ df = pd.read_csv(io.BytesIO(raw))
166
+ # Keep it compact but readable
167
+ return df.to_csv(index=False)
168
+ except Exception as e:
169
+ # fallback: raw text
170
+ return f"[{file.name}] CSV parse error ({e}). Raw:\n{_safe_decode(raw)}"
171
+
172
+ # JSON / TXT / MD / others -> decode
173
+ return _safe_decode(raw).strip()
174
+
175
+
176
+ def build_local_context(text_area: str, files) -> str:
177
+ chunks = []
178
+ if text_area.strip():
179
+ chunks.append(text_area.strip())
180
+
181
+ if files:
182
+ for f in files:
183
+ parsed = parse_uploaded_file_to_text(f).strip()
184
+ if parsed:
185
+ chunks.append(f"\n\n--- FILE: {f.name} ---\n{parsed}")
186
+
187
+ return "\n\n".join(chunks).strip()
188
+
189
+
190
+ local_context = build_local_context(local_context_text, uploaded_files)
191
+
192
+
193
+ # =========================
194
+ # WEB SEARCH (SNIPPETS ONLY) - ALWAYS ON
195
  # =========================
196
  def web_search_snippets(query: str):
197
  api_key = os.getenv("BRAVE_API_KEY") or st.secrets.get("BRAVE_API_KEY", None)
 
210
  timeout=6,
211
  )
212
  data = r.json()
213
+ except Exception:
214
  return "", 0.0
215
  t1 = time.perf_counter()
216
 
 
221
  if desc:
222
  snippets.append(desc)
223
 
224
+ clean_context = "\n\n".join(snippets) # paragraph-separated only
 
225
  return clean_context, (t1 - t0)
226
 
227
+
228
  # =========================
229
+ # TRUNCATE TO LAST 512 TOKENS (TAIL)
230
  # =========================
231
+ def truncate_context(web_ctx: str, local_ctx: str, system: str, question: str) -> str:
232
+ ordered_context = f"{web_ctx}\n\n{local_ctx}".strip()
 
 
233
 
234
  base = f"\n{system}\n{question}\n"
235
  base_tokens = tokenizer.encode(base)
236
  budget = MAX_INPUT_TOKENS - len(base_tokens)
237
 
238
+ if budget <= 0:
239
+ return "" # system+question already consume budget
240
+
241
+ ctx_tokens = tokenizer.encode(ordered_context) if ordered_context else []
242
  if len(ctx_tokens) <= budget:
243
  return ordered_context
244
 
245
+ truncated = ctx_tokens[-budget:] # keep MOST RECENT tokens
 
246
  return tokenizer.decode(truncated, skip_special_tokens=True)
247
 
248
+
249
  # =========================
250
  # STREAM GENERATION
251
  # =========================
 
262
  streamer=streamer,
263
  )
264
 
265
+ thread = threading.Thread(target=run, daemon=True)
266
  thread.start()
267
 
268
  text = ""
 
270
  text += chunk
271
  yield text
272
 
273
+
274
  # =========================
275
  # FEEDBACK HANDLER (Native st.feedback)
276
  # =========================
 
288
  score=score,
289
  comment="thumbs_up" if score else "thumbs_down",
290
  )
291
+ except Exception:
292
  pass
293
 
294
+
295
  # =========================
296
  # RENDER CHAT HISTORY
297
  # =========================
298
  for i, msg in enumerate(st.session_state.messages):
299
  with st.chat_message(msg["role"]):
300
+ if msg["role"] == "user":
301
+ st.markdown(msg["content"])
302
+ continue
303
+
304
+ # Assistant messages: collapsed-by-default expander = "whole message response be the dropdown"
305
+ with st.expander("🫖 Assistant response (click to expand)", expanded=False):
306
+ st.markdown(msg["content"])
307
 
 
 
308
  st.caption(
309
  f"🔎 {msg['search_time']:.2f}s search • "
310
  f"🧠 {msg['gen_time']:.2f}s generate • "
311
  f"⚡ {msg['tps']:.1f} tok/s • "
312
+ f"🧾 in={msg['input_tokens']} • out={msg['output_tokens']}"
313
  )
314
 
315
+ # Show EXACT prompt passed into the model (and the parts)
316
+ st.markdown("---")
317
+ st.markdown("#### Prompt & Inputs (exactly what was passed to the model)")
318
+ st.markdown("**System prompt:**")
319
+ st.code(msg.get("system_prompt", ""), language="text")
320
+ st.markdown("**Question:**")
321
+ st.code(msg.get("question", ""), language="text")
322
+ st.markdown("**Full model input (prompt):**")
323
+ st.code(msg.get("prompt", ""), language="text")
324
+
325
+ # Native thumbs feedback (outside expander so it's still reachable)
326
+ key = f"feedback_{i}"
327
+ st.session_state.setdefault(key, msg.get("feedback"))
328
+ st.feedback(
329
+ "thumbs",
330
+ key=key,
331
+ disabled=msg.get("feedback") is not None,
332
+ on_change=handle_feedback,
333
+ args=(i,),
334
+ )
335
+
336
 
337
  # =========================
338
+ # USER INPUT
339
  # =========================
340
  query = st.chat_input("Ask a question...")
341
 
342
  if query:
343
+ # show user message first
344
  st.session_state.messages.append({"role": "user", "content": query})
345
  st.rerun()
346
 
347
+
348
  # =========================
349
+ # GENERATE AFTER RERUN
350
  # =========================
351
  if (
352
  st.session_state.messages
353
  and st.session_state.messages[-1]["role"] == "user"
354
  and st.session_state.pending_response is None
355
  ):
356
+ question = st.session_state.messages[-1]["content"]
357
 
358
+ # --- Web Search (always on) ---
359
+ web_ctx, search_time = web_search_snippets(question)
 
 
 
360
 
361
  # --- Strict Order Context ---
362
  final_context = truncate_context(
363
+ web_ctx=web_ctx,
364
+ local_ctx=local_context,
365
+ system=system_prompt,
366
+ question=question,
367
  )
368
 
369
+ # IMPORTANT: prompt is EXACTLY what we pass to the model
370
+ prompt = f"{final_context}\n{system_prompt}\n{question}\n".strip() + "\n"
371
+
372
+ # Token accounting (split input vs output)
373
+ input_tokens = len(tokenizer.encode(prompt))
374
 
375
  # LangSmith run
376
  run_id = None
 
383
  "web_content": web_ctx,
384
  "local_context": local_context,
385
  "system_prompt": system_prompt,
386
+ "question": question,
387
  "final_context": final_context,
388
+ "prompt": prompt,
389
  },
390
  )
391
  run_id = run.id
392
+ except Exception:
393
  pass
394
 
395
+ # --- Stream UI: assistant response itself is a dropdown ---
396
  with st.chat_message("assistant"):
397
+ with st.expander("🫖 Assistant response (click to expand)", expanded=False):
398
+ placeholder = st.empty()
399
+ start = time.perf_counter()
400
+ final_text = ""
401
+
402
+ for partial in stream_generate(prompt):
403
+ final_text = partial
404
+ placeholder.markdown(final_text)
405
+
406
+ gen_time = time.perf_counter() - start
407
+ output_tokens = len(tokenizer.encode(final_text))
408
+ tps = output_tokens / gen_time if gen_time > 0 else 0.0
409
+
410
+ st.caption(
411
+ f"🔎 {search_time:.2f}s search • "
412
+ f"🧠 {gen_time:.2f}s generate • "
413
+ f" {tps:.1f} tok/s • "
414
+ f"🧾 in={input_tokens} • out={output_tokens}"
415
+ )
416
+
417
+ st.markdown("---")
418
+ st.markdown("#### Prompt & Inputs (exactly what was passed to the model)")
419
+ st.markdown("**System prompt:**")
420
+ st.code(system_prompt, language="text")
421
+ st.markdown("**Question:**")
422
+ st.code(question, language="text")
423
+ st.markdown("**Full model input (prompt):**")
424
+ st.code(prompt, language="text")
425
 
426
  if ls_client and run_id:
427
  try:
428
  ls_client.update_run(run_id, outputs={"answer": final_text})
429
+ except Exception:
430
  pass
431
 
432
  st.session_state.messages.append(
433
  {
434
  "role": "assistant",
435
  "content": final_text,
436
+ "system_prompt": system_prompt,
437
+ "question": question,
438
  "web_context": web_ctx,
439
  "local_context": local_context,
440
  "final_context": final_context,
441
+ "prompt": prompt,
442
  "search_time": search_time,
443
  "gen_time": gen_time,
444
+ "input_tokens": input_tokens,
445
+ "output_tokens": output_tokens,
446
  "tps": tps,
447
  "run_id": run_id,
448
  "feedback": None,