zakerytclarke commited on
Commit
c73781a
·
verified ·
1 Parent(s): ad255be

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +96 -107
src/streamlit_app.py CHANGED
@@ -43,11 +43,11 @@ st.set_page_config(page_title="TeapotAI Chat", page_icon="🫖", layout="centere
43
  # =========================
44
  @st.cache_resource
45
  def load_model():
46
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
47
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
48
- device = "cuda" if torch.cuda.is_available() else "cpu"
49
- model.to(device).eval()
50
- return tokenizer, model, device
51
 
52
 
53
  tokenizer, model, device = load_model()
@@ -72,30 +72,27 @@ ls_client = get_langsmith()
72
  # =========================
73
  if "messages" not in st.session_state:
74
  st.session_state.messages = []
75
- if "pending_response" not in st.session_state:
76
- st.session_state.pending_response = None
77
 
78
 
79
  # =========================
80
- # HEADER (SAFE IMAGE CALL)
81
  # =========================
82
  col1, col2 = st.columns([1, 6])
83
  with col1:
84
- # IMPORTANT: use_column_width=True (works on your Streamlit version)
85
  st.image(LOGO_URL, use_column_width=True)
86
  with col2:
87
  st.markdown("## TeapotAI Chat")
88
- st.caption("Fast grounded answers with clean web context")
89
 
90
 
91
  # =========================
92
- # SIDEBAR SETTINGS
93
  # =========================
94
  with st.sidebar:
95
  st.markdown("### Settings")
96
 
97
  system_prompt = st.text_area(
98
- "System Prompt",
99
  value=(
100
  "You are Teapot, an open-source AI assistant optimized for low-end devices, "
101
  "providing short, accurate responses without hallucinating while excelling at "
@@ -103,21 +100,20 @@ with st.sidebar:
103
  "If the context does not answer the question, reply exactly: "
104
  "'I am sorry but I don't have any information on that'."
105
  ),
106
- height=180,
107
  )
108
 
109
- st.markdown("### Local Context (Text)")
110
  local_context_text = st.text_area(
111
- "Paste additional context (optional)",
112
- height=140,
113
- placeholder="This will be appended after web content...",
114
  )
115
 
116
- st.markdown("### Local Context (File Upload)")
117
  uploaded_files = st.file_uploader(
118
- "Upload files (pdf, txt, csv, md, json, etc.)",
119
  accept_multiple_files=True,
120
  type=None,
 
121
  )
122
 
123
 
@@ -137,20 +133,18 @@ def parse_file_to_text(file) -> str:
137
  name = (file.name or "").lower()
138
  raw = file.getvalue()
139
 
140
- # PDF
141
  if name.endswith(".pdf") and PdfReader:
142
  try:
143
  reader = PdfReader(io.BytesIO(raw))
144
- pages = []
145
- for page in reader.pages:
146
- txt = page.extract_text() or ""
147
- if txt.strip():
148
- pages.append(txt.strip())
149
- return "\n\n".join(pages)
150
  except Exception as e:
151
  return f"[PDF parse error: {e}]"
152
 
153
- # CSV
154
  if name.endswith(".csv") and pd:
155
  try:
156
  df = pd.read_csv(io.BytesIO(raw))
@@ -158,13 +152,11 @@ def parse_file_to_text(file) -> str:
158
  except Exception:
159
  return safe_decode(raw)
160
 
161
- # TXT / MD / JSON / fallback
162
- return safe_decode(raw)
163
 
164
 
165
  def build_local_context(text_block: str, files) -> str:
166
  chunks = []
167
-
168
  if text_block and text_block.strip():
169
  chunks.append(text_block.strip())
170
 
@@ -172,7 +164,7 @@ def build_local_context(text_block: str, files) -> str:
172
  for f in files:
173
  parsed = parse_file_to_text(f)
174
  if parsed and parsed.strip():
175
- chunks.append(f"\n\n--- FILE: {f.name} ---\n{parsed.strip()}")
176
 
177
  return "\n\n".join(chunks).strip()
178
 
@@ -206,8 +198,7 @@ def web_search_snippets(query: str):
206
 
207
  snippets = []
208
  for item in data.get("web", {}).get("results", [])[:TOP_K_SEARCH]:
209
- desc = item.get("description", "")
210
- desc = desc.replace("<strong>", "").replace("</strong>", "").strip()
211
  if desc:
212
  snippets.append(desc)
213
 
@@ -217,22 +208,24 @@ def web_search_snippets(query: str):
217
  # =========================
218
  # CONTEXT TRUNCATION (TAIL)
219
  # =========================
220
- def truncate_context(web_ctx, local_ctx, system, question):
221
- ordered_context = f"{web_ctx}\n\n{local_ctx}".strip()
222
 
223
  base = f"\n{system}\n{question}\n"
224
  base_tokens = tokenizer.encode(base)
225
  budget = MAX_INPUT_TOKENS - len(base_tokens)
226
-
227
  if budget <= 0:
228
  return ""
229
 
230
- ctx_tokens = tokenizer.encode(ordered_context) if ordered_context else []
231
  if len(ctx_tokens) <= budget:
232
- return ordered_context
 
 
 
233
 
234
- truncated = ctx_tokens[-budget:]
235
- return tokenizer.decode(truncated, skip_special_tokens=True)
236
 
237
 
238
  # =========================
@@ -251,23 +244,22 @@ def stream_generate(prompt: str):
251
  streamer=streamer,
252
  )
253
 
254
- thread = threading.Thread(target=run)
255
- thread.start()
256
 
257
- text = ""
258
  for chunk in streamer:
259
- text += chunk
260
- yield text
261
 
262
 
263
  # =========================
264
  # FEEDBACK HANDLER
265
  # =========================
266
  def handle_feedback(idx: int):
267
- val = st.session_state[f"feedback_{idx}"]
268
- msg = st.session_state.messages[idx]
269
- msg["feedback"] = val
270
 
 
271
  if ls_client and msg.get("run_id"):
272
  score = 1 if val == "👍" else 0
273
  try:
@@ -282,41 +274,41 @@ def handle_feedback(idx: int):
282
 
283
 
284
  # =========================
285
- # RENDER CHAT HISTORY
286
  # =========================
287
  for i, msg in enumerate(st.session_state.messages):
288
  with st.chat_message(msg["role"]):
289
- if msg["role"] == "user":
290
- st.markdown(msg["content"])
291
- continue
292
-
293
- # Entire response as collapsed dropdown (less visible inspector)
294
- with st.expander("🫖 Assistant response (click to expand)", expanded=False):
295
- st.markdown(msg["content"])
296
 
 
 
297
  st.caption(
298
- f"🔎 {msg['search_time']:.2f}s search • "
299
- f"🧠 {msg['gen_time']:.2f}s generate • "
300
- f"⚡ {msg['tps']:.1f} tok/s • "
301
- f"🧾 in={msg['input_tokens']} • out={msg['output_tokens']}"
302
  )
303
 
304
- st.markdown("### Exact Model Input (Prompt)")
305
- st.code(msg["prompt"], language="text")
306
-
307
- key = f"feedback_{i}"
308
- st.session_state.setdefault(key, msg.get("feedback"))
309
- st.feedback(
310
- "thumbs",
311
- key=key,
312
- disabled=msg.get("feedback") is not None,
313
- on_change=handle_feedback,
314
- args=(i,),
315
- )
 
 
 
 
 
 
316
 
317
 
318
  # =========================
319
- # USER INPUT
320
  # =========================
321
  query = st.chat_input("Ask a question...")
322
 
@@ -326,29 +318,21 @@ if query:
326
 
327
 
328
  # =========================
329
- # GENERATE AFTER RERUN
330
  # =========================
331
- if (
332
- st.session_state.messages
333
- and st.session_state.messages[-1]["role"] == "user"
334
- and st.session_state.pending_response is None
335
- ):
336
  question = st.session_state.messages[-1]["content"]
337
 
338
- # Always do web search
339
  web_ctx, search_time = web_search_snippets(question)
340
 
341
- final_context = truncate_context(
342
- web_ctx,
343
- local_context,
344
- system_prompt,
345
- question,
346
- )
347
 
348
- # EXACT prompt passed to model
349
  prompt = f"{final_context}\n{system_prompt}\n{question}\n"
350
 
351
- input_tokens = len(tokenizer.encode(prompt))
352
 
353
  # LangSmith run (optional)
354
  run_id = None
@@ -358,36 +342,40 @@ if (
358
  name="teapot_chat",
359
  run_type="llm",
360
  inputs={
361
- "prompt": prompt,
362
  "question": question,
 
363
  },
364
  )
365
  run_id = run.id
366
  except Exception:
367
  pass
368
 
 
369
  with st.chat_message("assistant"):
370
- with st.expander("🫖 Assistant response (click to expand)", expanded=False):
371
- placeholder = st.empty()
372
- start = time.perf_counter()
373
- final_text = ""
374
 
375
- for partial in stream_generate(prompt):
376
- final_text = partial
377
- placeholder.markdown(final_text)
378
 
379
- gen_time = time.perf_counter() - start
380
- output_tokens = len(tokenizer.encode(final_text))
381
- tps = output_tokens / gen_time if gen_time > 0 else 0.0
382
 
383
- st.caption(
384
- f"🔎 {search_time:.2f}s search • "
385
- f"🧠 {gen_time:.2f}s generate • "
386
- f"⚡ {tps:.1f} tok/s • "
387
- f"🧾 in={input_tokens} • out={output_tokens}"
388
- )
389
 
390
- st.markdown("### Exact Model Input (Prompt)")
 
 
 
 
 
391
  st.code(prompt, language="text")
392
 
393
  if ls_client and run_id:
@@ -400,6 +388,8 @@ if (
400
  {
401
  "role": "assistant",
402
  "content": final_text,
 
 
403
  "prompt": prompt,
404
  "search_time": search_time,
405
  "gen_time": gen_time,
@@ -411,5 +401,4 @@ if (
411
  }
412
  )
413
 
414
- st.session_state.pending_response = None
415
  st.rerun()
 
43
  # =========================
44
  @st.cache_resource
45
  def load_model():
46
+ tok = AutoTokenizer.from_pretrained(MODEL_NAME)
47
+ mdl = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
48
+ dev = "cuda" if torch.cuda.is_available() else "cpu"
49
+ mdl.to(dev).eval()
50
+ return tok, mdl, dev
51
 
52
 
53
  tokenizer, model, device = load_model()
 
72
  # =========================
73
  if "messages" not in st.session_state:
74
  st.session_state.messages = []
 
 
75
 
76
 
77
  # =========================
78
+ # HEADER
79
  # =========================
80
  col1, col2 = st.columns([1, 6])
81
  with col1:
 
82
  st.image(LOGO_URL, use_column_width=True)
83
  with col2:
84
  st.markdown("## TeapotAI Chat")
85
+ st.caption("Grounded answers with web context")
86
 
87
 
88
  # =========================
89
+ # SIDEBAR
90
  # =========================
91
  with st.sidebar:
92
  st.markdown("### Settings")
93
 
94
  system_prompt = st.text_area(
95
+ "System prompt",
96
  value=(
97
  "You are Teapot, an open-source AI assistant optimized for low-end devices, "
98
  "providing short, accurate responses without hallucinating while excelling at "
 
100
  "If the context does not answer the question, reply exactly: "
101
  "'I am sorry but I don't have any information on that'."
102
  ),
103
+ height=160,
104
  )
105
 
 
106
  local_context_text = st.text_area(
107
+ "Local context (optional)",
108
+ height=120,
109
+ placeholder="Extra context to append after web snippets…",
110
  )
111
 
 
112
  uploaded_files = st.file_uploader(
113
+ "Upload context files",
114
  accept_multiple_files=True,
115
  type=None,
116
+ help="PDF, TXT, CSV, MD, JSON, etc.",
117
  )
118
 
119
 
 
133
  name = (file.name or "").lower()
134
  raw = file.getvalue()
135
 
 
136
  if name.endswith(".pdf") and PdfReader:
137
  try:
138
  reader = PdfReader(io.BytesIO(raw))
139
+ parts = []
140
+ for p in reader.pages:
141
+ t = (p.extract_text() or "").strip()
142
+ if t:
143
+ parts.append(t)
144
+ return "\n\n".join(parts).strip()
145
  except Exception as e:
146
  return f"[PDF parse error: {e}]"
147
 
 
148
  if name.endswith(".csv") and pd:
149
  try:
150
  df = pd.read_csv(io.BytesIO(raw))
 
152
  except Exception:
153
  return safe_decode(raw)
154
 
155
+ return safe_decode(raw).strip()
 
156
 
157
 
158
  def build_local_context(text_block: str, files) -> str:
159
  chunks = []
 
160
  if text_block and text_block.strip():
161
  chunks.append(text_block.strip())
162
 
 
164
  for f in files:
165
  parsed = parse_file_to_text(f)
166
  if parsed and parsed.strip():
167
+ chunks.append(f"\n\n--- {f.name} ---\n{parsed.strip()}")
168
 
169
  return "\n\n".join(chunks).strip()
170
 
 
198
 
199
  snippets = []
200
  for item in data.get("web", {}).get("results", [])[:TOP_K_SEARCH]:
201
+ desc = (item.get("description") or "").replace("<strong>", "").replace("</strong>", "").strip()
 
202
  if desc:
203
  snippets.append(desc)
204
 
 
208
  # =========================
209
  # CONTEXT TRUNCATION (TAIL)
210
  # =========================
211
+ def truncate_context(web_ctx: str, local_ctx: str, system: str, question: str) -> str:
212
+ ctx = f"{web_ctx}\n\n{local_ctx}".strip()
213
 
214
  base = f"\n{system}\n{question}\n"
215
  base_tokens = tokenizer.encode(base)
216
  budget = MAX_INPUT_TOKENS - len(base_tokens)
 
217
  if budget <= 0:
218
  return ""
219
 
220
+ ctx_tokens = tokenizer.encode(ctx) if ctx else []
221
  if len(ctx_tokens) <= budget:
222
+ return ctx
223
+
224
+ return tokenizer.decode(ctx_tokens[-budget:], skip_special_tokens=True)
225
+
226
 
227
+ def count_tokens(text: str) -> int:
228
+ return len(tokenizer.encode(text)) if text else 0
229
 
230
 
231
  # =========================
 
244
  streamer=streamer,
245
  )
246
 
247
+ threading.Thread(target=run, daemon=True).start()
 
248
 
249
+ acc = ""
250
  for chunk in streamer:
251
+ acc += chunk
252
+ yield acc
253
 
254
 
255
  # =========================
256
  # FEEDBACK HANDLER
257
  # =========================
258
  def handle_feedback(idx: int):
259
+ val = st.session_state.get(f"fb_{idx}")
260
+ st.session_state.messages[idx]["feedback"] = val
 
261
 
262
+ msg = st.session_state.messages[idx]
263
  if ls_client and msg.get("run_id"):
264
  score = 1 if val == "👍" else 0
265
  try:
 
274
 
275
 
276
  # =========================
277
+ # RENDER HISTORY
278
  # =========================
279
  for i, msg in enumerate(st.session_state.messages):
280
  with st.chat_message(msg["role"]):
281
+ st.markdown(msg["content"])
 
 
 
 
 
 
282
 
283
+ if msg["role"] == "assistant":
284
+ # Light, normal-looking stats
285
  st.caption(
286
+ f"{msg['search_time']:.2f}s search • {msg['gen_time']:.2f}s gen • "
287
+ f"{msg['tps']:.1f} tok/s in {msg['input_tokens']} out {msg['output_tokens']}"
 
 
288
  )
289
 
290
+ # Small inspector (collapsed)
291
+ with st.expander("Inspect context"):
292
+ st.markdown("**System**")
293
+ st.code(msg.get("system_prompt", ""), language="text")
294
+ st.markdown("**Question**")
295
+ st.code(msg.get("question", ""), language="text")
296
+ st.markdown("**Prompt (sent to model)**")
297
+ st.code(msg.get("prompt", ""), language="text")
298
+
299
+ key = f"fb_{i}"
300
+ st.session_state.setdefault(key, msg.get("feedback"))
301
+ st.feedback(
302
+ "thumbs",
303
+ key=key,
304
+ disabled=msg.get("feedback") is not None,
305
+ on_change=handle_feedback,
306
+ args=(i,),
307
+ )
308
 
309
 
310
  # =========================
311
+ # INPUT
312
  # =========================
313
  query = st.chat_input("Ask a question...")
314
 
 
318
 
319
 
320
  # =========================
321
+ # GENERATE
322
  # =========================
323
+ if st.session_state.messages and st.session_state.messages[-1]["role"] == "user":
 
 
 
 
324
  question = st.session_state.messages[-1]["content"]
325
 
326
+ # web search
327
  web_ctx, search_time = web_search_snippets(question)
328
 
329
+ # truncate final context
330
+ final_context = truncate_context(web_ctx, local_context, system_prompt, question)
 
 
 
 
331
 
332
+ # prompt sent to model
333
  prompt = f"{final_context}\n{system_prompt}\n{question}\n"
334
 
335
+ input_tokens = count_tokens(prompt)
336
 
337
  # LangSmith run (optional)
338
  run_id = None
 
342
  name="teapot_chat",
343
  run_type="llm",
344
  inputs={
345
+ "system_prompt": system_prompt,
346
  "question": question,
347
+ "prompt": prompt,
348
  },
349
  )
350
  run_id = run.id
351
  except Exception:
352
  pass
353
 
354
+ # stream normally in chat
355
  with st.chat_message("assistant"):
356
+ placeholder = st.empty()
357
+ start = time.perf_counter()
358
+ final_text = ""
 
359
 
360
+ for partial in stream_generate(prompt):
361
+ final_text = partial
362
+ placeholder.markdown(final_text)
363
 
364
+ gen_time = time.perf_counter() - start
365
+ output_tokens = count_tokens(final_text)
366
+ tps = output_tokens / gen_time if gen_time > 0 else 0.0
367
 
368
+ st.caption(
369
+ f"{search_time:.2f}s search • {gen_time:.2f}s gen • "
370
+ f"{tps:.1f} tok/s in {input_tokens} out {output_tokens}"
371
+ )
 
 
372
 
373
+ with st.expander("Inspect context"):
374
+ st.markdown("**System**")
375
+ st.code(system_prompt, language="text")
376
+ st.markdown("**Question**")
377
+ st.code(question, language="text")
378
+ st.markdown("**Prompt (sent to model)**")
379
  st.code(prompt, language="text")
380
 
381
  if ls_client and run_id:
 
388
  {
389
  "role": "assistant",
390
  "content": final_text,
391
+ "system_prompt": system_prompt,
392
+ "question": question,
393
  "prompt": prompt,
394
  "search_time": search_time,
395
  "gen_time": gen_time,
 
401
  }
402
  )
403
 
 
404
  st.rerun()