zakerytclarke commited on
Commit
4f6c67a
·
verified ·
1 Parent(s): 3980160

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +207 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,209 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
 
 
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import os
2
+ import re
3
+ import threading
4
+ from typing import List, Dict
5
+
6
+ import requests
7
  import streamlit as st
8
+ import torch
9
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
10
+
11
+
12
+ # -----------------------
13
+ # Config
14
+ # -----------------------
15
+ MODEL_NAME = "teapotai/tinyteapot"
16
+ BRAVE_ENDPOINT = "https://api.search.brave.com/res/v1/web/search"
17
+ TOP_K = 3
18
+ TIMEOUT_SECS = 15
19
+
20
+
21
+ # -----------------------
22
+ # Model load (cached)
23
+ # -----------------------
24
+ @st.cache_resource
25
+ def load_model():
26
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
27
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
28
+
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ model.to(device)
31
+ model.eval()
32
+ return tokenizer, model, device
33
+
34
+
35
+ # -----------------------
36
+ # Brave Search
37
+ # -----------------------
38
+ def brave_search_snippets(query: str, top_k: int = 3) -> List[Dict[str, str]]:
39
+ brave_api_key = os.getenv("BRAVE_API_KEY")
40
+ if not brave_api_key:
41
+ raise RuntimeError("Missing BRAVE_API_KEY env var.")
42
+
43
+ headers = {
44
+ "Accept": "application/json",
45
+ "X-Subscription-Token": brave_api_key,
46
+ }
47
+ params = {"q": query, "count": top_k}
48
+
49
+ resp = requests.get(
50
+ BRAVE_ENDPOINT,
51
+ headers=headers,
52
+ params=params,
53
+ timeout=TIMEOUT_SECS,
54
+ )
55
+ resp.raise_for_status()
56
+ data = resp.json()
57
+
58
+ results = []
59
+ web = data.get("web") or {}
60
+ items = web.get("results") or []
61
+ for item in items[:top_k]:
62
+ title = (item.get("title") or "").strip()
63
+ url = (item.get("url") or "").strip()
64
+ snippet = (item.get("description") or "").strip()
65
+ if title or url or snippet:
66
+ results.append({"title": title, "url": url, "snippet": snippet})
67
+ return results
68
+
69
+
70
+ def format_context_from_results(results: List[Dict[str, str]]) -> str:
71
+ """
72
+ Stable, explicit formatting. If you want it to match some *other* exact template,
73
+ change only this function.
74
+ """
75
+ if not results:
76
+ return ""
77
+
78
+ blocks = []
79
+ for i, r in enumerate(results, start=1):
80
+ title = re.sub(r"\s+", " ", r.get("title", "")).strip()
81
+ url = re.sub(r"\s+", " ", r.get("url", "")).strip()
82
+ snippet = re.sub(r"\s+", " ", r.get("snippet", "")).strip()
83
+
84
+ blocks.append(
85
+ f"[{i}] {title}\n"
86
+ f"URL: {url}\n"
87
+ f"Snippet: {snippet}"
88
+ )
89
+ return "\n\n".join(blocks)
90
+
91
+
92
+ # -----------------------
93
+ # TinyTeapot generation (streaming)
94
+ # -----------------------
95
+ def build_prompt(context: str, system_prompt: str, question: str) -> str:
96
+ # EXACTLY your format: context + system_prompt + question
97
+ return f"{context}\n{system_prompt}\n{question}\n"
98
+
99
+
100
+ def stream_generate(tokenizer, model, device, prompt: str, max_new_tokens: int, temperature: float, top_p: float):
101
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
102
+
103
+ do_sample = float(temperature) > 0.0
104
+ gen_kwargs = dict(
105
+ **inputs,
106
+ max_new_tokens=int(max_new_tokens),
107
+ do_sample=do_sample,
108
+ temperature=float(temperature) if do_sample else None,
109
+ top_p=float(top_p) if do_sample else None,
110
+ num_beams=1,
111
+ )
112
+
113
+ # Transformers streamer: yields decoded text pieces as generation proceeds
114
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
115
+
116
+ def _run():
117
+ # Remove None args (generate doesn't like None for some models)
118
+ clean_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None}
119
+ model.generate(**clean_kwargs, streamer=streamer)
120
+
121
+ t = threading.Thread(target=_run, daemon=True)
122
+ t.start()
123
+
124
+ partial = ""
125
+ for piece in streamer:
126
+ partial += piece
127
+ yield partial
128
+
129
+
130
+ # -----------------------
131
+ # Streamlit UI
132
+ # -----------------------
133
+ st.set_page_config(page_title="TinyTeapot + Brave Search", page_icon="🫖", layout="centered")
134
+
135
+ st.title("🫖 TinyTeapot + Brave Search (Top 3)")
136
+
137
+ default_system_prompt = (
138
+ "You are Teapot, an open-source AI assistant optimized for low-end devices, "
139
+ "providing short, accurate responses without hallucinating while excelling at "
140
+ "information extraction and text summarization. "
141
+ "If the context does not answer the question, reply exactly: "
142
+ "'I am sorry but I don't have any information on that'."
143
+ )
144
+
145
+ with st.sidebar:
146
+ st.header("Settings")
147
+ system_prompt = st.text_area("System prompt", value=default_system_prompt, height=140)
148
+ max_new_tokens = st.slider("Max new tokens", 1, 512, 128, 1)
149
+ temperature = st.slider("Temperature (0 = greedy)", 0.0, 2.0, 0.0, 0.1)
150
+ top_p = st.slider("Top-p", 0.1, 1.0, 0.95, 0.05)
151
+ show_sources = st.checkbox("Show sources/context", value=True)
152
+
153
+ # Session state for chat history
154
+ if "messages" not in st.session_state:
155
+ st.session_state.messages = [] # list of {"role": "user"/"assistant", "content": str}
156
+
157
+ # Render chat history
158
+ for m in st.session_state.messages:
159
+ with st.chat_message(m["role"]):
160
+ st.markdown(m["content"])
161
+
162
+ # Chat input
163
+ question = st.chat_input("Ask a question (the app will Brave-search top 3 snippets)…")
164
+
165
+ if question:
166
+ # Add user message
167
+ st.session_state.messages.append({"role": "user", "content": question})
168
+ with st.chat_message("user"):
169
+ st.markdown(question)
170
+
171
+ tokenizer, model, device = load_model()
172
+
173
+ # Get Brave context
174
+ try:
175
+ results = brave_search_snippets(question, top_k=TOP_K)
176
+ context = format_context_from_results(results)
177
+ except Exception as e:
178
+ # If Brave fails, keep context empty so your system prompt triggers the exact refusal.
179
+ context = ""
180
+ results = []
181
+ # You can uncomment this if you want to show the error:
182
+ # st.warning(f"Brave Search failed: {e}")
183
+
184
+ prompt = build_prompt(context=context, system_prompt=system_prompt, question=question)
185
+
186
+ with st.chat_message("assistant"):
187
+ if show_sources:
188
+ with st.expander("Sources / Context used", expanded=False):
189
+ if context.strip():
190
+ st.code(context)
191
+ else:
192
+ st.write("(No search context returned.)")
193
+
194
+ placeholder = st.empty()
195
+ final = ""
196
+
197
+ for partial in stream_generate(
198
+ tokenizer=tokenizer,
199
+ model=model,
200
+ device=device,
201
+ prompt=prompt,
202
+ max_new_tokens=max_new_tokens,
203
+ temperature=temperature,
204
+ top_p=top_p,
205
+ ):
206
+ final = partial
207
+ placeholder.markdown(final)
208
 
209
+ st.session_state.messages.append({"role": "assistant", "content": final})