Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +110 -60
src/streamlit_app.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
import os
|
| 2 |
import time
|
| 3 |
-
import threading
|
| 4 |
import requests
|
| 5 |
|
| 6 |
import streamlit as st
|
|
@@ -50,7 +49,6 @@ tokenizer, model, device = load_model()
|
|
| 50 |
# =========================
|
| 51 |
@st.cache_resource
|
| 52 |
def get_langsmith():
|
| 53 |
-
key = os.getenv("LANGCHAIN_API_KEY") or os.getenv("LANGSMITH_API_KEY") or os.getenv("LANGCHAIN_TRACING_V2")
|
| 54 |
if (os.getenv("LANGCHAIN_API_KEY") or os.getenv("LANGSMITH_API_KEY")) and LangSmithClient:
|
| 55 |
return LangSmithClient()
|
| 56 |
return None
|
|
@@ -59,6 +57,46 @@ def get_langsmith():
|
|
| 59 |
ls_client = get_langsmith()
|
| 60 |
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
# =========================
|
| 63 |
# SESSION STATE
|
| 64 |
# =========================
|
|
@@ -66,18 +104,27 @@ if "messages" not in st.session_state:
|
|
| 66 |
st.session_state.messages = []
|
| 67 |
if "needs_answer" not in st.session_state:
|
| 68 |
st.session_state.needs_answer = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
# =========================
|
| 72 |
# HEADER (prevent logo flash)
|
| 73 |
-
# Use a fixed pixel width to avoid layout shift / big flash.
|
| 74 |
# =========================
|
| 75 |
col1, col2 = st.columns([1, 7], vertical_alignment="center")
|
| 76 |
with col1:
|
| 77 |
-
st.image(LOGO_URL, width=56)
|
| 78 |
with col2:
|
| 79 |
st.markdown("## TeapotAI Chat")
|
| 80 |
-
st.caption(
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
|
| 83 |
# =========================
|
|
@@ -88,13 +135,7 @@ with st.sidebar:
|
|
| 88 |
|
| 89 |
system_prompt = st.text_area(
|
| 90 |
"System prompt",
|
| 91 |
-
value=
|
| 92 |
-
"You are Teapot, an open-source AI assistant optimized for running on low-end cpu devices, "
|
| 93 |
-
"providing short, accurate responses without hallucinating while excelling at "
|
| 94 |
-
"information extraction and text summarization. "
|
| 95 |
-
"If the context does not answer the question, reply exactly: "
|
| 96 |
-
"'I am sorry but I don't have any information on that'."
|
| 97 |
-
),
|
| 98 |
height=160,
|
| 99 |
)
|
| 100 |
|
|
@@ -104,6 +145,21 @@ with st.sidebar:
|
|
| 104 |
placeholder="Extra context appended after web snippets…",
|
| 105 |
)
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
# =========================
|
| 109 |
# WEB SEARCH (ALWAYS ON)
|
|
@@ -162,9 +218,9 @@ def count_tokens(text: str) -> int:
|
|
| 162 |
|
| 163 |
# =========================
|
| 164 |
# LANGSMITH-TRACED ANSWER FUNCTION
|
| 165 |
-
# (signature exactly: context, system_prompt, question -> answer)
|
| 166 |
# =========================
|
| 167 |
if traceable:
|
|
|
|
| 168 |
@traceable(name="teapot_answer")
|
| 169 |
def traced_answer(context: str, system_prompt: str, question: str) -> str:
|
| 170 |
prompt = f"{context}\n{system_prompt}\n{question}\n"
|
|
@@ -176,9 +232,10 @@ if traceable:
|
|
| 176 |
do_sample=False,
|
| 177 |
num_beams=1,
|
| 178 |
)
|
| 179 |
-
|
| 180 |
-
|
| 181 |
else:
|
|
|
|
| 182 |
def traced_answer(context: str, system_prompt: str, question: str) -> str:
|
| 183 |
prompt = f"{context}\n{system_prompt}\n{question}\n"
|
| 184 |
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
|
@@ -193,7 +250,6 @@ else:
|
|
| 193 |
|
| 194 |
|
| 195 |
def get_trace_id_if_available() -> str | None:
|
| 196 |
-
# Works when running inside a @traceable function call
|
| 197 |
if not get_current_run_tree:
|
| 198 |
return None
|
| 199 |
try:
|
|
@@ -217,7 +273,6 @@ def handle_feedback(idx: int):
|
|
| 217 |
if ls_client and trace_id:
|
| 218 |
score = 1 if val == "👍" else 0
|
| 219 |
try:
|
| 220 |
-
# LangSmith SDK supports trace_id= for feedback association
|
| 221 |
ls_client.create_feedback(
|
| 222 |
trace_id=trace_id,
|
| 223 |
key="thumb_rating",
|
|
@@ -230,6 +285,8 @@ def handle_feedback(idx: int):
|
|
| 230 |
|
| 231 |
# =========================
|
| 232 |
# RENDER HISTORY
|
|
|
|
|
|
|
| 233 |
# =========================
|
| 234 |
for i, msg in enumerate(st.session_state.messages):
|
| 235 |
with st.chat_message(msg["role"]):
|
|
@@ -237,17 +294,11 @@ for i, msg in enumerate(st.session_state.messages):
|
|
| 237 |
st.markdown(msg["content"])
|
| 238 |
continue
|
| 239 |
|
| 240 |
-
#
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
c1, c2 = st.columns([1, 12], vertical_alignment="center")
|
| 246 |
-
with c1:
|
| 247 |
-
st.markdown(msg["content"])
|
| 248 |
-
|
| 249 |
-
with c2:
|
| 250 |
-
|
| 251 |
key = f"fb_{i}"
|
| 252 |
st.session_state.setdefault(key, msg.get("feedback"))
|
| 253 |
st.feedback(
|
|
@@ -258,18 +309,10 @@ for i, msg in enumerate(st.session_state.messages):
|
|
| 258 |
args=(i,),
|
| 259 |
)
|
| 260 |
|
| 261 |
-
|
|
|
|
| 262 |
|
| 263 |
-
|
| 264 |
-
with c3:
|
| 265 |
-
st.caption(
|
| 266 |
-
f"🔎 {msg['search_time']:.2f}s (search)"
|
| 267 |
-
f"🧠 {msg['gen_time']:.2f}s (generation) "
|
| 268 |
-
f"⚡ {msg['tps']:.1f} tok/s "
|
| 269 |
-
f"🧾 {msg['input_tokens']} input tokens • {msg['output_tokens']} output tokens"
|
| 270 |
-
)
|
| 271 |
-
|
| 272 |
-
with c4:
|
| 273 |
with st.popover("ℹ️", help="Inspect"):
|
| 274 |
st.markdown("**Context**")
|
| 275 |
st.code(msg.get("context", ""), language="text")
|
|
@@ -277,11 +320,16 @@ for i, msg in enumerate(st.session_state.messages):
|
|
| 277 |
st.code(msg.get("system_prompt", ""), language="text")
|
| 278 |
st.markdown("**Question**")
|
| 279 |
st.code(msg.get("question", ""), language="text")
|
| 280 |
-
|
|
|
|
| 281 |
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
|
| 286 |
|
| 287 |
# =========================
|
|
@@ -313,32 +361,40 @@ if (
|
|
| 313 |
prompt = f"{context}\n{system_prompt}\n{question}\n"
|
| 314 |
input_tokens = count_tokens(prompt)
|
| 315 |
|
| 316 |
-
# Run traced answer
|
| 317 |
with st.chat_message("assistant"):
|
| 318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
start = time.perf_counter()
|
| 321 |
-
|
| 322 |
-
# Generate full answer first (traced), then "stream" it to UI quickly.
|
| 323 |
-
# This keeps LangSmith tracing simple/reliable while still giving a streaming UX.
|
| 324 |
answer = traced_answer(context, system_prompt, question)
|
| 325 |
trace_id = get_trace_id_if_available()
|
| 326 |
|
| 327 |
-
#
|
| 328 |
buf = ""
|
| 329 |
for ch in answer:
|
| 330 |
buf += ch
|
| 331 |
placeholder.markdown(buf)
|
| 332 |
-
# small delay; tune if you want faster/slower
|
| 333 |
time.sleep(0.002)
|
| 334 |
|
| 335 |
gen_time = time.perf_counter() - start
|
| 336 |
output_tokens = count_tokens(answer)
|
| 337 |
tps = output_tokens / gen_time if gen_time > 0 else 0.0
|
| 338 |
|
| 339 |
-
#
|
| 340 |
-
|
| 341 |
-
with
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
with st.popover("ℹ️", help="Inspect"):
|
| 343 |
st.markdown("**Context**")
|
| 344 |
st.code(context, language="text")
|
|
@@ -348,13 +404,7 @@ if (
|
|
| 348 |
st.code(question, language="text")
|
| 349 |
st.markdown("**Prompt**")
|
| 350 |
st.code(prompt, language="text")
|
| 351 |
-
|
| 352 |
-
st.caption(
|
| 353 |
-
f"🔎 {search_time:.2f}s (search) "
|
| 354 |
-
f"🧠 {gen_time:.2f}s (generation) "
|
| 355 |
-
f"⚡ {tps:.1f} tok/s "
|
| 356 |
-
f"🧾 {input_tokens} input tokens • {output_tokens} output tokens"
|
| 357 |
-
)
|
| 358 |
|
| 359 |
# Persist assistant message
|
| 360 |
st.session_state.messages.append(
|
|
|
|
| 1 |
import os
|
| 2 |
import time
|
|
|
|
| 3 |
import requests
|
| 4 |
|
| 5 |
import streamlit as st
|
|
|
|
| 49 |
# =========================
|
| 50 |
@st.cache_resource
|
| 51 |
def get_langsmith():
|
|
|
|
| 52 |
if (os.getenv("LANGCHAIN_API_KEY") or os.getenv("LANGSMITH_API_KEY")) and LangSmithClient:
|
| 53 |
return LangSmithClient()
|
| 54 |
return None
|
|
|
|
| 57 |
ls_client = get_langsmith()
|
| 58 |
|
| 59 |
|
| 60 |
+
# =========================
|
| 61 |
+
# SAMPLE SEED (with full debug fields)
|
| 62 |
+
# =========================
|
| 63 |
+
SAMPLE_QUESTION = "who are you"
|
| 64 |
+
|
| 65 |
+
DEFAULT_SYSTEM_PROMPT = (
|
| 66 |
+
"You are Teapot, an open-source AI assistant optimized for running on low-end cpu devices, "
|
| 67 |
+
"providing short, accurate responses without hallucinating while excelling at "
|
| 68 |
+
"information extraction and text summarization. "
|
| 69 |
+
"If the context does not answer the question, reply exactly: "
|
| 70 |
+
"'I am sorry but I don't have any information on that'."
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
SAMPLE_SYSTEM_PROMPT = DEFAULT_SYSTEM_PROMPT
|
| 74 |
+
|
| 75 |
+
SAMPLE_CONTEXT = (
|
| 76 |
+
"Teapot is an open-source AI assistant optimized for running on low-end cpu devices."
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
SAMPLE_ANSWER = "I am Teapot, an open-source AI assistant optimized for running on low-end cpu devices."
|
| 80 |
+
SAMPLE_PROMPT = f"{SAMPLE_CONTEXT}\n{SAMPLE_SYSTEM_PROMPT}\n{SAMPLE_QUESTION}\n"
|
| 81 |
+
|
| 82 |
+
SAMPLE_USER_MSG = {"role": "user", "content": SAMPLE_QUESTION}
|
| 83 |
+
SAMPLE_ASSISTANT_MSG = {
|
| 84 |
+
"role": "assistant",
|
| 85 |
+
"content": SAMPLE_ANSWER,
|
| 86 |
+
"context": SAMPLE_CONTEXT,
|
| 87 |
+
"system_prompt": SAMPLE_SYSTEM_PROMPT,
|
| 88 |
+
"question": SAMPLE_QUESTION,
|
| 89 |
+
"prompt": SAMPLE_PROMPT,
|
| 90 |
+
"search_time": 0.37,
|
| 91 |
+
"gen_time": 0.67,
|
| 92 |
+
"input_tokens": 245,
|
| 93 |
+
"output_tokens": 24,
|
| 94 |
+
"tps": 35.9,
|
| 95 |
+
"trace_id": None,
|
| 96 |
+
"feedback": None,
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
# =========================
|
| 101 |
# SESSION STATE
|
| 102 |
# =========================
|
|
|
|
| 104 |
st.session_state.messages = []
|
| 105 |
if "needs_answer" not in st.session_state:
|
| 106 |
st.session_state.needs_answer = False
|
| 107 |
+
if "seeded" not in st.session_state:
|
| 108 |
+
st.session_state.seeded = False
|
| 109 |
+
|
| 110 |
+
# Seed exactly once on first load
|
| 111 |
+
if (not st.session_state.seeded) and (len(st.session_state.messages) == 0):
|
| 112 |
+
st.session_state.messages = [SAMPLE_USER_MSG, SAMPLE_ASSISTANT_MSG]
|
| 113 |
+
st.session_state.seeded = True
|
| 114 |
|
| 115 |
|
| 116 |
# =========================
|
| 117 |
# HEADER (prevent logo flash)
|
|
|
|
| 118 |
# =========================
|
| 119 |
col1, col2 = st.columns([1, 7], vertical_alignment="center")
|
| 120 |
with col1:
|
| 121 |
+
st.image(LOGO_URL, width=56)
|
| 122 |
with col2:
|
| 123 |
st.markdown("## TeapotAI Chat")
|
| 124 |
+
st.caption(
|
| 125 |
+
"Teapot is a 77M-parameter LLM optimized for fast CPU inference that only generates answers "
|
| 126 |
+
"from the provided context to minimize hallucinations."
|
| 127 |
+
)
|
| 128 |
|
| 129 |
|
| 130 |
# =========================
|
|
|
|
| 135 |
|
| 136 |
system_prompt = st.text_area(
|
| 137 |
"System prompt",
|
| 138 |
+
value=DEFAULT_SYSTEM_PROMPT,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
height=160,
|
| 140 |
)
|
| 141 |
|
|
|
|
| 145 |
placeholder="Extra context appended after web snippets…",
|
| 146 |
)
|
| 147 |
|
| 148 |
+
st.markdown("### Conversation")
|
| 149 |
+
c1, c2 = st.columns(2)
|
| 150 |
+
with c1:
|
| 151 |
+
if st.button("Load sample"):
|
| 152 |
+
st.session_state.messages = [SAMPLE_USER_MSG, SAMPLE_ASSISTANT_MSG]
|
| 153 |
+
st.session_state.needs_answer = False
|
| 154 |
+
st.session_state.seeded = True
|
| 155 |
+
st.rerun()
|
| 156 |
+
with c2:
|
| 157 |
+
if st.button("Clear chat"):
|
| 158 |
+
st.session_state.messages = []
|
| 159 |
+
st.session_state.needs_answer = False
|
| 160 |
+
st.session_state.seeded = True
|
| 161 |
+
st.rerun()
|
| 162 |
+
|
| 163 |
|
| 164 |
# =========================
|
| 165 |
# WEB SEARCH (ALWAYS ON)
|
|
|
|
| 218 |
|
| 219 |
# =========================
|
| 220 |
# LANGSMITH-TRACED ANSWER FUNCTION
|
|
|
|
| 221 |
# =========================
|
| 222 |
if traceable:
|
| 223 |
+
|
| 224 |
@traceable(name="teapot_answer")
|
| 225 |
def traced_answer(context: str, system_prompt: str, question: str) -> str:
|
| 226 |
prompt = f"{context}\n{system_prompt}\n{question}\n"
|
|
|
|
| 232 |
do_sample=False,
|
| 233 |
num_beams=1,
|
| 234 |
)
|
| 235 |
+
return tokenizer.decode(out[0], skip_special_tokens=True)
|
| 236 |
+
|
| 237 |
else:
|
| 238 |
+
|
| 239 |
def traced_answer(context: str, system_prompt: str, question: str) -> str:
|
| 240 |
prompt = f"{context}\n{system_prompt}\n{question}\n"
|
| 241 |
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
|
|
|
| 250 |
|
| 251 |
|
| 252 |
def get_trace_id_if_available() -> str | None:
|
|
|
|
| 253 |
if not get_current_run_tree:
|
| 254 |
return None
|
| 255 |
try:
|
|
|
|
| 273 |
if ls_client and trace_id:
|
| 274 |
score = 1 if val == "👍" else 0
|
| 275 |
try:
|
|
|
|
| 276 |
ls_client.create_feedback(
|
| 277 |
trace_id=trace_id,
|
| 278 |
key="thumb_rating",
|
|
|
|
| 285 |
|
| 286 |
# =========================
|
| 287 |
# RENDER HISTORY
|
| 288 |
+
# Row 1: message + feedback
|
| 289 |
+
# Row 2: inspect + debug metrics
|
| 290 |
# =========================
|
| 291 |
for i, msg in enumerate(st.session_state.messages):
|
| 292 |
with st.chat_message(msg["role"]):
|
|
|
|
| 294 |
st.markdown(msg["content"])
|
| 295 |
continue
|
| 296 |
|
| 297 |
+
# Row 1
|
| 298 |
+
msg_col, fb_col = st.columns([14, 1], vertical_alignment="center")
|
| 299 |
+
with msg_col:
|
| 300 |
+
st.markdown(msg.get("content", ""))
|
| 301 |
+
with fb_col:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
key = f"fb_{i}"
|
| 303 |
st.session_state.setdefault(key, msg.get("feedback"))
|
| 304 |
st.feedback(
|
|
|
|
| 309 |
args=(i,),
|
| 310 |
)
|
| 311 |
|
| 312 |
+
# Row 2
|
| 313 |
+
inspect_col, metrics_col = st.columns([1, 12], vertical_alignment="center")
|
| 314 |
|
| 315 |
+
with inspect_col:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
with st.popover("ℹ️", help="Inspect"):
|
| 317 |
st.markdown("**Context**")
|
| 318 |
st.code(msg.get("context", ""), language="text")
|
|
|
|
| 320 |
st.code(msg.get("system_prompt", ""), language="text")
|
| 321 |
st.markdown("**Question**")
|
| 322 |
st.code(msg.get("question", ""), language="text")
|
| 323 |
+
st.markdown("**Prompt**")
|
| 324 |
+
st.code(msg.get("prompt", ""), language="text")
|
| 325 |
|
| 326 |
+
with metrics_col:
|
| 327 |
+
st.caption(
|
| 328 |
+
f"🔎 {msg.get('search_time', 0.0):.2f}s (search) "
|
| 329 |
+
f"🧠 {msg.get('gen_time', 0.0):.2f}s (generation) "
|
| 330 |
+
f"⚡ {msg.get('tps', 0.0):.1f} tok/s "
|
| 331 |
+
f"🧾 {msg.get('input_tokens', 0)} input tokens • {msg.get('output_tokens', 0)} output tokens"
|
| 332 |
+
)
|
| 333 |
|
| 334 |
|
| 335 |
# =========================
|
|
|
|
| 361 |
prompt = f"{context}\n{system_prompt}\n{question}\n"
|
| 362 |
input_tokens = count_tokens(prompt)
|
| 363 |
|
| 364 |
+
# Run traced answer
|
| 365 |
with st.chat_message("assistant"):
|
| 366 |
+
# Row 1: message + feedback (feedback disabled until persisted)
|
| 367 |
+
msg_col, fb_col = st.columns([14, 1], vertical_alignment="center")
|
| 368 |
+
with msg_col:
|
| 369 |
+
placeholder = st.empty()
|
| 370 |
+
with fb_col:
|
| 371 |
+
st.feedback("thumbs", key="live_fb", disabled=True)
|
| 372 |
|
| 373 |
start = time.perf_counter()
|
|
|
|
|
|
|
|
|
|
| 374 |
answer = traced_answer(context, system_prompt, question)
|
| 375 |
trace_id = get_trace_id_if_available()
|
| 376 |
|
| 377 |
+
# Stream into the message column
|
| 378 |
buf = ""
|
| 379 |
for ch in answer:
|
| 380 |
buf += ch
|
| 381 |
placeholder.markdown(buf)
|
|
|
|
| 382 |
time.sleep(0.002)
|
| 383 |
|
| 384 |
gen_time = time.perf_counter() - start
|
| 385 |
output_tokens = count_tokens(answer)
|
| 386 |
tps = output_tokens / gen_time if gen_time > 0 else 0.0
|
| 387 |
|
| 388 |
+
# Row 2: inspect + metrics
|
| 389 |
+
inspect_col, metrics_col = st.columns([12, 1], vertical_alignment="center")
|
| 390 |
+
with inspect_col:
|
| 391 |
+
st.caption(
|
| 392 |
+
f"🔎 {search_time:.2f}s (search) "
|
| 393 |
+
f"🧠 {gen_time:.2f}s (generation) "
|
| 394 |
+
f"⚡ {tps:.1f} tok/s "
|
| 395 |
+
f"🧾 {input_tokens} input tokens • {output_tokens} output tokens"
|
| 396 |
+
)
|
| 397 |
+
with metrics_col:
|
| 398 |
with st.popover("ℹ️", help="Inspect"):
|
| 399 |
st.markdown("**Context**")
|
| 400 |
st.code(context, language="text")
|
|
|
|
| 404 |
st.code(question, language="text")
|
| 405 |
st.markdown("**Prompt**")
|
| 406 |
st.code(prompt, language="text")
|
| 407 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
|
| 409 |
# Persist assistant message
|
| 410 |
st.session_state.messages.append(
|