zakerytclarke commited on
Commit
2c84c9e
·
verified ·
1 Parent(s): 0b7fcb4

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +167 -99
src/streamlit_app.py CHANGED
@@ -1,49 +1,112 @@
1
  import os
2
  import re
3
- import threading
4
  from typing import List, Dict
5
 
6
  import requests
7
  import streamlit as st
8
  import torch
9
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  # -----------------------
13
  # Config
14
  # -----------------------
15
- MODEL_NAME = "teapotai/tinyteapot"
16
  BRAVE_ENDPOINT = "https://api.search.brave.com/res/v1/web/search"
17
  TOP_K = 3
18
  TIMEOUT_SECS = 15
19
 
 
 
20
 
21
- # -----------------------
22
- # Model load (cached)
23
- # -----------------------
24
- @st.cache_resource
25
- def load_model():
26
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
27
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
28
 
29
- device = "cuda" if torch.cuda.is_available() else "cpu"
30
- model.to(device)
31
- model.eval()
32
- return tokenizer, model, device
 
 
 
 
 
 
 
33
 
34
 
35
  # -----------------------
36
- # Brave Search
37
  # -----------------------
 
 
 
 
 
38
  def brave_search_snippets(query: str, top_k: int = 3) -> List[Dict[str, str]]:
39
- brave_api_key = os.getenv("BRAVE_API_KEY")
40
  if not brave_api_key:
41
- raise RuntimeError("Missing BRAVE_API_KEY env var.")
42
 
43
- headers = {
44
- "Accept": "application/json",
45
- "X-Subscription-Token": brave_api_key,
46
- }
47
  params = {"q": query, "count": top_k}
48
 
49
  resp = requests.get(
@@ -69,8 +132,7 @@ def brave_search_snippets(query: str, top_k: int = 3) -> List[Dict[str, str]]:
69
 
70
  def format_context_from_results(results: List[Dict[str, str]]) -> str:
71
  """
72
- Stable, explicit formatting. If you want it to match some *other* exact template,
73
- change only this function.
74
  """
75
  if not results:
76
  return ""
@@ -81,6 +143,10 @@ def format_context_from_results(results: List[Dict[str, str]]) -> str:
81
  url = re.sub(r"\s+", " ", r.get("url", "")).strip()
82
  snippet = re.sub(r"\s+", " ", r.get("snippet", "")).strip()
83
 
 
 
 
 
84
  blocks.append(
85
  f"[{i}] {title}\n"
86
  f"URL: {url}\n"
@@ -90,98 +156,109 @@ def format_context_from_results(results: List[Dict[str, str]]) -> str:
90
 
91
 
92
  # -----------------------
93
- # TinyTeapot generation (streaming)
94
  # -----------------------
95
- def build_prompt(context: str, system_prompt: str, question: str) -> str:
96
- # EXACTLY your format: context + system_prompt + question
97
- return f"{context}\n{system_prompt}\n{question}\n"
98
-
99
-
100
- def stream_generate(tokenizer, model, device, prompt: str, max_new_tokens: int, temperature: float, top_p: float):
101
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
102
-
103
- do_sample = float(temperature) > 0.0
104
- gen_kwargs = dict(
105
- **inputs,
106
- max_new_tokens=int(max_new_tokens),
107
- do_sample=do_sample,
108
- temperature=float(temperature) if do_sample else None,
109
- top_p=float(top_p) if do_sample else None,
110
- num_beams=1,
111
- )
112
 
113
- # Transformers streamer: yields decoded text pieces as generation proceeds
114
- streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
 
115
 
116
- def _run():
117
- # Remove None args (generate doesn't like None for some models)
118
- clean_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None}
119
- model.generate(**clean_kwargs, streamer=streamer)
 
 
120
 
121
- t = threading.Thread(target=_run, daemon=True)
122
- t.start()
123
 
124
- partial = ""
125
- for piece in streamer:
126
- partial += piece
127
- yield partial
 
 
 
 
 
 
 
 
 
128
 
129
 
130
  # -----------------------
131
- # Streamlit UI
132
  # -----------------------
133
- st.set_page_config(page_title="TinyTeapot + Brave Search", page_icon="🫖", layout="centered")
134
-
135
- st.title("🫖 TinyTeapot + Brave Search (Top 3)")
136
-
137
- default_system_prompt = (
138
- "You are Teapot, an open-source AI assistant optimized for low-end devices, "
139
- "providing short, accurate responses without hallucinating while excelling at "
140
- "information extraction and text summarization. "
141
- "If the context does not answer the question, reply exactly: "
142
- "'I am sorry but I don't have any information on that'."
143
- )
144
 
145
  with st.sidebar:
146
  st.header("Settings")
147
- system_prompt = st.text_area("System prompt", value=default_system_prompt, height=140)
148
- max_new_tokens = st.slider("Max new tokens", 1, 512, 128, 1)
149
- temperature = st.slider("Temperature (0 = greedy)", 0.0, 2.0, 0.0, 0.1)
150
- top_p = st.slider("Top-p", 0.1, 1.0, 0.95, 0.05)
 
 
 
 
 
151
  show_sources = st.checkbox("Show sources/context", value=True)
152
 
153
- # Session state for chat history
 
 
 
 
 
 
 
 
 
 
154
  if "messages" not in st.session_state:
155
- st.session_state.messages = [] # list of {"role": "user"/"assistant", "content": str}
156
 
157
- # Render chat history
158
  for m in st.session_state.messages:
159
  with st.chat_message(m["role"]):
160
  st.markdown(m["content"])
161
 
162
- # Chat input
163
- question = st.chat_input("Ask a question (the app will Brave-search top 3 snippets)…")
164
 
165
  if question:
166
- # Add user message
167
  st.session_state.messages.append({"role": "user", "content": question})
168
  with st.chat_message("user"):
169
  st.markdown(question)
170
 
171
- tokenizer, model, device = load_model()
172
-
173
- # Get Brave context
174
  try:
175
  results = brave_search_snippets(question, top_k=TOP_K)
176
  context = format_context_from_results(results)
177
- except Exception as e:
178
- # If Brave fails, keep context empty so your system prompt triggers the exact refusal.
179
- context = ""
180
  results = []
181
- # You can uncomment this if you want to show the error:
182
- # st.warning(f"Brave Search failed: {e}")
183
 
184
- prompt = build_prompt(context=context, system_prompt=system_prompt, question=question)
 
 
 
 
 
 
185
 
186
  with st.chat_message("assistant"):
187
  if show_sources:
@@ -192,18 +269,9 @@ if question:
192
  st.write("(No search context returned.)")
193
 
194
  placeholder = st.empty()
195
- final = ""
196
-
197
- for partial in stream_generate(
198
- tokenizer=tokenizer,
199
- model=model,
200
- device=device,
201
- prompt=prompt,
202
- max_new_tokens=max_new_tokens,
203
- temperature=temperature,
204
- top_p=top_p,
205
- ):
206
- final = partial
207
- placeholder.markdown(final)
208
-
209
- st.session_state.messages.append({"role": "assistant", "content": final})
 
1
  import os
2
  import re
3
+ import time
4
  from typing import List, Dict
5
 
6
  import requests
7
  import streamlit as st
8
  import torch
9
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
10
+
11
+ from teapotai import TeapotAI
12
+
13
+
14
+ # -----------------------
15
+ # Branding / Theme
16
+ # -----------------------
17
+ TEAPOT_LOGO_GIF = "https://teapotai.com/assets/logo.gif"
18
+
19
+ TEA_BG = "#fbf7ef" # warm off-white
20
+ TEA_PANEL = "#fffaf2" # slightly brighter
21
+ TEA_TEXT = "#1f2937" # slate-ish
22
+ TEA_MUTED = "#6b7280" # gray
23
+ TEA_ACCENT = "#c0841d" # warm amber
24
+ TEA_BORDER = "rgba(31, 41, 55, 0.10)"
25
+
26
+ st.set_page_config(
27
+ page_title="TeapotAI Chat",
28
+ page_icon="🫖",
29
+ layout="centered",
30
+ )
31
+
32
+ CUSTOM_CSS = f"""
33
+ <style>
34
+ /* App background */
35
+ .stApp {{
36
+ background: {TEA_BG};
37
+ color: {TEA_TEXT};
38
+ }}
39
+
40
+ /* Sidebar */
41
+ section[data-testid="stSidebar"] {{
42
+ background: {TEA_PANEL};
43
+ border-right: 1px solid {TEA_BORDER};
44
+ }}
45
+
46
+ /* Chat bubbles */
47
+ div[data-testid="stChatMessage"] {{
48
+ border-radius: 16px;
49
+ padding: 8px 10px;
50
+ }}
51
+
52
+ /* Inputs */
53
+ .stTextInput > div > div, .stTextArea > div > div {{
54
+ border-radius: 12px !important;
55
+ }}
56
+
57
+ /* Buttons */
58
+ .stButton button {{
59
+ border-radius: 12px;
60
+ border: 1px solid {TEA_BORDER};
61
+ }}
62
+
63
+ /* Accent-ish links */
64
+ a {{
65
+ color: {TEA_ACCENT} !important;
66
+ }}
67
+ </style>
68
+ """
69
+ st.markdown(CUSTOM_CSS, unsafe_allow_html=True)
70
 
71
 
72
  # -----------------------
73
  # Config
74
  # -----------------------
 
75
  BRAVE_ENDPOINT = "https://api.search.brave.com/res/v1/web/search"
76
  TOP_K = 3
77
  TIMEOUT_SECS = 15
78
 
79
+ MODEL_TINY = "teapotai/tinyteapot"
80
+ MODEL_LLM = "teapotai/teapotllm"
81
 
 
 
 
 
 
 
 
82
 
83
+ DEFAULT_SYSTEM_PROMPT = (
84
+ "You are Teapot, an open-source AI assistant optimized for low-end devices, "
85
+ "providing short, accurate responses without hallucinating while excelling at "
86
+ "information extraction and text summarization. "
87
+ "If the context does not answer the question, reply exactly: "
88
+ "'I am sorry but I don't have any information on that'."
89
+ )
90
+
91
+ DEFAULT_DOCUMENTS = [
92
+ """Teapot (Tiny Teapot) is an open-source small language model (~77 million parameters) fine-tuned on synthetic data and optimized to run locally on resource-constrained devices such as smartphones and CPUs. Teapot is trained to only answer using context from documents, reducing hallucinations. Teapot can perform a variety of tasks, including hallucination-resistant Question Answering (QnA), Retrieval-Augmented Generation (RAG), and JSON extraction. TeapotLLM is a fine tune of flan-t5-large that was trained on synthetic data generated by Deepseek v3 TeapotLLM can be hosted on low-power devices with as little as 2GB of CPU RAM such as a Raspberry Pi. Teapot is a model built by and for the community."""
93
+ ]
94
 
95
 
96
  # -----------------------
97
+ # Helpers
98
  # -----------------------
99
+ def get_brave_key() -> str:
100
+ # Streamlit Cloud secrets support + local env var support
101
+ return st.secrets.get("BRAVE_API_KEY") if hasattr(st, "secrets") and "BRAVE_API_KEY" in st.secrets else os.getenv("BRAVE_API_KEY")
102
+
103
+
104
  def brave_search_snippets(query: str, top_k: int = 3) -> List[Dict[str, str]]:
105
+ brave_api_key = get_brave_key()
106
  if not brave_api_key:
107
+ raise RuntimeError("Missing BRAVE_API_KEY (set env var or Streamlit secrets).")
108
 
109
+ headers = {"Accept": "application/json", "X-Subscription-Token": brave_api_key}
 
 
 
110
  params = {"q": query, "count": top_k}
111
 
112
  resp = requests.get(
 
132
 
133
  def format_context_from_results(results: List[Dict[str, str]]) -> str:
134
  """
135
+ Stable formatting; plus you asked to strip <strong> tags.
 
136
  """
137
  if not results:
138
  return ""
 
143
  url = re.sub(r"\s+", " ", r.get("url", "")).strip()
144
  snippet = re.sub(r"\s+", " ", r.get("snippet", "")).strip()
145
 
146
+ # strip <strong> tags specifically, as requested
147
+ title = title.replace("<strong>", "").replace("</strong>", "")
148
+ snippet = snippet.replace("<strong>", "").replace("</strong>", "")
149
+
150
  blocks.append(
151
  f"[{i}] {title}\n"
152
  f"URL: {url}\n"
 
156
 
157
 
158
  # -----------------------
159
+ # Model / TeapotAI loader
160
  # -----------------------
161
+ @st.cache_resource
162
+ def load_teapot_ai(model_name: str) -> TeapotAI:
163
+ """
164
+ Cached per model_name. TinyTeapot will be loaded on startup (we call it once).
165
+ """
166
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
167
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
168
 
169
+ device = "cuda" if torch.cuda.is_available() else "cpu"
170
+ model.to(device)
171
+ model.eval()
172
 
173
+ teapot_ai = TeapotAI(
174
+ tokenizer=tokenizer,
175
+ model=model,
176
+ documents=DEFAULT_DOCUMENTS,
177
+ )
178
+ return teapot_ai
179
 
 
 
180
 
181
+ def typewriter_render(text: str, container, speed_chars_per_sec: float = 250.0):
182
+ """
183
+ TeapotAI.query isn't streamed (in this code), so we do a simple typewriter effect.
184
+ """
185
+ if not text:
186
+ container.markdown("")
187
+ return
188
+ delay = 1.0 / max(speed_chars_per_sec, 1.0)
189
+ out = ""
190
+ for ch in text:
191
+ out += ch
192
+ container.markdown(out)
193
+ time.sleep(delay)
194
 
195
 
196
  # -----------------------
197
+ # UI
198
  # -----------------------
199
+ # Header with logo
200
+ col1, col2 = st.columns([1, 3], vertical_alignment="center")
201
+ with col1:
202
+ # Streamlit will fetch the gif directly
203
+ st.image(TEAPOT_LOGO_GIF, use_container_width=True)
204
+ with col2:
205
+ st.markdown("## TeapotAI Chat")
206
+ st.caption("Brave Search (top 3 snippets) → context → TeapotAI.query()")
 
 
 
207
 
208
  with st.sidebar:
209
  st.header("Settings")
210
+
211
+ model_choice = st.radio(
212
+ "Model",
213
+ options=[MODEL_TINY, MODEL_LLM],
214
+ index=0,
215
+ help="TinyTeapot loads by default. Switching loads the other model (cached).",
216
+ )
217
+
218
+ system_prompt = st.text_area("System prompt", value=DEFAULT_SYSTEM_PROMPT, height=140)
219
  show_sources = st.checkbox("Show sources/context", value=True)
220
 
221
+ # Optional: “typing” effect
222
+ typing_effect = st.checkbox("Typing effect", value=True)
223
+
224
+
225
+ # Load TinyTeapot on startup, regardless of current selection (your requirement)
226
+ _ = load_teapot_ai(MODEL_TINY)
227
+
228
+ # Load selected model (cached after first load)
229
+ teapot_ai = load_teapot_ai(model_choice)
230
+
231
+ # Chat state
232
  if "messages" not in st.session_state:
233
+ st.session_state.messages = [] # [{"role": "user"/"assistant", "content": str}]
234
 
235
+ # Render history
236
  for m in st.session_state.messages:
237
  with st.chat_message(m["role"]):
238
  st.markdown(m["content"])
239
 
240
+ question = st.chat_input("Ask a question… (@sources are fetched via Brave)")
 
241
 
242
  if question:
 
243
  st.session_state.messages.append({"role": "user", "content": question})
244
  with st.chat_message("user"):
245
  st.markdown(question)
246
 
247
+ # Brave context
 
 
248
  try:
249
  results = brave_search_snippets(question, top_k=TOP_K)
250
  context = format_context_from_results(results)
251
+ except Exception:
 
 
252
  results = []
253
+ context = ""
 
254
 
255
+ # TeapotAI query (context comes from Brave)
256
+ # NOTE: you explicitly want context="" param to hold Brave results after stripping strong tags.
257
+ answer = teapot_ai.query(
258
+ query=question,
259
+ context=context,
260
+ system_prompt=system_prompt,
261
+ )
262
 
263
  with st.chat_message("assistant"):
264
  if show_sources:
 
269
  st.write("(No search context returned.)")
270
 
271
  placeholder = st.empty()
272
+ if typing_effect:
273
+ typewriter_render(answer, placeholder, speed_chars_per_sec=350.0)
274
+ else:
275
+ placeholder.markdown(answer)
276
+
277
+ st.session_state.messages.append({"role": "assistant", "content": answer})