tinyteapotchat / src /streamlit_app.py
zakerytclarke's picture
Update src/streamlit_app.py
9e36a10 verified
import os
import time
import threading
import requests
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
# Optional LangSmith (trace + feedback)
try:
from langsmith import Client as LangSmithClient
from langsmith import traceable
from langsmith.run_helpers import get_current_run_tree
except Exception:
LangSmithClient = None
traceable = None
get_current_run_tree = None
# =========================
# CONFIG
# =========================
MODEL_NAME = "teapotai/tinyteapot"
MAX_INPUT_TOKENS = 512
MAX_NEW_TOKENS = 192
TOP_K_SEARCH = 3
LOGO_URL = "https://teapotai.com/assets/logo.gif"
st.set_page_config(page_title="TeapotAI Chat", page_icon="🫖", layout="centered")
# =========================
# LOAD MODEL (CACHED)
# =========================
@st.cache_resource
def load_model():
tok = AutoTokenizer.from_pretrained(MODEL_NAME)
mdl = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
dev = "cuda" if torch.cuda.is_available() else "cpu"
mdl.to(dev).eval()
return tok, mdl, dev
tokenizer, model, device = load_model()
# =========================
# LANGSMITH (OPTIONAL)
# =========================
@st.cache_resource
def get_langsmith():
if (os.getenv("LANGCHAIN_API_KEY") or os.getenv("LANGSMITH_API_KEY")) and LangSmithClient:
return LangSmithClient()
return None
ls_client = get_langsmith()
# =========================
# SAMPLE SEED (with full debug fields)
# =========================
SAMPLE_QUESTION = "Who are you?"
DEFAULT_SYSTEM_PROMPT = (
"You are Teapot, an open-source AI assistant optimized for running on low-end cpu devices, "
"providing short, accurate responses without hallucinating while excelling at "
"information extraction and text summarization. "
"If the context does not answer the question, reply exactly: "
"'I am sorry but I don't have any information on that'."
)
SAMPLE_SYSTEM_PROMPT = DEFAULT_SYSTEM_PROMPT
SAMPLE_CONTEXT = "Teapot is an open-source AI assistant optimized for running on low-end cpu devices."
SAMPLE_ANSWER = "I am Teapot, an open-source AI assistant optimized for running on low-end cpu devices."
SAMPLE_PROMPT = f"{SAMPLE_CONTEXT}\n{SAMPLE_SYSTEM_PROMPT}\n{SAMPLE_QUESTION}\n"
SAMPLE_USER_MSG = {"role": "user", "content": SAMPLE_QUESTION}
SAMPLE_ASSISTANT_MSG = {
"role": "assistant",
"content": SAMPLE_ANSWER,
"context": SAMPLE_CONTEXT,
"system_prompt": SAMPLE_SYSTEM_PROMPT,
"question": SAMPLE_QUESTION,
"prompt": SAMPLE_PROMPT,
"search_time": 0.37,
"gen_time": 0.67,
"input_tokens": 245,
"output_tokens": 24,
"tps": 35.9,
"trace_id": None,
"feedback": None,
}
# =========================
# SESSION STATE
# =========================
if "messages" not in st.session_state:
st.session_state.messages = []
if "seeded" not in st.session_state:
st.session_state.seeded = False
# Seed exactly once on first load
if (not st.session_state.seeded) and (len(st.session_state.messages) == 0):
st.session_state.messages = [SAMPLE_USER_MSG, SAMPLE_ASSISTANT_MSG]
st.session_state.seeded = True
# =========================
# HEADER
# =========================
col1, col2 = st.columns([1, 7], vertical_alignment="center")
with col1:
st.image(LOGO_URL, width=56)
with col2:
st.markdown("## TeapotAI Chat")
st.caption(
"Teapot is a 77M-parameter LLM optimized for fast CPU inference that only generates answers "
"from the provided context to minimize hallucinations."
)
# =========================
# SIDEBAR
# =========================
with st.sidebar:
st.markdown("### Settings")
system_prompt = st.text_area(
"System prompt",
value=DEFAULT_SYSTEM_PROMPT,
height=160,
)
local_context = st.text_area(
"Local context (optional)",
height=140,
placeholder="Extra context appended after web snippets…",
)
if st.button("Clear chat"):
st.session_state.messages = []
st.session_state.seeded = True
st.rerun()
# =========================
# WEB SEARCH (ALWAYS ON)
# =========================
def web_search_snippets(query: str):
api_key = os.getenv("BRAVE_API_KEY") or st.secrets.get("BRAVE_API_KEY", None)
if not api_key:
return "", 0.0
headers = {"X-Subscription-Token": api_key, "Accept": "application/json"}
params = {"q": query, "count": TOP_K_SEARCH}
t0 = time.perf_counter()
try:
r = requests.get(
"https://api.search.brave.com/res/v1/web/search",
headers=headers,
params=params,
timeout=6,
)
data = r.json()
except Exception:
return "", 0.0
t1 = time.perf_counter()
snippets = []
for item in data.get("web", {}).get("results", [])[:TOP_K_SEARCH]:
desc = (item.get("description") or "").replace("<strong>", "").replace("</strong>", "").strip()
if desc:
snippets.append(desc)
return "\n\n".join(snippets), (t1 - t0)
# =========================
# CONTEXT TRUNCATION (TAIL)
# =========================
def truncate_context(web_ctx: str, local_ctx: str, system: str, question: str) -> str:
ctx = f"{web_ctx}\n\n{local_ctx}".strip()
base = f"\n{system}\n{question}\n"
base_tokens = tokenizer.encode(base)
budget = MAX_INPUT_TOKENS - len(base_tokens)
if budget <= 0:
return ""
if not ctx:
return ""
ctx_tokens = tokenizer.encode(ctx)
if len(ctx_tokens) <= budget:
return ctx
return tokenizer.decode(ctx_tokens[-budget:], skip_special_tokens=True)
def count_tokens(text: str) -> int:
return len(tokenizer.encode(text)) if text else 0
def get_trace_id_if_available() -> str | None:
if not get_current_run_tree:
return None
try:
run = get_current_run_tree()
return str(run.id) if run and getattr(run, "id", None) else None
except Exception:
return None
# =========================
# FEEDBACK HANDLER (attached to trace_id)
# =========================
def handle_feedback(idx: int):
val = st.session_state.get(f"fb_{idx}")
st.session_state.messages[idx]["feedback"] = val
msg = st.session_state.messages[idx]
trace_id = msg.get("trace_id")
if ls_client and trace_id:
score = 1 if val == "👍" else 0
try:
ls_client.create_feedback(
trace_id=trace_id,
key="thumb_rating",
score=score,
comment="thumbs_up" if score else "thumbs_down",
)
except Exception:
pass
# =========================
# STREAMING + LANGSMITH FIX
# - We do NOT trace a generator.
# - We stream to UI while returning a SINGLE final string.
# =========================
_UI_STREAM = {"placeholder": None} # set per-request
def _generate_with_streamer(prompt: str) -> str:
"""
Runs model.generate with a TextIteratorStreamer and updates a Streamlit placeholder
as chunks arrive. Returns the final full text.
"""
inputs = tokenizer(prompt, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
gen_kwargs = dict(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
num_beams=1,
streamer=streamer,
)
t = threading.Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
t.start()
buf = ""
ph = _UI_STREAM.get("placeholder")
if ph is not None:
ph.markdown("") # ensure element exists before first chunk
for chunk in streamer:
buf += chunk
if ph is not None:
ph.markdown(buf)
return buf
if traceable:
@traceable(name="teapot_answer")
def traced_answer_streaming(context: str, system_prompt: str, question: str) -> str:
prompt = f"{context}\n{system_prompt}\n{question}\n"
return _generate_with_streamer(prompt)
else:
def traced_answer_streaming(context: str, system_prompt: str, question: str) -> str:
prompt = f"{context}\n{system_prompt}\n{question}\n"
return _generate_with_streamer(prompt)
# =========================
# INPUT FIRST (so latest user msg renders immediately)
# =========================
query = st.chat_input("Ask a question...")
if query:
st.session_state.messages.append({"role": "user", "content": query})
# =========================
# RENDER HISTORY
# Row 1: message + feedback
# Row 2: inspect + debug metrics
# =========================
for i, msg in enumerate(st.session_state.messages):
with st.chat_message(msg["role"]):
if msg["role"] == "user":
st.markdown(msg["content"])
continue
# Row 1
msg_col, fb_col = st.columns([12, 1], vertical_alignment="center")
with msg_col:
st.markdown(msg.get("content", ""))
with fb_col:
key = f"fb_{i}"
st.session_state.setdefault(key, msg.get("feedback"))
st.feedback(
"thumbs",
key=key,
disabled=msg.get("feedback") is not None,
on_change=handle_feedback,
args=(i,),
)
# Row 2
inspect_col, metrics_col = st.columns([12, 1], vertical_alignment="center")
with inspect_col:
st.caption(
f"🔎 {msg.get('search_time', 0.0):.2f}s (search) "
f"🧠 {msg.get('gen_time', 0.0):.2f}s (generation) "
f"⚡ {msg.get('tps', 0.0):.1f} tok/s "
f"🧾 {msg.get('input_tokens', 0)} input tokens • {msg.get('output_tokens', 0)} output tokens"
)
with metrics_col:
with st.popover("ℹ️", help="Inspect"):
st.markdown("**Context**")
st.code(msg.get("context", ""), language="text")
st.markdown("**System Prompt**")
st.code(msg.get("system_prompt", ""), language="text")
st.markdown("**Question**")
st.code(msg.get("question", ""), language="text")
# =========================
# GENERATE ONLY IF THIS RUN RECEIVED A NEW QUERY
# =========================
if query:
question = query
# Web search
web_ctx, search_time = web_search_snippets(question)
# Context + truncation
context = truncate_context(web_ctx, local_context, system_prompt, question)
prompt = f"{context}\n{system_prompt}\n{question}\n"
input_tokens = count_tokens(prompt)
# Assistant response (stream to UI, return full string for LangSmith)
with st.chat_message("assistant"):
# Row 1: message + feedback (disabled live)
msg_col, fb_col = st.columns([14, 1], vertical_alignment="center")
with msg_col:
placeholder = st.empty()
with fb_col:
st.feedback("thumbs", key="live_fb", disabled=True)
_UI_STREAM["placeholder"] = placeholder
start = time.perf_counter()
answer = traced_answer_streaming(context, system_prompt, question)
trace_id = get_trace_id_if_available()
gen_time = time.perf_counter() - start
_UI_STREAM["placeholder"] = None # cleanup
output_tokens = count_tokens(answer)
tps = output_tokens / gen_time if gen_time > 0 else 0.0
# Row 2: inspect + metrics
inspect_col, metrics_col = st.columns([12, 1], vertical_alignment="center")
with inspect_col:
st.caption(
f"🔎 {search_time:.2f}s (search) "
f"🧠 {gen_time:.2f}s (generation) "
f"⚡ {tps:.1f} tok/s "
f"🧾 {input_tokens} input tokens • {output_tokens} output tokens"
)
with metrics_col:
with st.popover("ℹ️", help="Inspect"):
st.markdown("**Context**")
st.code(context, language="text")
st.markdown("**System**")
st.code(system_prompt, language="text")
st.markdown("**Question**")
st.code(question, language="text")
st.markdown("**Prompt**")
st.code(prompt, language="text")
# Persist assistant message
st.session_state.messages.append(
{
"role": "assistant",
"content": answer,
"context": context,
"system_prompt": system_prompt,
"question": question,
"prompt": prompt,
"search_time": search_time,
"gen_time": gen_time,
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"tps": tps,
"trace_id": trace_id,
"feedback": None,
}
)