zakerytclarke commited on
Commit
3bbf8f3
·
verified ·
1 Parent(s): 0f3b8dd

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +110 -60
src/streamlit_app.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
  import time
3
- import threading
4
  import requests
5
 
6
  import streamlit as st
@@ -50,7 +49,6 @@ tokenizer, model, device = load_model()
50
  # =========================
51
  @st.cache_resource
52
  def get_langsmith():
53
- key = os.getenv("LANGCHAIN_API_KEY") or os.getenv("LANGSMITH_API_KEY") or os.getenv("LANGCHAIN_TRACING_V2")
54
  if (os.getenv("LANGCHAIN_API_KEY") or os.getenv("LANGSMITH_API_KEY")) and LangSmithClient:
55
  return LangSmithClient()
56
  return None
@@ -59,6 +57,46 @@ def get_langsmith():
59
  ls_client = get_langsmith()
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  # =========================
63
  # SESSION STATE
64
  # =========================
@@ -66,18 +104,27 @@ if "messages" not in st.session_state:
66
  st.session_state.messages = []
67
  if "needs_answer" not in st.session_state:
68
  st.session_state.needs_answer = False
 
 
 
 
 
 
 
69
 
70
 
71
  # =========================
72
  # HEADER (prevent logo flash)
73
- # Use a fixed pixel width to avoid layout shift / big flash.
74
  # =========================
75
  col1, col2 = st.columns([1, 7], vertical_alignment="center")
76
  with col1:
77
- st.image(LOGO_URL, width=56) # fixed width prevents "flash huge"
78
  with col2:
79
  st.markdown("## TeapotAI Chat")
80
- st.caption("Teapot is a 77 million parameter LLM designed to generate ")
 
 
 
81
 
82
 
83
  # =========================
@@ -88,13 +135,7 @@ with st.sidebar:
88
 
89
  system_prompt = st.text_area(
90
  "System prompt",
91
- value=(
92
- "You are Teapot, an open-source AI assistant optimized for running on low-end cpu devices, "
93
- "providing short, accurate responses without hallucinating while excelling at "
94
- "information extraction and text summarization. "
95
- "If the context does not answer the question, reply exactly: "
96
- "'I am sorry but I don't have any information on that'."
97
- ),
98
  height=160,
99
  )
100
 
@@ -104,6 +145,21 @@ with st.sidebar:
104
  placeholder="Extra context appended after web snippets…",
105
  )
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  # =========================
109
  # WEB SEARCH (ALWAYS ON)
@@ -162,9 +218,9 @@ def count_tokens(text: str) -> int:
162
 
163
  # =========================
164
  # LANGSMITH-TRACED ANSWER FUNCTION
165
- # (signature exactly: context, system_prompt, question -> answer)
166
  # =========================
167
  if traceable:
 
168
  @traceable(name="teapot_answer")
169
  def traced_answer(context: str, system_prompt: str, question: str) -> str:
170
  prompt = f"{context}\n{system_prompt}\n{question}\n"
@@ -176,9 +232,10 @@ if traceable:
176
  do_sample=False,
177
  num_beams=1,
178
  )
179
- text = tokenizer.decode(out[0], skip_special_tokens=True)
180
- return text
181
  else:
 
182
  def traced_answer(context: str, system_prompt: str, question: str) -> str:
183
  prompt = f"{context}\n{system_prompt}\n{question}\n"
184
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
@@ -193,7 +250,6 @@ else:
193
 
194
 
195
  def get_trace_id_if_available() -> str | None:
196
- # Works when running inside a @traceable function call
197
  if not get_current_run_tree:
198
  return None
199
  try:
@@ -217,7 +273,6 @@ def handle_feedback(idx: int):
217
  if ls_client and trace_id:
218
  score = 1 if val == "👍" else 0
219
  try:
220
- # LangSmith SDK supports trace_id= for feedback association
221
  ls_client.create_feedback(
222
  trace_id=trace_id,
223
  key="thumb_rating",
@@ -230,6 +285,8 @@ def handle_feedback(idx: int):
230
 
231
  # =========================
232
  # RENDER HISTORY
 
 
233
  # =========================
234
  for i, msg in enumerate(st.session_state.messages):
235
  with st.chat_message(msg["role"]):
@@ -237,17 +294,11 @@ for i, msg in enumerate(st.session_state.messages):
237
  st.markdown(msg["content"])
238
  continue
239
 
240
- # Assistant
241
-
242
-
243
- # Info icon popover with full prompt/context
244
- # (st.popover is stable in your Streamlit range; no rerun on open/close)
245
- c1, c2 = st.columns([1, 12], vertical_alignment="center")
246
- with c1:
247
- st.markdown(msg["content"])
248
-
249
- with c2:
250
-
251
  key = f"fb_{i}"
252
  st.session_state.setdefault(key, msg.get("feedback"))
253
  st.feedback(
@@ -258,18 +309,10 @@ for i, msg in enumerate(st.session_state.messages):
258
  args=(i,),
259
  )
260
 
261
-
 
262
 
263
- c3, c4 = st.columns([1, 12], vertical_alignment="center")
264
- with c3:
265
- st.caption(
266
- f"🔎 {msg['search_time']:.2f}s (search)"
267
- f"🧠 {msg['gen_time']:.2f}s (generation) "
268
- f"⚡ {msg['tps']:.1f} tok/s "
269
- f"🧾 {msg['input_tokens']} input tokens • {msg['output_tokens']} output tokens"
270
- )
271
-
272
- with c4:
273
  with st.popover("ℹ️", help="Inspect"):
274
  st.markdown("**Context**")
275
  st.code(msg.get("context", ""), language="text")
@@ -277,11 +320,16 @@ for i, msg in enumerate(st.session_state.messages):
277
  st.code(msg.get("system_prompt", ""), language="text")
278
  st.markdown("**Question**")
279
  st.code(msg.get("question", ""), language="text")
280
-
 
281
 
282
-
283
-
284
-
 
 
 
 
285
 
286
 
287
  # =========================
@@ -313,32 +361,40 @@ if (
313
  prompt = f"{context}\n{system_prompt}\n{question}\n"
314
  input_tokens = count_tokens(prompt)
315
 
316
- # Run traced answer (returns answer; trace_id obtained from current run tree)
317
  with st.chat_message("assistant"):
318
- placeholder = st.empty()
 
 
 
 
 
319
 
320
  start = time.perf_counter()
321
-
322
- # Generate full answer first (traced), then "stream" it to UI quickly.
323
- # This keeps LangSmith tracing simple/reliable while still giving a streaming UX.
324
  answer = traced_answer(context, system_prompt, question)
325
  trace_id = get_trace_id_if_available()
326
 
327
- # Typewriter-ish stream (fast, looks normal)
328
  buf = ""
329
  for ch in answer:
330
  buf += ch
331
  placeholder.markdown(buf)
332
- # small delay; tune if you want faster/slower
333
  time.sleep(0.002)
334
 
335
  gen_time = time.perf_counter() - start
336
  output_tokens = count_tokens(answer)
337
  tps = output_tokens / gen_time if gen_time > 0 else 0.0
338
 
339
- # Metrics + info popover for this live message
340
- c1, c2 = st.columns([1, 12], vertical_alignment="center")
341
- with c1:
 
 
 
 
 
 
 
342
  with st.popover("ℹ️", help="Inspect"):
343
  st.markdown("**Context**")
344
  st.code(context, language="text")
@@ -348,13 +404,7 @@ if (
348
  st.code(question, language="text")
349
  st.markdown("**Prompt**")
350
  st.code(prompt, language="text")
351
- with c2:
352
- st.caption(
353
- f"🔎 {search_time:.2f}s (search) "
354
- f"🧠 {gen_time:.2f}s (generation) "
355
- f"⚡ {tps:.1f} tok/s "
356
- f"🧾 {input_tokens} input tokens • {output_tokens} output tokens"
357
- )
358
 
359
  # Persist assistant message
360
  st.session_state.messages.append(
 
1
  import os
2
  import time
 
3
  import requests
4
 
5
  import streamlit as st
 
49
  # =========================
50
  @st.cache_resource
51
  def get_langsmith():
 
52
  if (os.getenv("LANGCHAIN_API_KEY") or os.getenv("LANGSMITH_API_KEY")) and LangSmithClient:
53
  return LangSmithClient()
54
  return None
 
57
  ls_client = get_langsmith()
58
 
59
 
60
+ # =========================
61
+ # SAMPLE SEED (with full debug fields)
62
+ # =========================
63
+ SAMPLE_QUESTION = "who are you"
64
+
65
+ DEFAULT_SYSTEM_PROMPT = (
66
+ "You are Teapot, an open-source AI assistant optimized for running on low-end cpu devices, "
67
+ "providing short, accurate responses without hallucinating while excelling at "
68
+ "information extraction and text summarization. "
69
+ "If the context does not answer the question, reply exactly: "
70
+ "'I am sorry but I don't have any information on that'."
71
+ )
72
+
73
+ SAMPLE_SYSTEM_PROMPT = DEFAULT_SYSTEM_PROMPT
74
+
75
+ SAMPLE_CONTEXT = (
76
+ "Teapot is an open-source AI assistant optimized for running on low-end cpu devices."
77
+ )
78
+
79
+ SAMPLE_ANSWER = "I am Teapot, an open-source AI assistant optimized for running on low-end cpu devices."
80
+ SAMPLE_PROMPT = f"{SAMPLE_CONTEXT}\n{SAMPLE_SYSTEM_PROMPT}\n{SAMPLE_QUESTION}\n"
81
+
82
+ SAMPLE_USER_MSG = {"role": "user", "content": SAMPLE_QUESTION}
83
+ SAMPLE_ASSISTANT_MSG = {
84
+ "role": "assistant",
85
+ "content": SAMPLE_ANSWER,
86
+ "context": SAMPLE_CONTEXT,
87
+ "system_prompt": SAMPLE_SYSTEM_PROMPT,
88
+ "question": SAMPLE_QUESTION,
89
+ "prompt": SAMPLE_PROMPT,
90
+ "search_time": 0.37,
91
+ "gen_time": 0.67,
92
+ "input_tokens": 245,
93
+ "output_tokens": 24,
94
+ "tps": 35.9,
95
+ "trace_id": None,
96
+ "feedback": None,
97
+ }
98
+
99
+
100
  # =========================
101
  # SESSION STATE
102
  # =========================
 
104
  st.session_state.messages = []
105
  if "needs_answer" not in st.session_state:
106
  st.session_state.needs_answer = False
107
+ if "seeded" not in st.session_state:
108
+ st.session_state.seeded = False
109
+
110
+ # Seed exactly once on first load
111
+ if (not st.session_state.seeded) and (len(st.session_state.messages) == 0):
112
+ st.session_state.messages = [SAMPLE_USER_MSG, SAMPLE_ASSISTANT_MSG]
113
+ st.session_state.seeded = True
114
 
115
 
116
  # =========================
117
  # HEADER (prevent logo flash)
 
118
  # =========================
119
  col1, col2 = st.columns([1, 7], vertical_alignment="center")
120
  with col1:
121
+ st.image(LOGO_URL, width=56)
122
  with col2:
123
  st.markdown("## TeapotAI Chat")
124
+ st.caption(
125
+ "Teapot is a 77M-parameter LLM optimized for fast CPU inference that only generates answers "
126
+ "from the provided context to minimize hallucinations."
127
+ )
128
 
129
 
130
  # =========================
 
135
 
136
  system_prompt = st.text_area(
137
  "System prompt",
138
+ value=DEFAULT_SYSTEM_PROMPT,
 
 
 
 
 
 
139
  height=160,
140
  )
141
 
 
145
  placeholder="Extra context appended after web snippets…",
146
  )
147
 
148
+ st.markdown("### Conversation")
149
+ c1, c2 = st.columns(2)
150
+ with c1:
151
+ if st.button("Load sample"):
152
+ st.session_state.messages = [SAMPLE_USER_MSG, SAMPLE_ASSISTANT_MSG]
153
+ st.session_state.needs_answer = False
154
+ st.session_state.seeded = True
155
+ st.rerun()
156
+ with c2:
157
+ if st.button("Clear chat"):
158
+ st.session_state.messages = []
159
+ st.session_state.needs_answer = False
160
+ st.session_state.seeded = True
161
+ st.rerun()
162
+
163
 
164
  # =========================
165
  # WEB SEARCH (ALWAYS ON)
 
218
 
219
  # =========================
220
  # LANGSMITH-TRACED ANSWER FUNCTION
 
221
  # =========================
222
  if traceable:
223
+
224
  @traceable(name="teapot_answer")
225
  def traced_answer(context: str, system_prompt: str, question: str) -> str:
226
  prompt = f"{context}\n{system_prompt}\n{question}\n"
 
232
  do_sample=False,
233
  num_beams=1,
234
  )
235
+ return tokenizer.decode(out[0], skip_special_tokens=True)
236
+
237
  else:
238
+
239
  def traced_answer(context: str, system_prompt: str, question: str) -> str:
240
  prompt = f"{context}\n{system_prompt}\n{question}\n"
241
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
250
 
251
 
252
  def get_trace_id_if_available() -> str | None:
 
253
  if not get_current_run_tree:
254
  return None
255
  try:
 
273
  if ls_client and trace_id:
274
  score = 1 if val == "👍" else 0
275
  try:
 
276
  ls_client.create_feedback(
277
  trace_id=trace_id,
278
  key="thumb_rating",
 
285
 
286
  # =========================
287
  # RENDER HISTORY
288
+ # Row 1: message + feedback
289
+ # Row 2: inspect + debug metrics
290
  # =========================
291
  for i, msg in enumerate(st.session_state.messages):
292
  with st.chat_message(msg["role"]):
 
294
  st.markdown(msg["content"])
295
  continue
296
 
297
+ # Row 1
298
+ msg_col, fb_col = st.columns([14, 1], vertical_alignment="center")
299
+ with msg_col:
300
+ st.markdown(msg.get("content", ""))
301
+ with fb_col:
 
 
 
 
 
 
302
  key = f"fb_{i}"
303
  st.session_state.setdefault(key, msg.get("feedback"))
304
  st.feedback(
 
309
  args=(i,),
310
  )
311
 
312
+ # Row 2
313
+ inspect_col, metrics_col = st.columns([1, 12], vertical_alignment="center")
314
 
315
+ with inspect_col:
 
 
 
 
 
 
 
 
 
316
  with st.popover("ℹ️", help="Inspect"):
317
  st.markdown("**Context**")
318
  st.code(msg.get("context", ""), language="text")
 
320
  st.code(msg.get("system_prompt", ""), language="text")
321
  st.markdown("**Question**")
322
  st.code(msg.get("question", ""), language="text")
323
+ st.markdown("**Prompt**")
324
+ st.code(msg.get("prompt", ""), language="text")
325
 
326
+ with metrics_col:
327
+ st.caption(
328
+ f"🔎 {msg.get('search_time', 0.0):.2f}s (search) "
329
+ f"🧠 {msg.get('gen_time', 0.0):.2f}s (generation) "
330
+ f"⚡ {msg.get('tps', 0.0):.1f} tok/s "
331
+ f"🧾 {msg.get('input_tokens', 0)} input tokens • {msg.get('output_tokens', 0)} output tokens"
332
+ )
333
 
334
 
335
  # =========================
 
361
  prompt = f"{context}\n{system_prompt}\n{question}\n"
362
  input_tokens = count_tokens(prompt)
363
 
364
+ # Run traced answer
365
  with st.chat_message("assistant"):
366
+ # Row 1: message + feedback (feedback disabled until persisted)
367
+ msg_col, fb_col = st.columns([14, 1], vertical_alignment="center")
368
+ with msg_col:
369
+ placeholder = st.empty()
370
+ with fb_col:
371
+ st.feedback("thumbs", key="live_fb", disabled=True)
372
 
373
  start = time.perf_counter()
 
 
 
374
  answer = traced_answer(context, system_prompt, question)
375
  trace_id = get_trace_id_if_available()
376
 
377
+ # Stream into the message column
378
  buf = ""
379
  for ch in answer:
380
  buf += ch
381
  placeholder.markdown(buf)
 
382
  time.sleep(0.002)
383
 
384
  gen_time = time.perf_counter() - start
385
  output_tokens = count_tokens(answer)
386
  tps = output_tokens / gen_time if gen_time > 0 else 0.0
387
 
388
+ # Row 2: inspect + metrics
389
+ inspect_col, metrics_col = st.columns([12, 1], vertical_alignment="center")
390
+ with inspect_col:
391
+ st.caption(
392
+ f"🔎 {search_time:.2f}s (search) "
393
+ f"🧠 {gen_time:.2f}s (generation) "
394
+ f"⚡ {tps:.1f} tok/s "
395
+ f"🧾 {input_tokens} input tokens • {output_tokens} output tokens"
396
+ )
397
+ with metrics_col:
398
  with st.popover("ℹ️", help="Inspect"):
399
  st.markdown("**Context**")
400
  st.code(context, language="text")
 
404
  st.code(question, language="text")
405
  st.markdown("**Prompt**")
406
  st.code(prompt, language="text")
407
+
 
 
 
 
 
 
408
 
409
  # Persist assistant message
410
  st.session_state.messages.append(