berndf commited on
Commit
6ec32ce
·
verified ·
1 Parent(s): 0837a0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -257
app.py CHANGED
@@ -1,270 +1,150 @@
1
- # app.py
2
- import time
3
- import random
4
- import numpy as np
5
  import streamlit as st
6
- import plotly.graph_objects as go
7
- from sklearn.decomposition import PCA
8
- import torch
9
  from transformers import AutoTokenizer, AutoModel
 
 
10
 
 
 
 
 
11
 
12
- st.set_page_config(page_title="Embedding Demo", layout="wide")
13
-
14
- # ----------------------------
15
- # BASE DATASETS (lowercase)
16
- # ----------------------------
17
- DATASETS = {
18
- "countries": [
19
- "germany","france","italy","spain","portugal","poland","netherlands","belgium",
20
- "austria","switzerland","greece","norway","sweden","finland","denmark","ireland",
21
- "hungary","czechia","slovakia","slovenia","iceland","estonia","latvia","lithuania","romania"
22
- ],
23
- "animals": [
24
- "cat","dog","lion","tiger","bear","wolf","fox","eagle","shark","whale",
25
- "zebra","giraffe","elephant","hippopotamus","rhinoceros","kangaroo","panda","otter","seal","dolphin",
26
- "chimpanzee","gorilla","leopard","cheetah","lynx"
27
- ],
28
- "furniture": [
29
- "armchair","sofa","dining table","coffee table","bookshelf","bed","wardrobe","desk","office chair","dresser",
30
- "nightstand","side table","tv stand","loveseat","chaise lounge","bench","hutch","kitchen island","futon","recliner",
31
- "ottoman","console table","vanity","buffet","sectional sofa"
32
- ],
33
- "actors": [
34
- "brad pitt","angelina jolie","meryl streep","leonardo dicaprio","tom hanks","scarlett johansson","robert de niro",
35
- "natalie portman","matt damon","cate blanchett","johnny depp","keanu reeves","hugh jackman","emma stone","ryan gosling",
36
- "jennifer lawrence","christian bale","charlize theron","will smith","anne hathaway","denzel washington","morgan freeman",
37
- "julia roberts","george clooney","kate winslet"
38
- ],
39
- "rock group": [
40
- "the beatles","rolling stones","pink floyd","queen","led zeppelin","u2","ac/dc","nirvana","radiohead","metallica",
41
- "guns n' roses","red hot chili peppers","coldplay","pearl jam","the police","aerosmith","green day","foo fighters",
42
- "the doors","bon jovi","deep purple","the who","the kinks","fleetwood mac","the beach boys"
43
- ],
44
- "sports": [
45
- "soccer","basketball","tennis","baseball","golf","swimming","cycling","running","volleyball","rugby",
46
- "boxing","skiing","snowboarding","surfing","skateboarding","karate","judo","fencing","rowing","badminton",
47
- "cricket","table tennis","gymnastics","hockey","climbing"
48
- ]
49
- }
50
-
51
- # ----------------------------
52
- # RANDOM MIXED SETS (once per session)
53
- # ----------------------------
54
- def make_random_mixed_sets(base: dict, n_sets: int = 3) -> dict:
55
- keys = list(base.keys())
56
- mixed = {}
57
- for _ in range(n_sets):
58
- sources = random.sample(keys, 3)
59
- items = []
60
- for s in sources:
61
- take = min(7, len(base[s]))
62
- items.extend(random.sample(base[s], take))
63
- mixed_name = "/".join(sources).lower()
64
- mixed[mixed_name] = items[:21]
65
- return mixed
66
-
67
- if "mixed_added" not in st.session_state:
68
- DATASETS.update(make_random_mixed_sets(DATASETS, 3))
69
- st.session_state.mixed_added = True
70
-
71
- # ----------------------------
72
- # MODELS (transformers)
73
- # ----------------------------
74
  EMBED_MODELS = {
75
- "all-minilm-l6-v2 (384d)": "sentence-transformers/all-MiniLM-L6-v2",
76
- "all-mpnet-base-v2 (768d)": "sentence-transformers/all-mpnet-base-v2",
77
- "all-roberta-large-v1 (1024d)": "sentence-transformers/all-roberta-large-v1",
78
  }
79
 
80
- @st.cache_resource(show_spinner=False)
81
- def load_hf_model(model_name: str):
82
- tok = AutoTokenizer.from_pretrained(model_name)
83
- mdl = AutoModel.from_pretrained(model_name)
84
- mdl.eval()
85
- return tok, mdl
86
 
87
- @st.cache_data(show_spinner=False)
88
- def embed_texts(model_name: str, texts_tuple: tuple):
89
- tokenizer, model = load_hf_model(model_name)
90
- texts = list(texts_tuple)
91
  with torch.no_grad():
92
- inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
93
- outputs = model(**inputs)
94
- token_embeddings = outputs.last_hidden_state # (B,T,H)
95
- mask = inputs["attention_mask"].unsqueeze(-1).type_as(token_embeddings)
96
- summed = (token_embeddings * mask).sum(dim=1)
97
- counts = mask.sum(dim=1).clamp(min=1e-9)
98
- embeddings = summed / counts # mean pooling
99
  return embeddings.cpu().numpy()
100
 
101
- # ----------------------------
102
- # STATE: camera + rotation
103
- # ----------------------------
104
- if "camera_eye" not in st.session_state:
105
- st.session_state.camera_eye = {"x": 1.6, "y": 1.6, "z": 1.2}
106
- if "spinning" not in st.session_state:
107
- st.session_state.spinning = False
108
- if "angle_rad" not in st.session_state:
109
- # derive from eye.x, eye.y
110
- e = st.session_state.camera_eye
111
- st.session_state.angle_rad = float(np.arctan2(e["y"], e["x"]))
112
-
113
- def update_eye_from_angle(angle_rad: float, radius: float, z: float):
114
- return {"x": radius * np.cos(angle_rad), "y": radius * np.sin(angle_rad), "z": z}
115
-
116
- # ----------------------------
117
- # NAVIGATION via st.query_params
118
- # ----------------------------
119
- def goto(page: str):
120
- st.query_params["page"] = page
121
- st.rerun()
122
-
123
- page = st.query_params.get("page", ["demo"])[0]
124
-
125
- # ----------------------------
126
- # INFO PAGE
127
- # ----------------------------
128
- def info_page():
129
- st.title("ℹ about this demo")
130
- st.write("""
131
- **embeddings** turn words (or longer text) into numerical vectors.
132
- in this vector space, **semantically related** items end up **near** each other.
133
-
134
- why this is useful:
135
- - semantic search and retrieval
136
- - clustering and topic discovery
137
- - recommendations and deduplication
138
- - measuring similarity and analogies
139
-
140
- this demo embeds single words with a selectable model, reduces to 2d/3d with pca,
141
- and shows how related words cluster in the projected space.
142
- """.strip())
143
- if st.button("⬅ back to demo"):
144
- goto("demo")
145
-
146
- # ----------------------------
147
- # DEMO PAGE
148
- # ----------------------------
149
- def demo_page():
150
- # top row: dataset, model + 2d/3d, info button
151
- c1, c2, c3 = st.columns([2, 2, 1])
152
- with c1:
153
- ds_names = list(DATASETS.keys())
154
- dataset_name = st.selectbox("dataset", ds_names, index=ds_names.index("furniture") if "furniture" in ds_names else 0)
155
- with c2:
156
- cc1, cc2 = st.columns([2, 1])
157
- with cc1:
158
- model_label = st.selectbox("embedding model", list(EMBED_MODELS.keys()))
159
- model_name = EMBED_MODELS[model_label]
160
- with cc2:
161
- proj_mode = st.radio("projection", ["2d", "3d"], horizontal=True)
162
- with c3:
163
- if st.button("ℹ info"):
164
- goto("info")
165
-
166
- words = DATASETS[dataset_name]
167
- st.text_area("dataset words", "\n".join(words), height=160)
168
-
169
- # Embed + PCA
170
- embs = embed_texts(model_name, tuple(words))
171
- if proj_mode == "2d":
172
- coords = PCA(n_components=2).fit_transform(embs)
173
- else:
174
- coords = PCA(n_components=3).fit_transform(embs)
175
-
176
- title_html = f"<b style='color:#1f77b4; font-size:2.0rem;'>{dataset_name}</b>"
177
-
178
- if proj_mode == "3d":
179
- # compute radius from current eye, keep same z
180
- eye = st.session_state.camera_eye
181
- radius = float(np.sqrt(eye["x"]**2 + eye["y"]**2)) or 1.6
182
- z_eye = float(eye["z"])
183
-
184
- fig = go.Figure(
185
- data=[go.Scatter3d(
186
- x=coords[:, 0], y=coords[:, 1], z=coords[:, 2],
187
- mode="markers+text", text=words, textposition="top center",
188
- marker=dict(size=6),
189
- )],
190
- layout=go.Layout(
191
- title=dict(text=title_html, x=0.5, xanchor="center", yanchor="top",
192
- font=dict(size=30, color="#1f77b4")),
193
- scene=dict(
194
- camera=dict(eye=eye, projection=dict(type="perspective")),
195
- xaxis=dict(showbackground=True, backgroundcolor="rgba(255, 230, 230, 1)"),
196
- yaxis=dict(showbackground=True, backgroundcolor="rgba(230, 255, 230, 1)"),
197
- zaxis=dict(showbackground=True, backgroundcolor="rgba(230, 230, 255, 1)"),
198
- ),
199
- margin=dict(l=0, r=0, b=0, t=60),
200
- uirevision="keep",
201
- )
202
- )
203
-
204
- # Controls under the plot: Start/Stop rotation
205
- b1, b2 = st.columns([1, 1])
206
- with b1:
207
- start_clicked = st.button("▶ start rotation", disabled=st.session_state.spinning)
208
- with b2:
209
- stop_clicked = st.button("⏹ stop rotation", disabled=not st.session_state.spinning)
210
-
211
- # If start pressed: turn on spinner and initialize angle from current stored eye
212
- if start_clicked:
213
- st.session_state.spinning = True
214
- # start from stored angle (not capturing manual camera — simple approach)
215
- st.session_state.angle_rad = float(np.arctan2(eye["y"], eye["x"]))
216
- # fall through to loop below (this turn will render once, then continue)
217
-
218
- # If stop pressed: turn off spinner (and keep stop disabled after)
219
- if stop_clicked:
220
- st.session_state.spinning = False
221
-
222
- # Live render placeholder
223
- placeholder = st.empty()
224
-
225
- # First draw (static) before any loop
226
- placeholder.plotly_chart(fig, use_container_width=True)
227
-
228
- # Continuous rotation loop while spinning
229
- if st.session_state.spinning:
230
- # one "batch" of frames, then rerun to keep UI responsive
231
- steps_per_batch = 120
232
- step = np.deg2rad(3) # 3 degrees per frame ~ smooth
233
- for _ in range(steps_per_batch):
234
- if not st.session_state.spinning:
235
- break
236
- st.session_state.angle_rad += step
237
- new_eye = update_eye_from_angle(st.session_state.angle_rad, radius, z_eye)
238
- st.session_state.camera_eye = new_eye
239
- fig.update_layout(scene_camera=dict(eye=new_eye, projection=dict(type="perspective")))
240
- placeholder.plotly_chart(fig, use_container_width=True)
241
- time.sleep(0.033) # ~30 FPS
242
-
243
- # If still spinning after this batch, rerun to keep going
244
- if st.session_state.spinning:
245
- st.rerun()
246
-
247
- else:
248
- fig = go.Figure(
249
- data=[go.Scatter(
250
- x=coords[:, 0], y=coords[:, 1],
251
- mode="markers+text", text=words, textposition="top center",
252
- marker=dict(size=9),
253
- )],
254
- layout=go.Layout(
255
- title=dict(text=title_html, x=0.5, xanchor="center", yanchor="top",
256
- font=dict(size=30, color="#1f77b4")),
257
- xaxis=dict(title="PC1"),
258
- yaxis=dict(title="PC2", scaleanchor="x", scaleratio=1),
259
- margin=dict(l=0, r=0, b=0, t=60),
260
- )
261
- )
262
- st.plotly_chart(fig, use_container_width=True)
263
 
264
- # ----------------------------
265
- # ROUTER
266
- # ----------------------------
267
- if page == "info":
268
- info_page()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  else:
270
- demo_page()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import plotly.graph_objs as go
3
+ import numpy as np
4
+ import random
5
  from transformers import AutoTokenizer, AutoModel
6
+ import torch
7
+ from sklearn.decomposition import PCA
8
 
9
+ # -------------------
10
+ # CONFIG
11
+ # -------------------
12
+ st.set_page_config(layout="wide", page_title="Embedding Visualizer")
13
 
14
+ # -------------------
15
+ # EMBEDDING MODELS
16
+ # -------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  EMBED_MODELS = {
18
+ "all-MiniLM-L6-v2 (384 dims)": "sentence-transformers/all-MiniLM-L6-v2",
19
+ "all-mpnet-base-v2 (768 dims)": "sentence-transformers/all-mpnet-base-v2",
20
+ "multi-qa-MiniLM-L6-cos-v1 (384 dims)": "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
21
  }
22
 
23
+ @st.cache_resource
24
+ def load_model(model_name):
25
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
26
+ model = AutoModel.from_pretrained(model_name)
27
+ return tokenizer, model
 
28
 
29
+ def embed_texts(texts, tokenizer, model):
30
+ tokens = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
 
 
31
  with torch.no_grad():
32
+ embeddings = model(**tokens).last_hidden_state.mean(dim=1)
 
 
 
 
 
 
33
  return embeddings.cpu().numpy()
34
 
35
+ # -------------------
36
+ # DATASETS
37
+ # -------------------
38
+ base_sets = {
39
+ "countries": ["Germany", "France", "Italy", "Spain", "Portugal", "Norway", "Sweden", "Denmark", "Poland", "Austria"],
40
+ "animals": ["Dog", "Cat", "Horse", "Elephant", "Tiger", "Lion", "Monkey", "Giraffe", "Zebra", "Bear"],
41
+ "furniture": [
42
+ "Armchair", "Sofa", "Dining table", "Coffee table", "Bookshelf", "Bed", "Wardrobe",
43
+ "Desk", "Office chair", "Dresser", "Nightstand", "Side table", "TV stand",
44
+ "Loveseat", "Chaise lounge", "Bench", "Hutch", "Kitchen island", "Futon", "Recliner",
45
+ "Ottoman", "Console table", "Vanity", "Buffet", "Sectional sofa"
46
+ ],
47
+ "actor": ["Tom Hanks", "Brad Pitt", "Leonardo DiCaprio", "Meryl Streep", "Natalie Portman",
48
+ "Morgan Freeman", "Emma Stone", "Denzel Washington", "Cate Blanchett", "Robert De Niro"],
49
+ "rock group": ["The Beatles", "The Rolling Stones", "Queen", "Pink Floyd", "Led Zeppelin",
50
+ "U2", "The Who", "Metallica", "Nirvana", "Radiohead"]
51
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ # -------------------
54
+ # CREATE RANDOM MIXED SETS
55
+ # -------------------
56
+ def create_random_mixed_sets(num_sets=3):
57
+ mixed_sets = {}
58
+ keys = list(base_sets.keys())
59
+ for _ in range(num_sets):
60
+ chosen = random.sample(keys, 3)
61
+ words = []
62
+ for k in chosen:
63
+ words.extend(random.sample(base_sets[k], min(7, len(base_sets[k]))))
64
+ mixed_name = "/".join(chosen)
65
+ mixed_sets[mixed_name] = words
66
+ return mixed_sets
67
+
68
+ mixed_sets = create_random_mixed_sets()
69
+ datasets = {**base_sets, **mixed_sets}
70
+
71
+ # -------------------
72
+ # UI LAYOUT
73
+ # -------------------
74
+ col_top1, col_top2, col_top3 = st.columns([2, 2, 1])
75
+ with col_top1:
76
+ dataset_name = st.selectbox("Dataset", list(datasets.keys()), index=list(datasets.keys()).index("furniture"))
77
+ with col_top2:
78
+ embed_model_name = st.selectbox("Embedding model", list(EMBED_MODELS.keys()))
79
+ with col_top3:
80
+ st.markdown("[ℹ Info](?page=info)")
81
+
82
+ if st.query_params.get("page") == "info":
83
+ st.markdown("""
84
+ ## embedding demo info
85
+ embeddings are numerical vector representations of text.
86
+ they capture meaning so that similar words or phrases are located near each other in the vector space.
87
+ this makes them useful for search, clustering, recommendation, and semantic analysis.
88
+ """)
89
+ st.stop()
90
+
91
+ # -------------------
92
+ # MAIN TWO-COLUMN LAYOUT
93
+ # -------------------
94
+ col1, col2 = st.columns([1, 2])
95
+
96
+ with col1:
97
+ dataset_words = st.text_area("Dataset words", "\n".join(datasets[dataset_name]), height=400)
98
+ words = [w.strip() for w in dataset_words.split("\n") if w.strip()]
99
+
100
+ with col2:
101
+ dim_mode = st.radio("Projection", ["2D", "3D"], horizontal=True)
102
+
103
+ # -------------------
104
+ # EMBEDDING & PROJECTION
105
+ # -------------------
106
+ tokenizer, model = load_model(EMBED_MODELS[embed_model_name])
107
+ vectors = embed_texts(words, tokenizer, model)
108
+
109
+ if dim_mode == "2D":
110
+ proj = PCA(n_components=2).fit_transform(vectors)
111
+ else:
112
+ proj = PCA(n_components=3).fit_transform(vectors)
113
+
114
+ # -------------------
115
+ # PLOT
116
+ # -------------------
117
+ rotate = st.session_state.get("rotate", False)
118
+ scene_camera = dict(eye=dict(x=1.25, y=1.25, z=1.25))
119
+
120
+ if dim_mode == "3D":
121
+ trace = go.Scatter3d(
122
+ x=proj[:, 0], y=proj[:, 1], z=proj[:, 2],
123
+ mode='markers+text',
124
+ text=words,
125
+ marker=dict(size=6, color='blue', opacity=0.8),
126
+ textposition='top center'
127
+ )
128
+ fig = go.Figure(data=[trace])
129
+ fig.update_layout(scene_camera=scene_camera, margin=dict(l=0, r=0, t=0, b=0))
130
  else:
131
+ trace = go.Scatter(
132
+ x=proj[:, 0], y=proj[:, 1],
133
+ mode='markers+text',
134
+ text=words,
135
+ marker=dict(size=8, color='blue', opacity=0.8),
136
+ textposition='top center'
137
+ )
138
+ fig = go.Figure(data=[trace])
139
+ fig.update_layout(margin=dict(l=0, r=0, t=0, b=0))
140
+
141
+ # -------------------
142
+ # ROTATION BUTTON
143
+ # -------------------
144
+ if st.button("🔄 Toggle Rotation"):
145
+ st.session_state.rotate = not st.session_state.get("rotate", False)
146
+
147
+ if rotate and dim_mode == "3D":
148
+ fig.update_layout(scene_camera=dict(eye=dict(x=1.25, y=1.25, z=1.25), up=dict(x=0, y=0, z=1)))
149
+
150
+ st.plotly_chart(fig, use_container_width=True)