import os import time import threading import requests import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer # Optional LangSmith (trace + feedback) try: from langsmith import Client as LangSmithClient from langsmith import traceable from langsmith.run_helpers import get_current_run_tree except Exception: LangSmithClient = None traceable = None get_current_run_tree = None # ========================= # CONFIG # ========================= MODEL_NAME = "teapotai/tinyteapot" MAX_INPUT_TOKENS = 512 MAX_NEW_TOKENS = 192 TOP_K_SEARCH = 3 LOGO_URL = "https://teapotai.com/assets/logo.gif" st.set_page_config(page_title="TeapotAI Chat", page_icon="🫖", layout="centered") # ========================= # LOAD MODEL (CACHED) # ========================= @st.cache_resource def load_model(): tok = AutoTokenizer.from_pretrained(MODEL_NAME) mdl = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) dev = "cuda" if torch.cuda.is_available() else "cpu" mdl.to(dev).eval() return tok, mdl, dev tokenizer, model, device = load_model() # ========================= # LANGSMITH (OPTIONAL) # ========================= @st.cache_resource def get_langsmith(): if (os.getenv("LANGCHAIN_API_KEY") or os.getenv("LANGSMITH_API_KEY")) and LangSmithClient: return LangSmithClient() return None ls_client = get_langsmith() # ========================= # SAMPLE SEED (with full debug fields) # ========================= SAMPLE_QUESTION = "Who are you?" DEFAULT_SYSTEM_PROMPT = ( "You are Teapot, an open-source AI assistant optimized for running on low-end cpu devices, " "providing short, accurate responses without hallucinating while excelling at " "information extraction and text summarization. " "If the context does not answer the question, reply exactly: " "'I am sorry but I don't have any information on that'." ) SAMPLE_SYSTEM_PROMPT = DEFAULT_SYSTEM_PROMPT SAMPLE_CONTEXT = "Teapot is an open-source AI assistant optimized for running on low-end cpu devices." SAMPLE_ANSWER = "I am Teapot, an open-source AI assistant optimized for running on low-end cpu devices." SAMPLE_PROMPT = f"{SAMPLE_CONTEXT}\n{SAMPLE_SYSTEM_PROMPT}\n{SAMPLE_QUESTION}\n" SAMPLE_USER_MSG = {"role": "user", "content": SAMPLE_QUESTION} SAMPLE_ASSISTANT_MSG = { "role": "assistant", "content": SAMPLE_ANSWER, "context": SAMPLE_CONTEXT, "system_prompt": SAMPLE_SYSTEM_PROMPT, "question": SAMPLE_QUESTION, "prompt": SAMPLE_PROMPT, "search_time": 0.37, "gen_time": 0.67, "input_tokens": 245, "output_tokens": 24, "tps": 35.9, "trace_id": None, "feedback": None, } # ========================= # SESSION STATE # ========================= if "messages" not in st.session_state: st.session_state.messages = [] if "seeded" not in st.session_state: st.session_state.seeded = False # Seed exactly once on first load if (not st.session_state.seeded) and (len(st.session_state.messages) == 0): st.session_state.messages = [SAMPLE_USER_MSG, SAMPLE_ASSISTANT_MSG] st.session_state.seeded = True # ========================= # HEADER # ========================= col1, col2 = st.columns([1, 7], vertical_alignment="center") with col1: st.image(LOGO_URL, width=56) with col2: st.markdown("## TeapotAI Chat") st.caption( "Teapot is a 77M-parameter LLM optimized for fast CPU inference that only generates answers " "from the provided context to minimize hallucinations." ) # ========================= # SIDEBAR # ========================= with st.sidebar: st.markdown("### Settings") system_prompt = st.text_area( "System prompt", value=DEFAULT_SYSTEM_PROMPT, height=160, ) local_context = st.text_area( "Local context (optional)", height=140, placeholder="Extra context appended after web snippets…", ) if st.button("Clear chat"): st.session_state.messages = [] st.session_state.seeded = True st.rerun() # ========================= # WEB SEARCH (ALWAYS ON) # ========================= def web_search_snippets(query: str): api_key = os.getenv("BRAVE_API_KEY") or st.secrets.get("BRAVE_API_KEY", None) if not api_key: return "", 0.0 headers = {"X-Subscription-Token": api_key, "Accept": "application/json"} params = {"q": query, "count": TOP_K_SEARCH} t0 = time.perf_counter() try: r = requests.get( "https://api.search.brave.com/res/v1/web/search", headers=headers, params=params, timeout=6, ) data = r.json() except Exception: return "", 0.0 t1 = time.perf_counter() snippets = [] for item in data.get("web", {}).get("results", [])[:TOP_K_SEARCH]: desc = (item.get("description") or "").replace("", "").replace("", "").strip() if desc: snippets.append(desc) return "\n\n".join(snippets), (t1 - t0) # ========================= # CONTEXT TRUNCATION (TAIL) # ========================= def truncate_context(web_ctx: str, local_ctx: str, system: str, question: str) -> str: ctx = f"{web_ctx}\n\n{local_ctx}".strip() base = f"\n{system}\n{question}\n" base_tokens = tokenizer.encode(base) budget = MAX_INPUT_TOKENS - len(base_tokens) if budget <= 0: return "" if not ctx: return "" ctx_tokens = tokenizer.encode(ctx) if len(ctx_tokens) <= budget: return ctx return tokenizer.decode(ctx_tokens[-budget:], skip_special_tokens=True) def count_tokens(text: str) -> int: return len(tokenizer.encode(text)) if text else 0 def get_trace_id_if_available() -> str | None: if not get_current_run_tree: return None try: run = get_current_run_tree() return str(run.id) if run and getattr(run, "id", None) else None except Exception: return None # ========================= # FEEDBACK HANDLER (attached to trace_id) # ========================= def handle_feedback(idx: int): val = st.session_state.get(f"fb_{idx}") st.session_state.messages[idx]["feedback"] = val msg = st.session_state.messages[idx] trace_id = msg.get("trace_id") if ls_client and trace_id: score = 1 if val == "👍" else 0 try: ls_client.create_feedback( trace_id=trace_id, key="thumb_rating", score=score, comment="thumbs_up" if score else "thumbs_down", ) except Exception: pass # ========================= # STREAMING + LANGSMITH FIX # - We do NOT trace a generator. # - We stream to UI while returning a SINGLE final string. # ========================= _UI_STREAM = {"placeholder": None} # set per-request def _generate_with_streamer(prompt: str) -> str: """ Runs model.generate with a TextIteratorStreamer and updates a Streamlit placeholder as chunks arrive. Returns the final full text. """ inputs = tokenizer(prompt, return_tensors="pt").to(device) streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True, ) gen_kwargs = dict( **inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, num_beams=1, streamer=streamer, ) t = threading.Thread(target=model.generate, kwargs=gen_kwargs, daemon=True) t.start() buf = "" ph = _UI_STREAM.get("placeholder") if ph is not None: ph.markdown("") # ensure element exists before first chunk for chunk in streamer: buf += chunk if ph is not None: ph.markdown(buf) return buf if traceable: @traceable(name="teapot_answer") def traced_answer_streaming(context: str, system_prompt: str, question: str) -> str: prompt = f"{context}\n{system_prompt}\n{question}\n" return _generate_with_streamer(prompt) else: def traced_answer_streaming(context: str, system_prompt: str, question: str) -> str: prompt = f"{context}\n{system_prompt}\n{question}\n" return _generate_with_streamer(prompt) # ========================= # INPUT FIRST (so latest user msg renders immediately) # ========================= query = st.chat_input("Ask a question...") if query: st.session_state.messages.append({"role": "user", "content": query}) # ========================= # RENDER HISTORY # Row 1: message + feedback # Row 2: inspect + debug metrics # ========================= for i, msg in enumerate(st.session_state.messages): with st.chat_message(msg["role"]): if msg["role"] == "user": st.markdown(msg["content"]) continue # Row 1 msg_col, fb_col = st.columns([12, 1], vertical_alignment="center") with msg_col: st.markdown(msg.get("content", "")) with fb_col: key = f"fb_{i}" st.session_state.setdefault(key, msg.get("feedback")) st.feedback( "thumbs", key=key, disabled=msg.get("feedback") is not None, on_change=handle_feedback, args=(i,), ) # Row 2 inspect_col, metrics_col = st.columns([12, 1], vertical_alignment="center") with inspect_col: st.caption( f"🔎 {msg.get('search_time', 0.0):.2f}s (search) " f"🧠 {msg.get('gen_time', 0.0):.2f}s (generation) " f"⚡ {msg.get('tps', 0.0):.1f} tok/s " f"🧾 {msg.get('input_tokens', 0)} input tokens • {msg.get('output_tokens', 0)} output tokens" ) with metrics_col: with st.popover("ℹ️", help="Inspect"): st.markdown("**Context**") st.code(msg.get("context", ""), language="text") st.markdown("**System Prompt**") st.code(msg.get("system_prompt", ""), language="text") st.markdown("**Question**") st.code(msg.get("question", ""), language="text") # ========================= # GENERATE ONLY IF THIS RUN RECEIVED A NEW QUERY # ========================= if query: question = query # Web search web_ctx, search_time = web_search_snippets(question) # Context + truncation context = truncate_context(web_ctx, local_context, system_prompt, question) prompt = f"{context}\n{system_prompt}\n{question}\n" input_tokens = count_tokens(prompt) # Assistant response (stream to UI, return full string for LangSmith) with st.chat_message("assistant"): # Row 1: message + feedback (disabled live) msg_col, fb_col = st.columns([14, 1], vertical_alignment="center") with msg_col: placeholder = st.empty() with fb_col: st.feedback("thumbs", key="live_fb", disabled=True) _UI_STREAM["placeholder"] = placeholder start = time.perf_counter() answer = traced_answer_streaming(context, system_prompt, question) trace_id = get_trace_id_if_available() gen_time = time.perf_counter() - start _UI_STREAM["placeholder"] = None # cleanup output_tokens = count_tokens(answer) tps = output_tokens / gen_time if gen_time > 0 else 0.0 # Row 2: inspect + metrics inspect_col, metrics_col = st.columns([12, 1], vertical_alignment="center") with inspect_col: st.caption( f"🔎 {search_time:.2f}s (search) " f"🧠 {gen_time:.2f}s (generation) " f"⚡ {tps:.1f} tok/s " f"🧾 {input_tokens} input tokens • {output_tokens} output tokens" ) with metrics_col: with st.popover("ℹ️", help="Inspect"): st.markdown("**Context**") st.code(context, language="text") st.markdown("**System**") st.code(system_prompt, language="text") st.markdown("**Question**") st.code(question, language="text") st.markdown("**Prompt**") st.code(prompt, language="text") # Persist assistant message st.session_state.messages.append( { "role": "assistant", "content": answer, "context": context, "system_prompt": system_prompt, "question": question, "prompt": prompt, "search_time": search_time, "gen_time": gen_time, "input_tokens": input_tokens, "output_tokens": output_tokens, "tps": tps, "trace_id": trace_id, "feedback": None, } )