zakerytclarke commited on
Commit
4c83d38
·
verified ·
1 Parent(s): 8f8133e

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +334 -108
src/streamlit_app.py CHANGED
@@ -1,28 +1,26 @@
1
- # streamlit_app.py
2
  import os
3
  import re
4
  import time
5
- import warnings
6
- from typing import List, Dict, Optional
7
 
8
  import requests
9
  import streamlit as st
10
  import torch
11
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
12
 
13
- from teapotai import TeapotAI
 
 
 
 
14
 
15
 
16
  # -----------------------
17
- # Optional: quiet noisy warnings from deps
18
  # -----------------------
19
- warnings.filterwarnings("ignore", message="pkg_resources is deprecated as an API.*")
20
- warnings.filterwarnings("ignore", message='Field name "schema" in "TeapotTool" shadows.*')
21
-
22
 
23
- # -----------------------
24
- # Config
25
- # -----------------------
26
  TEAPOT_LOGO_GIF = "https://teapotai.com/assets/logo.gif"
27
 
28
  SUGGESTED_QUERIES = [
@@ -41,46 +39,54 @@ DEFAULT_SYSTEM_PROMPT = (
41
  "'I am sorry but I don't have any information on that'."
42
  )
43
 
44
- DEFAULT_DOCUMENTS = [
45
- """Teapot (Tiny Teapot) is an open-source small language model (~77 million parameters) fine-tuned on synthetic data and optimized to run locally on resource-constrained devices such as smartphones and CPUs. Teapot is trained to only answer using context from documents, reducing hallucinations. Teapot can perform a variety of tasks, including hallucination-resistant Question Answering (QnA), Retrieval-Augmented Generation (RAG), and JSON extraction. TeapotLLM is a fine tune of flan-t5-large that was trained on synthetic data generated by Deepseek v3 TeapotLLM can be hosted on low-power devices with as little as 2GB of CPU RAM such as a Raspberry Pi. Teapot is a model built by and for the community."""
46
- ]
47
-
48
- # Brave Search
49
- BRAVE_ENDPOINT = "https://api.search.brave.com/res/v1/web/search"
50
  TOP_K = 3
51
  TIMEOUT_SECS = 15
52
 
53
-
54
- # -----------------------
55
- # Streamlit setup (no custom theming)
56
- # -----------------------
57
- st.set_page_config(page_title="TeapotAI Chat", page_icon="🫖", layout="centered")
58
 
59
 
60
  # -----------------------
61
- # Helpers
62
  # -----------------------
63
  def st_image_full_width(img_url: str):
64
- # Streamlit API varies across builds
65
  try:
66
  st.image(img_url, use_container_width=True)
67
  except TypeError:
68
  st.image(img_url, use_column_width=True)
69
 
70
 
71
- def get_brave_key() -> Optional[str]:
72
- # HF Spaces secrets are commonly env vars; support st.secrets too
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  return os.getenv("BRAVE_API_KEY") or (st.secrets.get("BRAVE_API_KEY") if hasattr(st, "secrets") else None)
74
 
75
 
76
- def brave_search_snippets(query: str, top_k: int = 3) -> List[Dict[str, str]]:
77
- key = get_brave_key()
78
  if not key:
79
- raise RuntimeError("Missing BRAVE_API_KEY (set as a Space secret / env var).")
80
 
81
  headers = {"Accept": "application/json", "X-Subscription-Token": key}
82
  params = {"q": query, "count": top_k}
83
- r = requests.get(BRAVE_ENDPOINT, headers=headers, params=params, timeout=TIMEOUT_SECS)
84
  r.raise_for_status()
85
  data = r.json()
86
 
@@ -97,6 +103,9 @@ def brave_search_snippets(query: str, top_k: int = 3) -> List[Dict[str, str]]:
97
 
98
 
99
  def format_context_from_results(results: List[Dict[str, str]]) -> str:
 
 
 
100
  if not results:
101
  return ""
102
 
@@ -106,83 +115,194 @@ def format_context_from_results(results: List[Dict[str, str]]) -> str:
106
  url = re.sub(r"\s+", " ", r.get("url", "")).strip()
107
  snippet = re.sub(r"\s+", " ", r.get("snippet", "")).strip()
108
 
109
- # per your requirement: strip <strong> tags
110
  title = title.replace("<strong>", "").replace("</strong>", "")
111
  snippet = snippet.replace("<strong>", "").replace("</strong>", "")
112
 
113
  blocks.append(f"[{i}] {title}\nURL: {url}\nSnippet: {snippet}")
114
-
115
  return "\n\n".join(blocks)
116
 
117
 
118
  def count_tokens(tokenizer: AutoTokenizer, text: str) -> int:
119
  if not text:
120
  return 0
121
- try:
122
- return len(tokenizer.encode(text))
123
- except Exception:
124
- return 0
125
 
126
 
127
- def render_sources_popover(sources: List[Dict[str, str]], context: str):
 
 
 
 
 
 
 
 
 
 
 
128
  """
129
- Renders ℹ️ popover if available; otherwise uses expander.
 
130
  """
131
- def _body():
132
- st.markdown("**Sources**")
133
- if sources:
134
- for j, s in enumerate(sources, start=1):
135
- title = (s.get("title") or "").strip() or f"Result {j}"
136
- url = (s.get("url") or "").strip()
137
- snippet = (s.get("snippet") or "").strip()
138
- if url:
139
- st.markdown(f"- [{title}]({url})")
140
- else:
141
- st.markdown(f"- {title}")
142
- if snippet:
143
- st.caption(snippet)
144
- else:
145
- st.caption("(No sources returned.)")
146
 
147
- st.markdown("**Full context**")
148
- if context.strip():
149
- st.code(context)
150
- else:
151
- st.caption("(Empty context.)")
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  try:
154
- with st.popover("ℹ️"):
155
- _body()
156
  except Exception:
157
- with st.expander("ℹ️ Sources / Context"):
158
- _body()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
 
161
  # -----------------------
162
- # Load model + TeapotAI (cached)
163
  # -----------------------
164
  @st.cache_resource
165
- def load_teapot_ai_and_tokenizer():
166
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
167
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
168
 
169
  device = "cuda" if torch.cuda.is_available() else "cpu"
170
  model.to(device)
171
  model.eval()
 
172
 
173
- teapot_ai = TeapotAI(
174
- tokenizer=tokenizer,
175
- model=model,
176
- documents=DEFAULT_DOCUMENTS,
177
- )
178
- return teapot_ai, tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
 
181
  # -----------------------
182
  # Session state
183
  # -----------------------
184
  if "messages" not in st.session_state:
185
- # Each assistant message includes: sources/context + timing/tokens
 
 
186
  st.session_state.messages = []
187
  if "pending_query" not in st.session_state:
188
  st.session_state.pending_query = None
@@ -200,7 +320,7 @@ with c2:
200
 
201
 
202
  # -----------------------
203
- # Sidebar (ONLY: system prompt + web search toggle)
204
  # -----------------------
205
  with st.sidebar:
206
  st.markdown("### Settings")
@@ -208,12 +328,15 @@ with st.sidebar:
208
  use_web_search = st.checkbox("Use web search", value=True)
209
 
210
 
211
- # Load tiny model on startup
212
- teapot_ai, hf_tokenizer = load_teapot_ai_and_tokenizer()
 
 
 
213
 
214
 
215
  # -----------------------
216
- # Suggested queries on empty chat
217
  # -----------------------
218
  if len(st.session_state.messages) == 0 and st.session_state.pending_query is None:
219
  st.markdown("#### Suggested")
@@ -228,7 +351,38 @@ if len(st.session_state.messages) == 0 and st.session_state.pending_query is Non
228
  # -----------------------
229
  # Render chat history
230
  # -----------------------
231
- for m in st.session_state.messages:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  if m["role"] == "user":
233
  with st.chat_message("user"):
234
  st.markdown(m["content"])
@@ -236,25 +390,33 @@ for m in st.session_state.messages:
236
  with st.chat_message("assistant"):
237
  st.markdown(m["content"])
238
 
239
- # metadata row
240
- meta_cols = st.columns([1, 3, 3, 5])
241
- with meta_cols[0]:
242
  render_sources_popover(m.get("sources", []), m.get("context", ""))
243
 
244
- # tokens/sec, and token counts
245
- tps = m.get("tps", None)
246
- out_toks = m.get("output_tokens", None)
247
- secs = m.get("seconds", None)
248
-
249
- with meta_cols[1]:
250
- if tps is not None:
251
- st.caption(f"⚡ {tps:.1f} tokens/s")
252
- with meta_cols[2]:
253
- if out_toks is not None:
254
- st.caption(f"🧮 {out_toks} output tokens")
255
- with meta_cols[3]:
256
- if secs is not None:
257
- st.caption(f"⏱️ {secs:.2f}s")
 
 
 
 
 
 
 
 
 
258
 
259
 
260
  # -----------------------
@@ -267,41 +429,105 @@ if st.session_state.pending_query and not user_input:
267
  st.session_state.pending_query = None
268
 
269
  if user_input:
 
270
  st.session_state.messages.append({"role": "user", "content": user_input})
271
 
 
272
  sources: List[Dict[str, str]] = []
273
- context = ""
274
 
275
  if use_web_search:
276
  try:
277
- sources = brave_search_snippets(user_input, top_k=TOP_K)
278
- context = format_context_from_results(sources)
279
  except Exception:
280
  sources = []
281
- context = ""
282
 
283
- # Teapot inference + timing
284
- t0 = time.perf_counter()
285
- answer = teapot_ai.query(
286
- query=user_input,
 
 
 
 
 
 
 
 
 
287
  context=context,
288
  system_prompt=system_prompt,
 
 
289
  )
290
- t1 = time.perf_counter()
291
 
292
- elapsed = max(t1 - t0, 1e-6)
293
- output_tokens = count_tokens(hf_tokenizer, answer)
294
- tps = output_tokens / elapsed if output_tokens > 0 else 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
 
296
  st.session_state.messages.append(
297
  {
298
  "role": "assistant",
299
- "content": answer,
300
  "sources": sources,
301
  "context": context,
 
302
  "seconds": elapsed,
303
- "output_tokens": output_tokens,
304
  "tps": tps,
 
305
  }
306
  )
 
307
  st.rerun()
 
 
1
  import os
2
  import re
3
  import time
4
+ import threading
5
+ from typing import List, Dict, Optional, Iterable, Tuple
6
 
7
  import requests
8
  import streamlit as st
9
  import torch
10
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
11
 
12
+ # LangSmith
13
+ try:
14
+ from langsmith import Client as LangSmithClient
15
+ except Exception:
16
+ LangSmithClient = None
17
 
18
 
19
  # -----------------------
20
+ # App config
21
  # -----------------------
22
+ st.set_page_config(page_title="TeapotAI Chat", page_icon="🫖", layout="centered")
 
 
23
 
 
 
 
24
  TEAPOT_LOGO_GIF = "https://teapotai.com/assets/logo.gif"
25
 
26
  SUGGESTED_QUERIES = [
 
39
  "'I am sorry but I don't have any information on that'."
40
  )
41
 
42
+ # Search provider (kept internal; UI says “web search”)
43
+ SEARCH_ENDPOINT = "https://api.search.brave.com/res/v1/web/search"
 
 
 
 
44
  TOP_K = 3
45
  TIMEOUT_SECS = 15
46
 
47
+ # Model input budget
48
+ MAX_INPUT_TOKENS = 512
49
+ MAX_NEW_TOKENS = 192 # output cap
 
 
50
 
51
 
52
  # -----------------------
53
+ # Utilities
54
  # -----------------------
55
  def st_image_full_width(img_url: str):
 
56
  try:
57
  st.image(img_url, use_container_width=True)
58
  except TypeError:
59
  st.image(img_url, use_column_width=True)
60
 
61
 
62
+ def autoscroll_to_bottom():
63
+ st.markdown(
64
+ """
65
+ <script>
66
+ (function() {
67
+ const doc = window.parent.document;
68
+ const el = doc.documentElement || doc.body;
69
+ el.scrollTo({ top: el.scrollHeight, behavior: "smooth" });
70
+ })();
71
+ </script>
72
+ """,
73
+ unsafe_allow_html=True,
74
+ )
75
+
76
+
77
+ def get_search_key() -> Optional[str]:
78
+ # Keep the secret name you already use
79
  return os.getenv("BRAVE_API_KEY") or (st.secrets.get("BRAVE_API_KEY") if hasattr(st, "secrets") else None)
80
 
81
 
82
+ def search_top_snippets(query: str, top_k: int = 3) -> List[Dict[str, str]]:
83
+ key = get_search_key()
84
  if not key:
85
+ raise RuntimeError("Missing BRAVE_API_KEY (Space secret / env var).")
86
 
87
  headers = {"Accept": "application/json", "X-Subscription-Token": key}
88
  params = {"q": query, "count": top_k}
89
+ r = requests.get(SEARCH_ENDPOINT, headers=headers, params=params, timeout=TIMEOUT_SECS)
90
  r.raise_for_status()
91
  data = r.json()
92
 
 
103
 
104
 
105
  def format_context_from_results(results: List[Dict[str, str]]) -> str:
106
+ """
107
+ Stable formatting + strip <strong> tags.
108
+ """
109
  if not results:
110
  return ""
111
 
 
115
  url = re.sub(r"\s+", " ", r.get("url", "")).strip()
116
  snippet = re.sub(r"\s+", " ", r.get("snippet", "")).strip()
117
 
 
118
  title = title.replace("<strong>", "").replace("</strong>", "")
119
  snippet = snippet.replace("<strong>", "").replace("</strong>", "")
120
 
121
  blocks.append(f"[{i}] {title}\nURL: {url}\nSnippet: {snippet}")
 
122
  return "\n\n".join(blocks)
123
 
124
 
125
  def count_tokens(tokenizer: AutoTokenizer, text: str) -> int:
126
  if not text:
127
  return 0
128
+ return len(tokenizer.encode(text))
 
 
 
129
 
130
 
131
+ def build_prompt(context: str, system_prompt: str, question: str) -> str:
132
+ # EXACT format you’ve been using
133
+ return f"{context}\n{system_prompt}\n{question}\n"
134
+
135
+
136
+ def truncate_context_to_fit(
137
+ tokenizer: AutoTokenizer,
138
+ context: str,
139
+ system_prompt: str,
140
+ question: str,
141
+ max_input_tokens: int = 512,
142
+ ) -> str:
143
  """
144
+ Keep the *most recent* context while ensuring total prompt <= max_input_tokens.
145
+ We right-truncate by tokens (keep tail).
146
  """
147
+ # Tokenize fixed parts (system + question + newlines)
148
+ fixed_prompt = build_prompt("", system_prompt, question)
149
+ fixed_tokens = tokenizer.encode(fixed_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
+ # Reserve at least 0 for context
152
+ budget = max_input_tokens - len(fixed_tokens)
153
+ if budget <= 0:
154
+ return "" # no room for context at all
 
155
 
156
+ ctx_tokens = tokenizer.encode(context)
157
+ if len(ctx_tokens) <= budget:
158
+ return context
159
+
160
+ # Keep the most recent tokens (tail)
161
+ kept = ctx_tokens[-budget:]
162
+ truncated = tokenizer.decode(kept, skip_special_tokens=True)
163
+ return truncated
164
+
165
+
166
+ # -----------------------
167
+ # LangSmith integration
168
+ # -----------------------
169
+ @st.cache_resource
170
+ def get_langsmith_client() -> Optional["LangSmithClient"]:
171
+ if LangSmithClient is None:
172
+ return None
173
+
174
+ # LangSmith typically uses these env vars; if no key, no-op.
175
+ api_key = os.getenv("LANGCHAIN_API_KEY") or os.getenv("LANGSMITH_API_KEY")
176
+ if not api_key:
177
+ return None
178
  try:
179
+ return LangSmithClient()
 
180
  except Exception:
181
+ return None
182
+
183
+
184
+ def ls_create_run(
185
+ client: Optional["LangSmithClient"],
186
+ *,
187
+ context: str,
188
+ system_prompt: str,
189
+ question: str,
190
+ model_name: str,
191
+ ) -> Optional[str]:
192
+ if client is None:
193
+ return None
194
+
195
+ project = os.getenv("LANGCHAIN_PROJECT") or "teapot-chat"
196
+ try:
197
+ run = client.create_run(
198
+ name="teapot_chat_turn",
199
+ run_type="llm",
200
+ project_name=project,
201
+ inputs={
202
+ "context": context,
203
+ "system_prompt": system_prompt,
204
+ "question": question,
205
+ "model": model_name,
206
+ },
207
+ tags=["teapot", "streamlit"],
208
+ )
209
+ # create_run returns a Run-like object; the id property name can vary
210
+ return getattr(run, "id", None) or getattr(run, "run_id", None)
211
+ except Exception:
212
+ return None
213
+
214
+
215
+ def ls_end_run(
216
+ client: Optional["LangSmithClient"],
217
+ run_id: Optional[str],
218
+ *,
219
+ answer: str,
220
+ meta: Dict[str, object],
221
+ ):
222
+ if client is None or not run_id:
223
+ return
224
+ try:
225
+ client.update_run(
226
+ run_id,
227
+ outputs={"answer": answer, **meta},
228
+ )
229
+ except Exception:
230
+ pass
231
+
232
+
233
+ def ls_feedback(
234
+ client: Optional["LangSmithClient"],
235
+ run_id: Optional[str],
236
+ *,
237
+ score: int,
238
+ comment: str = "",
239
+ ):
240
+ if client is None or not run_id:
241
+ return
242
+ try:
243
+ client.create_feedback(
244
+ run_id=run_id,
245
+ key="user_feedback",
246
+ score=float(score),
247
+ comment=comment or None,
248
+ )
249
+ except Exception:
250
+ pass
251
 
252
 
253
  # -----------------------
254
+ # Model loading
255
  # -----------------------
256
  @st.cache_resource
257
+ def load_model_and_tokenizer():
258
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
259
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
260
 
261
  device = "cuda" if torch.cuda.is_available() else "cpu"
262
  model.to(device)
263
  model.eval()
264
+ return tokenizer, model, device
265
 
266
+
267
+ def generate_stream(
268
+ tokenizer: AutoTokenizer,
269
+ model: AutoModelForSeq2SeqLM,
270
+ device: str,
271
+ prompt: str,
272
+ max_new_tokens: int = 192,
273
+ ) -> Iterable[str]:
274
+ """
275
+ True streaming via TextIteratorStreamer.
276
+ Yields progressively longer partial outputs.
277
+ """
278
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
279
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
280
+
281
+ def _run():
282
+ model.generate(
283
+ **inputs,
284
+ do_sample=False,
285
+ num_beams=1,
286
+ max_new_tokens=int(max_new_tokens),
287
+ streamer=streamer,
288
+ )
289
+
290
+ t = threading.Thread(target=_run, daemon=True)
291
+ t.start()
292
+
293
+ partial = ""
294
+ for piece in streamer:
295
+ partial += piece
296
+ yield partial
297
 
298
 
299
  # -----------------------
300
  # Session state
301
  # -----------------------
302
  if "messages" not in st.session_state:
303
+ # message schema:
304
+ # user: {"role":"user","content":...}
305
+ # assistant: {"role":"assistant","content":..., "sources":[...], "context":..., "run_id":..., "tps":..., "output_tokens":..., "seconds":..., "feedback": None/1/-1}
306
  st.session_state.messages = []
307
  if "pending_query" not in st.session_state:
308
  st.session_state.pending_query = None
 
320
 
321
 
322
  # -----------------------
323
+ # Sidebar (ONLY system prompt + web search)
324
  # -----------------------
325
  with st.sidebar:
326
  st.markdown("### Settings")
 
328
  use_web_search = st.checkbox("Use web search", value=True)
329
 
330
 
331
+ # Load model
332
+ tokenizer, model, device = load_model_and_tokenizer()
333
+
334
+ # LangSmith client (optional)
335
+ ls_client = get_langsmith_client()
336
 
337
 
338
  # -----------------------
339
+ # Suggested queries when empty
340
  # -----------------------
341
  if len(st.session_state.messages) == 0 and st.session_state.pending_query is None:
342
  st.markdown("#### Suggested")
 
351
  # -----------------------
352
  # Render chat history
353
  # -----------------------
354
+ def render_sources_popover(sources: List[Dict[str, str]], context: str):
355
+ def _body():
356
+ st.markdown("**Sources**")
357
+ if sources:
358
+ for j, s in enumerate(sources, start=1):
359
+ title = (s.get("title") or "").strip() or f"Result {j}"
360
+ url = (s.get("url") or "").strip()
361
+ snippet = (s.get("snippet") or "").strip()
362
+ if url:
363
+ st.markdown(f"- [{title}]({url})")
364
+ else:
365
+ st.markdown(f"- {title}")
366
+ if snippet:
367
+ st.caption(snippet)
368
+ else:
369
+ st.caption("(No sources returned.)")
370
+
371
+ st.markdown("**Full context**")
372
+ if context.strip():
373
+ st.code(context)
374
+ else:
375
+ st.caption("(Empty context.)")
376
+
377
+ try:
378
+ with st.popover("ℹ️"):
379
+ _body()
380
+ except Exception:
381
+ with st.expander("ℹ️ Sources / Context"):
382
+ _body()
383
+
384
+
385
+ for idx, m in enumerate(st.session_state.messages):
386
  if m["role"] == "user":
387
  with st.chat_message("user"):
388
  st.markdown(m["content"])
 
390
  with st.chat_message("assistant"):
391
  st.markdown(m["content"])
392
 
393
+ meta = st.columns([1, 2.2, 2.2, 2.2, 2.2])
394
+ with meta[0]:
 
395
  render_sources_popover(m.get("sources", []), m.get("context", ""))
396
 
397
+ with meta[1]:
398
+ st.caption(f"⚡ {m.get('tps', 0.0):.1f} tok/s")
399
+ with meta[2]:
400
+ st.caption(f"🧮 {m.get('output_tokens', 0)} toks")
401
+ with meta[3]:
402
+ st.caption(f"⏱️ {m.get('seconds', 0.0):.2f}s")
403
+
404
+ # Feedback buttons wired to LangSmith
405
+ feedback = m.get("feedback", None)
406
+ run_id = m.get("run_id", None)
407
+ btn_cols = st.columns([1, 1, 6])
408
+ with btn_cols[0]:
409
+ up_disabled = feedback is not None
410
+ if st.button("👍", key=f"fb_up_{idx}", disabled=up_disabled):
411
+ st.session_state.messages[idx]["feedback"] = 1
412
+ ls_feedback(ls_client, run_id, score=1)
413
+ st.rerun()
414
+ with btn_cols[1]:
415
+ down_disabled = feedback is not None
416
+ if st.button("👎", key=f"fb_down_{idx}", disabled=down_disabled):
417
+ st.session_state.messages[idx]["feedback"] = -1
418
+ ls_feedback(ls_client, run_id, score=-1)
419
+ st.rerun()
420
 
421
 
422
  # -----------------------
 
429
  st.session_state.pending_query = None
430
 
431
  if user_input:
432
+ # Add user message
433
  st.session_state.messages.append({"role": "user", "content": user_input})
434
 
435
+ # Build context (optional web search)
436
  sources: List[Dict[str, str]] = []
437
+ raw_context = ""
438
 
439
  if use_web_search:
440
  try:
441
+ sources = search_top_snippets(user_input, top_k=TOP_K)
442
+ raw_context = format_context_from_results(sources)
443
  except Exception:
444
  sources = []
445
+ raw_context = ""
446
 
447
+ # Truncate context to fit 512 tokens total prompt, keeping most recent
448
+ context = truncate_context_to_fit(
449
+ tokenizer=tokenizer,
450
+ context=raw_context,
451
+ system_prompt=system_prompt,
452
+ question=user_input,
453
+ max_input_tokens=MAX_INPUT_TOKENS,
454
+ )
455
+ prompt = build_prompt(context, system_prompt, user_input)
456
+
457
+ # Create LangSmith run now (inputs)
458
+ run_id = ls_create_run(
459
+ ls_client,
460
  context=context,
461
  system_prompt=system_prompt,
462
+ question=user_input,
463
+ model_name=MODEL_NAME,
464
  )
 
465
 
466
+ # Stream generation into the UI
467
+ with st.chat_message("assistant"):
468
+ placeholder = st.empty()
469
+
470
+ t0 = time.perf_counter()
471
+ final_text = ""
472
+
473
+ for partial in generate_stream(tokenizer, model, device, prompt, max_new_tokens=MAX_NEW_TOKENS):
474
+ final_text = partial
475
+ placeholder.markdown(final_text)
476
+ autoscroll_to_bottom()
477
+
478
+ t1 = time.perf_counter()
479
+ elapsed = max(t1 - t0, 1e-6)
480
+
481
+ out_tokens = count_tokens(tokenizer, final_text)
482
+ tps = (out_tokens / elapsed) if out_tokens > 0 else 0.0
483
+
484
+ # Metadata row + feedback buttons (live)
485
+ meta = st.columns([1, 2.2, 2.2, 2.2, 2.2])
486
+ with meta[0]:
487
+ render_sources_popover(sources, context)
488
+ with meta[1]:
489
+ st.caption(f"⚡ {tps:.1f} tok/s")
490
+ with meta[2]:
491
+ st.caption(f"🧮 {out_tokens} toks")
492
+ with meta[3]:
493
+ st.caption(f"⏱️ {elapsed:.2f}s")
494
+
495
+ btn_cols = st.columns([1, 1, 6])
496
+ with btn_cols[0]:
497
+ if st.button("👍", key=f"fb_up_live_{len(st.session_state.messages)}"):
498
+ ls_feedback(ls_client, run_id, score=1)
499
+ with btn_cols[1]:
500
+ if st.button("👎", key=f"fb_down_live_{len(st.session_state.messages)}"):
501
+ ls_feedback(ls_client, run_id, score=-1)
502
+
503
+ # End LangSmith run (outputs)
504
+ ls_end_run(
505
+ ls_client,
506
+ run_id,
507
+ answer=final_text,
508
+ meta={
509
+ "seconds": elapsed,
510
+ "output_tokens": out_tokens,
511
+ "tokens_per_second": tps,
512
+ "used_web_search": bool(use_web_search),
513
+ "max_input_tokens": MAX_INPUT_TOKENS,
514
+ "max_new_tokens": MAX_NEW_TOKENS,
515
+ },
516
+ )
517
 
518
+ # Persist assistant message for history (feedback state stored)
519
  st.session_state.messages.append(
520
  {
521
  "role": "assistant",
522
+ "content": final_text,
523
  "sources": sources,
524
  "context": context,
525
+ "run_id": run_id,
526
  "seconds": elapsed,
527
+ "output_tokens": out_tokens,
528
  "tps": tps,
529
+ "feedback": None,
530
  }
531
  )
532
+
533
  st.rerun()