Spaces:
Running
Running
| import os | |
| import time | |
| import threading | |
| import requests | |
| import streamlit as st | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer | |
| # Optional LangSmith (trace + feedback) | |
| try: | |
| from langsmith import Client as LangSmithClient | |
| from langsmith import traceable | |
| from langsmith.run_helpers import get_current_run_tree | |
| except Exception: | |
| LangSmithClient = None | |
| traceable = None | |
| get_current_run_tree = None | |
| # ========================= | |
| # CONFIG | |
| # ========================= | |
| MODEL_NAME = "teapotai/tinyteapot" | |
| MAX_INPUT_TOKENS = 512 | |
| MAX_NEW_TOKENS = 192 | |
| TOP_K_SEARCH = 3 | |
| LOGO_URL = "https://teapotai.com/assets/logo.gif" | |
| st.set_page_config(page_title="TeapotAI Chat", page_icon="🫖", layout="centered") | |
| # ========================= | |
| # LOAD MODEL (CACHED) | |
| # ========================= | |
| def load_model(): | |
| tok = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| mdl = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
| dev = "cuda" if torch.cuda.is_available() else "cpu" | |
| mdl.to(dev).eval() | |
| return tok, mdl, dev | |
| tokenizer, model, device = load_model() | |
| # ========================= | |
| # LANGSMITH (OPTIONAL) | |
| # ========================= | |
| def get_langsmith(): | |
| if (os.getenv("LANGCHAIN_API_KEY") or os.getenv("LANGSMITH_API_KEY")) and LangSmithClient: | |
| return LangSmithClient() | |
| return None | |
| ls_client = get_langsmith() | |
| # ========================= | |
| # SAMPLE SEED (with full debug fields) | |
| # ========================= | |
| SAMPLE_QUESTION = "Who are you?" | |
| DEFAULT_SYSTEM_PROMPT = ( | |
| "You are Teapot, an open-source AI assistant optimized for running on low-end cpu devices, " | |
| "providing short, accurate responses without hallucinating while excelling at " | |
| "information extraction and text summarization. " | |
| "If the context does not answer the question, reply exactly: " | |
| "'I am sorry but I don't have any information on that'." | |
| ) | |
| SAMPLE_SYSTEM_PROMPT = DEFAULT_SYSTEM_PROMPT | |
| SAMPLE_CONTEXT = "Teapot is an open-source AI assistant optimized for running on low-end cpu devices." | |
| SAMPLE_ANSWER = "I am Teapot, an open-source AI assistant optimized for running on low-end cpu devices." | |
| SAMPLE_PROMPT = f"{SAMPLE_CONTEXT}\n{SAMPLE_SYSTEM_PROMPT}\n{SAMPLE_QUESTION}\n" | |
| SAMPLE_USER_MSG = {"role": "user", "content": SAMPLE_QUESTION} | |
| SAMPLE_ASSISTANT_MSG = { | |
| "role": "assistant", | |
| "content": SAMPLE_ANSWER, | |
| "context": SAMPLE_CONTEXT, | |
| "system_prompt": SAMPLE_SYSTEM_PROMPT, | |
| "question": SAMPLE_QUESTION, | |
| "prompt": SAMPLE_PROMPT, | |
| "search_time": 0.37, | |
| "gen_time": 0.67, | |
| "input_tokens": 245, | |
| "output_tokens": 24, | |
| "tps": 35.9, | |
| "trace_id": None, | |
| "feedback": None, | |
| } | |
| # ========================= | |
| # SESSION STATE | |
| # ========================= | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "seeded" not in st.session_state: | |
| st.session_state.seeded = False | |
| # Seed exactly once on first load | |
| if (not st.session_state.seeded) and (len(st.session_state.messages) == 0): | |
| st.session_state.messages = [SAMPLE_USER_MSG, SAMPLE_ASSISTANT_MSG] | |
| st.session_state.seeded = True | |
| # ========================= | |
| # HEADER | |
| # ========================= | |
| col1, col2 = st.columns([1, 7], vertical_alignment="center") | |
| with col1: | |
| st.image(LOGO_URL, width=56) | |
| with col2: | |
| st.markdown("## TeapotAI Chat") | |
| st.caption( | |
| "Teapot is a 77M-parameter LLM optimized for fast CPU inference that only generates answers " | |
| "from the provided context to minimize hallucinations." | |
| ) | |
| # ========================= | |
| # SIDEBAR | |
| # ========================= | |
| with st.sidebar: | |
| st.markdown("### Settings") | |
| system_prompt = st.text_area( | |
| "System prompt", | |
| value=DEFAULT_SYSTEM_PROMPT, | |
| height=160, | |
| ) | |
| local_context = st.text_area( | |
| "Local context (optional)", | |
| height=140, | |
| placeholder="Extra context appended after web snippets…", | |
| ) | |
| if st.button("Clear chat"): | |
| st.session_state.messages = [] | |
| st.session_state.seeded = True | |
| st.rerun() | |
| # ========================= | |
| # WEB SEARCH (ALWAYS ON) | |
| # ========================= | |
| def web_search_snippets(query: str): | |
| api_key = os.getenv("BRAVE_API_KEY") or st.secrets.get("BRAVE_API_KEY", None) | |
| if not api_key: | |
| return "", 0.0 | |
| headers = {"X-Subscription-Token": api_key, "Accept": "application/json"} | |
| params = {"q": query, "count": TOP_K_SEARCH} | |
| t0 = time.perf_counter() | |
| try: | |
| r = requests.get( | |
| "https://api.search.brave.com/res/v1/web/search", | |
| headers=headers, | |
| params=params, | |
| timeout=6, | |
| ) | |
| data = r.json() | |
| except Exception: | |
| return "", 0.0 | |
| t1 = time.perf_counter() | |
| snippets = [] | |
| for item in data.get("web", {}).get("results", [])[:TOP_K_SEARCH]: | |
| desc = (item.get("description") or "").replace("<strong>", "").replace("</strong>", "").strip() | |
| if desc: | |
| snippets.append(desc) | |
| return "\n\n".join(snippets), (t1 - t0) | |
| # ========================= | |
| # CONTEXT TRUNCATION (TAIL) | |
| # ========================= | |
| def truncate_context(web_ctx: str, local_ctx: str, system: str, question: str) -> str: | |
| ctx = f"{web_ctx}\n\n{local_ctx}".strip() | |
| base = f"\n{system}\n{question}\n" | |
| base_tokens = tokenizer.encode(base) | |
| budget = MAX_INPUT_TOKENS - len(base_tokens) | |
| if budget <= 0: | |
| return "" | |
| if not ctx: | |
| return "" | |
| ctx_tokens = tokenizer.encode(ctx) | |
| if len(ctx_tokens) <= budget: | |
| return ctx | |
| return tokenizer.decode(ctx_tokens[-budget:], skip_special_tokens=True) | |
| def count_tokens(text: str) -> int: | |
| return len(tokenizer.encode(text)) if text else 0 | |
| def get_trace_id_if_available() -> str | None: | |
| if not get_current_run_tree: | |
| return None | |
| try: | |
| run = get_current_run_tree() | |
| return str(run.id) if run and getattr(run, "id", None) else None | |
| except Exception: | |
| return None | |
| # ========================= | |
| # FEEDBACK HANDLER (attached to trace_id) | |
| # ========================= | |
| def handle_feedback(idx: int): | |
| val = st.session_state.get(f"fb_{idx}") | |
| st.session_state.messages[idx]["feedback"] = val | |
| msg = st.session_state.messages[idx] | |
| trace_id = msg.get("trace_id") | |
| if ls_client and trace_id: | |
| score = 1 if val == "👍" else 0 | |
| try: | |
| ls_client.create_feedback( | |
| trace_id=trace_id, | |
| key="thumb_rating", | |
| score=score, | |
| comment="thumbs_up" if score else "thumbs_down", | |
| ) | |
| except Exception: | |
| pass | |
| # ========================= | |
| # STREAMING + LANGSMITH FIX | |
| # - We do NOT trace a generator. | |
| # - We stream to UI while returning a SINGLE final string. | |
| # ========================= | |
| _UI_STREAM = {"placeholder": None} # set per-request | |
| def _generate_with_streamer(prompt: str) -> str: | |
| """ | |
| Runs model.generate with a TextIteratorStreamer and updates a Streamlit placeholder | |
| as chunks arrive. Returns the final full text. | |
| """ | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| streamer = TextIteratorStreamer( | |
| tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| ) | |
| gen_kwargs = dict( | |
| **inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=False, | |
| num_beams=1, | |
| streamer=streamer, | |
| ) | |
| t = threading.Thread(target=model.generate, kwargs=gen_kwargs, daemon=True) | |
| t.start() | |
| buf = "" | |
| ph = _UI_STREAM.get("placeholder") | |
| if ph is not None: | |
| ph.markdown("") # ensure element exists before first chunk | |
| for chunk in streamer: | |
| buf += chunk | |
| if ph is not None: | |
| ph.markdown(buf) | |
| return buf | |
| if traceable: | |
| def traced_answer_streaming(context: str, system_prompt: str, question: str) -> str: | |
| prompt = f"{context}\n{system_prompt}\n{question}\n" | |
| return _generate_with_streamer(prompt) | |
| else: | |
| def traced_answer_streaming(context: str, system_prompt: str, question: str) -> str: | |
| prompt = f"{context}\n{system_prompt}\n{question}\n" | |
| return _generate_with_streamer(prompt) | |
| # ========================= | |
| # INPUT FIRST (so latest user msg renders immediately) | |
| # ========================= | |
| query = st.chat_input("Ask a question...") | |
| if query: | |
| st.session_state.messages.append({"role": "user", "content": query}) | |
| # ========================= | |
| # RENDER HISTORY | |
| # Row 1: message + feedback | |
| # Row 2: inspect + debug metrics | |
| # ========================= | |
| for i, msg in enumerate(st.session_state.messages): | |
| with st.chat_message(msg["role"]): | |
| if msg["role"] == "user": | |
| st.markdown(msg["content"]) | |
| continue | |
| # Row 1 | |
| msg_col, fb_col = st.columns([12, 1], vertical_alignment="center") | |
| with msg_col: | |
| st.markdown(msg.get("content", "")) | |
| with fb_col: | |
| key = f"fb_{i}" | |
| st.session_state.setdefault(key, msg.get("feedback")) | |
| st.feedback( | |
| "thumbs", | |
| key=key, | |
| disabled=msg.get("feedback") is not None, | |
| on_change=handle_feedback, | |
| args=(i,), | |
| ) | |
| # Row 2 | |
| inspect_col, metrics_col = st.columns([12, 1], vertical_alignment="center") | |
| with inspect_col: | |
| st.caption( | |
| f"🔎 {msg.get('search_time', 0.0):.2f}s (search) " | |
| f"🧠 {msg.get('gen_time', 0.0):.2f}s (generation) " | |
| f"⚡ {msg.get('tps', 0.0):.1f} tok/s " | |
| f"🧾 {msg.get('input_tokens', 0)} input tokens • {msg.get('output_tokens', 0)} output tokens" | |
| ) | |
| with metrics_col: | |
| with st.popover("ℹ️", help="Inspect"): | |
| st.markdown("**Context**") | |
| st.code(msg.get("context", ""), language="text") | |
| st.markdown("**System Prompt**") | |
| st.code(msg.get("system_prompt", ""), language="text") | |
| st.markdown("**Question**") | |
| st.code(msg.get("question", ""), language="text") | |
| # ========================= | |
| # GENERATE ONLY IF THIS RUN RECEIVED A NEW QUERY | |
| # ========================= | |
| if query: | |
| question = query | |
| # Web search | |
| web_ctx, search_time = web_search_snippets(question) | |
| # Context + truncation | |
| context = truncate_context(web_ctx, local_context, system_prompt, question) | |
| prompt = f"{context}\n{system_prompt}\n{question}\n" | |
| input_tokens = count_tokens(prompt) | |
| # Assistant response (stream to UI, return full string for LangSmith) | |
| with st.chat_message("assistant"): | |
| # Row 1: message + feedback (disabled live) | |
| msg_col, fb_col = st.columns([14, 1], vertical_alignment="center") | |
| with msg_col: | |
| placeholder = st.empty() | |
| with fb_col: | |
| st.feedback("thumbs", key="live_fb", disabled=True) | |
| _UI_STREAM["placeholder"] = placeholder | |
| start = time.perf_counter() | |
| answer = traced_answer_streaming(context, system_prompt, question) | |
| trace_id = get_trace_id_if_available() | |
| gen_time = time.perf_counter() - start | |
| _UI_STREAM["placeholder"] = None # cleanup | |
| output_tokens = count_tokens(answer) | |
| tps = output_tokens / gen_time if gen_time > 0 else 0.0 | |
| # Row 2: inspect + metrics | |
| inspect_col, metrics_col = st.columns([12, 1], vertical_alignment="center") | |
| with inspect_col: | |
| st.caption( | |
| f"🔎 {search_time:.2f}s (search) " | |
| f"🧠 {gen_time:.2f}s (generation) " | |
| f"⚡ {tps:.1f} tok/s " | |
| f"🧾 {input_tokens} input tokens • {output_tokens} output tokens" | |
| ) | |
| with metrics_col: | |
| with st.popover("ℹ️", help="Inspect"): | |
| st.markdown("**Context**") | |
| st.code(context, language="text") | |
| st.markdown("**System**") | |
| st.code(system_prompt, language="text") | |
| st.markdown("**Question**") | |
| st.code(question, language="text") | |
| st.markdown("**Prompt**") | |
| st.code(prompt, language="text") | |
| # Persist assistant message | |
| st.session_state.messages.append( | |
| { | |
| "role": "assistant", | |
| "content": answer, | |
| "context": context, | |
| "system_prompt": system_prompt, | |
| "question": question, | |
| "prompt": prompt, | |
| "search_time": search_time, | |
| "gen_time": gen_time, | |
| "input_tokens": input_tokens, | |
| "output_tokens": output_tokens, | |
| "tps": tps, | |
| "trace_id": trace_id, | |
| "feedback": None, | |
| } | |
| ) |