zakerytclarke commited on
Commit
e4379b8
·
verified ·
1 Parent(s): 42556c6

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +88 -82
src/streamlit_app.py CHANGED
@@ -12,7 +12,6 @@ try:
12
  except:
13
  LangSmithClient = None
14
 
15
-
16
  # =========================
17
  # CONFIG
18
  # =========================
@@ -22,11 +21,7 @@ MAX_NEW_TOKENS = 192
22
  TOP_K_SEARCH = 3
23
  LOGO_URL = "https://teapotai.com/assets/logo.gif"
24
 
25
- st.set_page_config(
26
- page_title="TeapotAI Chat",
27
- page_icon="🫖",
28
- layout="centered"
29
- )
30
 
31
  # =========================
32
  # LOAD MODEL (CACHED)
@@ -36,42 +31,40 @@ def load_model():
36
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
37
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
38
  device = "cuda" if torch.cuda.is_available() else "cpu"
39
- model.to(device)
40
- model.eval()
41
  return tokenizer, model, device
42
 
43
  tokenizer, model, device = load_model()
44
 
45
-
46
  # =========================
47
- # LANGSMITH
48
  # =========================
49
  @st.cache_resource
50
  def get_langsmith():
51
- api_key = os.getenv("LANGCHAIN_API_KEY") or os.getenv("LANGSMITH_API_KEY")
52
- if api_key and LangSmithClient:
53
  return LangSmithClient()
54
  return None
55
 
56
  ls_client = get_langsmith()
57
 
58
-
59
  # =========================
60
  # SESSION STATE
61
  # =========================
62
  if "messages" not in st.session_state:
63
  st.session_state.messages = []
 
 
64
 
65
  # =========================
66
- # HEADER (LOGO RESTORED)
67
  # =========================
68
  col1, col2 = st.columns([1, 6])
69
  with col1:
70
  st.image(LOGO_URL, use_column_width=True)
71
  with col2:
72
  st.markdown("## TeapotAI Chat")
73
- st.caption("Fast, grounded answers with web context")
74
-
75
 
76
  # =========================
77
  # SIDEBAR SETTINGS
@@ -88,32 +81,27 @@ with st.sidebar:
88
  "If the context does not answer the question, reply exactly: "
89
  "'I am sorry but I don't have any information on that'."
90
  ),
91
- height=180
92
  )
93
 
94
- st.markdown("### Extra Context (Optional)")
95
- user_context = st.text_area(
96
- "Paste context to append to web results",
97
- height=150,
98
- placeholder="Add any custom context here..."
99
  )
100
 
101
  use_web = st.checkbox("Use web search", value=True)
102
 
103
-
104
  # =========================
105
- # WEB SEARCH (FAST)
106
  # =========================
107
- def web_search(query: str):
108
  api_key = os.getenv("BRAVE_API_KEY") or st.secrets.get("BRAVE_API_KEY", None)
109
  if not api_key:
110
  return "", 0.0
111
 
112
- headers = {
113
- "X-Subscription-Token": api_key,
114
- "Accept": "application/json"
115
- }
116
-
117
  params = {"q": query, "count": TOP_K_SEARCH}
118
 
119
  t0 = time.perf_counter()
@@ -129,35 +117,37 @@ def web_search(query: str):
129
  return "", 0.0
130
  t1 = time.perf_counter()
131
 
132
- blocks = []
133
- for i, item in enumerate(data.get("web", {}).get("results", [])[:TOP_K_SEARCH], 1):
134
- title = item.get("title", "")
135
- url = item.get("url", "")
136
  desc = item.get("description", "")
137
- desc = desc.replace("<strong>", "").replace("</strong>", "")
138
- blocks.append(f"[{i}] {title}\nURL: {url}\nSnippet: {desc}")
139
-
140
- context = "\n\n".join(blocks)
141
- return context, (t1 - t0)
142
 
 
 
 
143
 
144
  # =========================
145
  # TRUNCATE TO LAST 512 TOKENS
146
  # =========================
147
- def truncate_to_512(context: str, system: str, question: str):
148
- base_prompt = f"\n{system}\n{question}\n"
149
- base_tokens = tokenizer.encode(base_prompt)
 
 
 
 
150
  budget = MAX_INPUT_TOKENS - len(base_tokens)
151
 
152
- ctx_tokens = tokenizer.encode(context)
153
  if len(ctx_tokens) <= budget:
154
- return context
155
 
156
  # Keep MOST RECENT tokens (tail truncation)
157
  truncated = ctx_tokens[-budget:]
158
  return tokenizer.decode(truncated, skip_special_tokens=True)
159
 
160
-
161
  # =========================
162
  # STREAM GENERATION
163
  # =========================
@@ -182,17 +172,12 @@ def stream_generate(prompt: str):
182
  text += chunk
183
  yield text
184
 
185
-
186
  # =========================
187
- # LANGSMITH FEEDBACK HANDLER
188
  # =========================
189
  def handle_feedback(idx: int):
190
  val = st.session_state[f"feedback_{idx}"]
191
  msg = st.session_state.messages[idx]
192
-
193
- if val is None:
194
- return
195
-
196
  msg["feedback"] = val
197
 
198
  if ls_client and msg.get("run_id"):
@@ -204,18 +189,18 @@ def handle_feedback(idx: int):
204
  score=score,
205
  comment="thumbs_up" if score else "thumbs_down",
206
  )
207
- except Exception as e:
208
- print("LangSmith feedback error:", e)
209
-
210
 
211
  # =========================
212
- # RENDER CHAT
213
  # =========================
214
  for i, msg in enumerate(st.session_state.messages):
215
  with st.chat_message(msg["role"]):
216
  st.markdown(msg["content"])
217
 
218
  if msg["role"] == "assistant":
 
219
  st.caption(
220
  f"🔎 {msg['search_time']:.2f}s search • "
221
  f"🧠 {msg['gen_time']:.2f}s generate • "
@@ -223,48 +208,63 @@ for i, msg in enumerate(st.session_state.messages):
223
  f"🧮 {msg['tokens']} tokens"
224
  )
225
 
226
- feedback_key = f"feedback_{i}"
227
- st.session_state.setdefault(feedback_key, msg.get("feedback"))
228
-
 
 
 
 
 
 
 
 
 
229
  st.feedback(
230
  "thumbs",
231
- key=feedback_key,
232
  disabled=msg.get("feedback") is not None,
233
  on_change=handle_feedback,
234
  args=(i,),
235
  )
236
 
237
-
238
  # =========================
239
- # CHAT INPUT
240
  # =========================
241
  query = st.chat_input("Ask a question...")
242
 
243
  if query:
 
244
  st.session_state.messages.append({"role": "user", "content": query})
 
245
 
246
- # ---- WEB SEARCH ----
247
- web_context = ""
 
 
 
 
 
 
 
 
 
 
248
  search_time = 0.0
249
  if use_web:
250
- web_context, search_time = web_search(query)
251
-
252
- # ---- COMBINED CONTEXT (WEB + USER BOX) ----
253
- combined_context = ""
254
- if user_context:
255
- combined_context += user_context.strip() + "\n\n"
256
- if web_context:
257
- combined_context += web_context
258
 
259
- truncated_context = truncate_to_512(
260
- combined_context,
 
 
261
  system_prompt,
262
- query
263
  )
264
 
265
- prompt = f"{truncated_context}\n{system_prompt}\n{query}\n"
266
 
267
- # ---- LANGSMITH RUN ----
268
  run_id = None
269
  if ls_client:
270
  try:
@@ -272,26 +272,28 @@ if query:
272
  name="teapot_chat",
273
  run_type="llm",
274
  inputs={
275
- "context": truncated_context,
 
276
  "system_prompt": system_prompt,
277
  "question": query,
 
278
  },
279
  )
280
  run_id = run.id
281
  except:
282
  pass
283
 
284
- # ---- STREAM OUTPUT ----
285
  with st.chat_message("assistant"):
286
  placeholder = st.empty()
287
-
288
- gen_start = time.perf_counter()
289
  final_text = ""
 
290
  for partial in stream_generate(prompt):
291
  final_text = partial
292
  placeholder.markdown(final_text)
293
- gen_time = time.perf_counter() - gen_start
294
 
 
295
  tokens = len(tokenizer.encode(final_text))
296
  tps = tokens / gen_time if gen_time > 0 else 0.0
297
 
@@ -312,13 +314,17 @@ if query:
312
  {
313
  "role": "assistant",
314
  "content": final_text,
 
 
 
315
  "search_time": search_time,
316
  "gen_time": gen_time,
317
- "tps": tps,
318
  "tokens": tokens,
 
319
  "run_id": run_id,
320
  "feedback": None,
321
  }
322
  )
323
 
 
324
  st.rerun()
 
12
  except:
13
  LangSmithClient = None
14
 
 
15
  # =========================
16
  # CONFIG
17
  # =========================
 
21
  TOP_K_SEARCH = 3
22
  LOGO_URL = "https://teapotai.com/assets/logo.gif"
23
 
24
+ st.set_page_config(page_title="TeapotAI Chat", page_icon="🫖", layout="centered")
 
 
 
 
25
 
26
  # =========================
27
  # LOAD MODEL (CACHED)
 
31
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
32
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ model.to(device).eval()
 
35
  return tokenizer, model, device
36
 
37
  tokenizer, model, device = load_model()
38
 
 
39
  # =========================
40
+ # LANGSMITH (OPTIONAL)
41
  # =========================
42
  @st.cache_resource
43
  def get_langsmith():
44
+ key = os.getenv("LANGCHAIN_API_KEY") or os.getenv("LANGSMITH_API_KEY")
45
+ if key and LangSmithClient:
46
  return LangSmithClient()
47
  return None
48
 
49
  ls_client = get_langsmith()
50
 
 
51
  # =========================
52
  # SESSION STATE
53
  # =========================
54
  if "messages" not in st.session_state:
55
  st.session_state.messages = []
56
+ if "pending_response" not in st.session_state:
57
+ st.session_state.pending_response = None
58
 
59
  # =========================
60
+ # HEADER (LOGO)
61
  # =========================
62
  col1, col2 = st.columns([1, 6])
63
  with col1:
64
  st.image(LOGO_URL, use_column_width=True)
65
  with col2:
66
  st.markdown("## TeapotAI Chat")
67
+ st.caption("Fast grounded answers with clean web context")
 
68
 
69
  # =========================
70
  # SIDEBAR SETTINGS
 
81
  "If the context does not answer the question, reply exactly: "
82
  "'I am sorry but I don't have any information on that'."
83
  ),
84
+ height=180,
85
  )
86
 
87
+ st.markdown("### Local Context (Appended)")
88
+ local_context = st.text_area(
89
+ "Paste additional context (optional)",
90
+ height=160,
91
+ placeholder="This will be appended after web content..."
92
  )
93
 
94
  use_web = st.checkbox("Use web search", value=True)
95
 
 
96
  # =========================
97
+ # WEB SEARCH (SNIPPETS ONLY)
98
  # =========================
99
+ def web_search_snippets(query: str):
100
  api_key = os.getenv("BRAVE_API_KEY") or st.secrets.get("BRAVE_API_KEY", None)
101
  if not api_key:
102
  return "", 0.0
103
 
104
+ headers = {"X-Subscription-Token": api_key, "Accept": "application/json"}
 
 
 
 
105
  params = {"q": query, "count": TOP_K_SEARCH}
106
 
107
  t0 = time.perf_counter()
 
117
  return "", 0.0
118
  t1 = time.perf_counter()
119
 
120
+ snippets = []
121
+ for item in data.get("web", {}).get("results", [])[:TOP_K_SEARCH]:
 
 
122
  desc = item.get("description", "")
123
+ desc = desc.replace("<strong>", "").replace("</strong>", "").strip()
124
+ if desc:
125
+ snippets.append(desc)
 
 
126
 
127
+ # Paragraph-separated ONLY (no title, no URL)
128
+ clean_context = "\n\n".join(snippets)
129
+ return clean_context, (t1 - t0)
130
 
131
  # =========================
132
  # TRUNCATE TO LAST 512 TOKENS
133
  # =========================
134
+ def truncate_context(web_ctx, local_ctx, system, question):
135
+ ordered_context = (
136
+ f"{web_ctx}\n\n{local_ctx}".strip()
137
+ )
138
+
139
+ base = f"\n{system}\n{question}\n"
140
+ base_tokens = tokenizer.encode(base)
141
  budget = MAX_INPUT_TOKENS - len(base_tokens)
142
 
143
+ ctx_tokens = tokenizer.encode(ordered_context)
144
  if len(ctx_tokens) <= budget:
145
+ return ordered_context
146
 
147
  # Keep MOST RECENT tokens (tail truncation)
148
  truncated = ctx_tokens[-budget:]
149
  return tokenizer.decode(truncated, skip_special_tokens=True)
150
 
 
151
  # =========================
152
  # STREAM GENERATION
153
  # =========================
 
172
  text += chunk
173
  yield text
174
 
 
175
  # =========================
176
+ # FEEDBACK HANDLER (Native st.feedback)
177
  # =========================
178
  def handle_feedback(idx: int):
179
  val = st.session_state[f"feedback_{idx}"]
180
  msg = st.session_state.messages[idx]
 
 
 
 
181
  msg["feedback"] = val
182
 
183
  if ls_client and msg.get("run_id"):
 
189
  score=score,
190
  comment="thumbs_up" if score else "thumbs_down",
191
  )
192
+ except:
193
+ pass
 
194
 
195
  # =========================
196
+ # RENDER CHAT HISTORY
197
  # =========================
198
  for i, msg in enumerate(st.session_state.messages):
199
  with st.chat_message(msg["role"]):
200
  st.markdown(msg["content"])
201
 
202
  if msg["role"] == "assistant":
203
+ # Metrics row
204
  st.caption(
205
  f"🔎 {msg['search_time']:.2f}s search • "
206
  f"🧠 {msg['gen_time']:.2f}s generate • "
 
208
  f"🧮 {msg['tokens']} tokens"
209
  )
210
 
211
+ # Inspectable context (clean UX)
212
+ with st.expander("🔍 Inspect Context Used"):
213
+ st.markdown("**Web Content:**")
214
+ st.write(msg["web_context"] or "_None_")
215
+ st.markdown("**Local Context:**")
216
+ st.write(msg["local_context"] or "_None_")
217
+ st.markdown("**Final Truncated Context (512 tokens tail):**")
218
+ st.write(msg["final_context"])
219
+
220
+ # Native thumbs feedback
221
+ key = f"feedback_{i}"
222
+ st.session_state.setdefault(key, msg.get("feedback"))
223
  st.feedback(
224
  "thumbs",
225
+ key=key,
226
  disabled=msg.get("feedback") is not None,
227
  on_change=handle_feedback,
228
  args=(i,),
229
  )
230
 
 
231
  # =========================
232
+ # USER INPUT (FIXED ORDER)
233
  # =========================
234
  query = st.chat_input("Ask a question...")
235
 
236
  if query:
237
+ # 1️⃣ Immediately show user message FIRST (fix streaming race)
238
  st.session_state.messages.append({"role": "user", "content": query})
239
+ st.rerun()
240
 
241
+ # =========================
242
+ # GENERATE AFTER RERUN (Prevents premature streaming)
243
+ # =========================
244
+ if (
245
+ st.session_state.messages
246
+ and st.session_state.messages[-1]["role"] == "user"
247
+ and st.session_state.pending_response is None
248
+ ):
249
+ query = st.session_state.messages[-1]["content"]
250
+
251
+ # --- Web Search ---
252
+ web_ctx = ""
253
  search_time = 0.0
254
  if use_web:
255
+ web_ctx, search_time = web_search_snippets(query)
 
 
 
 
 
 
 
256
 
257
+ # --- Strict Order Context ---
258
+ final_context = truncate_context(
259
+ web_ctx,
260
+ local_context,
261
  system_prompt,
262
+ query,
263
  )
264
 
265
+ prompt = f"{final_context}\n{system_prompt}\n{query}\n"
266
 
267
+ # LangSmith run
268
  run_id = None
269
  if ls_client:
270
  try:
 
272
  name="teapot_chat",
273
  run_type="llm",
274
  inputs={
275
+ "web_content": web_ctx,
276
+ "local_context": local_context,
277
  "system_prompt": system_prompt,
278
  "question": query,
279
+ "final_context": final_context,
280
  },
281
  )
282
  run_id = run.id
283
  except:
284
  pass
285
 
286
+ # --- Stream UI ---
287
  with st.chat_message("assistant"):
288
  placeholder = st.empty()
289
+ start = time.perf_counter()
 
290
  final_text = ""
291
+
292
  for partial in stream_generate(prompt):
293
  final_text = partial
294
  placeholder.markdown(final_text)
 
295
 
296
+ gen_time = time.perf_counter() - start
297
  tokens = len(tokenizer.encode(final_text))
298
  tps = tokens / gen_time if gen_time > 0 else 0.0
299
 
 
314
  {
315
  "role": "assistant",
316
  "content": final_text,
317
+ "web_context": web_ctx,
318
+ "local_context": local_context,
319
+ "final_context": final_context,
320
  "search_time": search_time,
321
  "gen_time": gen_time,
 
322
  "tokens": tokens,
323
+ "tps": tps,
324
  "run_id": run_id,
325
  "feedback": None,
326
  }
327
  )
328
 
329
+ st.session_state.pending_response = None
330
  st.rerun()