schoginitoys commited on
Commit
ca69551
·
verified ·
1 Parent(s): 32fc4b5

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +267 -125
src/streamlit_app.py CHANGED
@@ -12,132 +12,274 @@ import numpy as np
12
  if isinstance(getattr(sys.modules.get("torch"), "classes", None), types.ModuleType):
13
  torch.classes.__path__ = []
14
 
15
- # pip install tiktoken transformers
16
- import tiktoken
 
17
  from transformers import GPT2TokenizerFast
18
 
19
- st.set_page_config(page_title="Embedding Dimension Visualizer", layout="wide")
20
- st.title("🔍 Embedding Dimension Visualizer")
21
-
22
- # ---- THEORY EXPANDER ----
23
- with st.expander("📖 Theory: Tokenization, BPE & Positional Encoding"):
24
- st.markdown("""
25
- **1️⃣ Tokenization**
26
- Splits raw text into atomic units (“tokens”).
27
-
28
- **2️⃣ Byte-Pair Encoding (BPE)**
29
- Iteratively merges the most frequent pair of symbols to build a subword vocabulary.
30
- E.g. "embedding" ["em", "bed", "ding"]
31
-
32
- **3️⃣ Positional Encoding**
33
- We add a deterministic sinusoidal vector to each token embedding so the model knows position.
34
- """)
35
- st.markdown("For embedding dimension \(d\), position \(pos\) and channel index \(i\):")
36
- st.latex(r"""\mathrm{PE}_{(pos,\,2i)} = \sin\!\Bigl(\frac{pos}{10000^{2i/d}}\Bigr)""")
37
- st.latex(r"""\mathrm{PE}_{(pos,\,2i+1)} = \cos\!\Bigl(\frac{pos}{10000^{2i/d}}\Bigr)""")
38
- st.markdown("""
39
- - \(pos\) starts at 0 for the first token
40
- - Even channels use \(\sin\), odd channels use \(\cos\)
41
- - This injects unique, smoothly varying positional signals into each embedding
42
- """)
43
-
44
-
45
- # ---- Sidebar ----
46
- with st.sidebar:
47
- st.header("Settings")
48
- input_text = st.text_input("Enter text to embed", value="Hello world!")
49
- dim = st.number_input(
50
- "Embedding dimensions",
51
- min_value=2,
52
- max_value=1536,
53
- value=3,
54
- step=1,
55
- help="Choose 2, 3, 512, 768, 1536, etc."
56
- )
57
- tokenizer_choice = st.selectbox(
58
- "Choose tokenizer",
59
- ["tiktoken", "openai", "huggingface"],
60
- help="Which tokenization scheme to demo."
61
- )
62
- generate = st.button("Generate / Reset Embedding")
63
-
64
- if not generate:
65
- st.info("Adjust the settings in the sidebar and click **Generate / Reset Embedding** to see the tokens and sliders.")
66
- st.stop()
67
-
68
- # ---- Tokenize ----
69
- if tokenizer_choice in ("tiktoken", "openai"):
70
- model_name = "gpt2" if tokenizer_choice=="tiktoken" else "gpt-3.5-turbo"
71
- enc = tiktoken.encoding_for_model(model_name)
72
- token_ids = enc.encode(input_text)
73
- token_strs = [enc.decode([tid]) for tid in token_ids]
74
- else:
75
- hf_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
76
- token_ids = hf_tokenizer.encode(input_text)
77
- token_strs = hf_tokenizer.convert_ids_to_tokens(token_ids)
78
-
79
- st.subheader("🪶 Tokens and IDs")
80
- for i, (tok, tid) in enumerate(zip(token_strs, token_ids), start=1):
81
- st.write(f"**{i}.** `{tok}` → ID **{tid}**")
82
-
83
- st.write("---")
84
- st.subheader("📊 Embedding + Positional Encoding per Token")
85
- st.write(f"Input: `{input_text}` | Tokenizer: **{tokenizer_choice}** | Dims per token: **{dim}**")
86
- if dim > 20:
87
- st.warning("Showing >20 sliders per block may be unwieldy; consider smaller dims for teaching.")
88
-
89
- # helper for sinusoidal positional encoding
90
- def get_positional_encoding(position: int, d_model: int) -> np.ndarray:
91
- pe = np.zeros(d_model, dtype=float)
92
- for i in range(d_model):
93
- angle = position / np.power(10000, (2 * (i // 2)) / d_model)
94
- pe[i] = np.sin(angle) if (i % 2 == 0) else np.cos(angle)
95
  return pe
96
 
97
- # ---- For each token, three slider‐blocks ----
98
- for t_idx, tok in enumerate(token_strs, start=1):
99
- emb = np.random.uniform(-1.0, 1.0, size=dim)
100
- pe = get_positional_encoding(t_idx - 1, dim)
101
- combined = emb + pe
102
-
103
- with st.expander(f"Token {t_idx}: `{tok}`"):
104
- st.markdown("**1️⃣ Embedding**")
105
- for d in range(dim):
106
- st.slider(
107
- label=f"Emb Dim {d+1}",
108
- min_value=-1.0, max_value=1.0,
109
- value=float(emb[d]),
110
- key=f"t{t_idx}_emb{d+1}",
111
- disabled=True
112
- )
113
-
114
- st.markdown("**2️⃣ Positional Encoding (sin / cos)**")
115
- for d in range(dim):
116
- st.slider(
117
- label=f"PE Dim {d+1}",
118
- min_value=-1.0, max_value=1.0,
119
- value=float(pe[d]),
120
- key=f"t{t_idx}_pe{d+1}",
121
- disabled=True
122
- )
123
-
124
- st.markdown("**3️⃣ Embedding + Positional Encoding**")
125
- for d in range(dim):
126
- st.slider(
127
- label=f"Sum Dim {d+1}",
128
- min_value=-2.0, max_value=2.0,
129
- value=float(combined[d]),
130
- key=f"t{t_idx}_sum{d+1}",
131
- disabled=True
132
- )
133
-
134
- # ---- NEW FINAL SECTION ----
135
- st.write("---")
136
- st.subheader("Final Input Embedding Plus Positional Encoding Ready to Send to ATtention Heads")
137
-
138
- for t_idx, tid in enumerate(token_ids, start=1):
139
- with st.expander(f"Token ID {tid}"):
140
- for d in range(1, dim+1):
141
- # pull the “sum” value out of session state
142
- val = st.session_state.get(f"t{t_idx}_sum{d}", None)
143
- st.write(f"Dim {d}: {val:.4f}" if val is not None else f"Dim {d}: N/A")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  if isinstance(getattr(sys.modules.get("torch"), "classes", None), types.ModuleType):
13
  torch.classes.__path__ = []
14
 
15
+ import torch
16
+ import numpy as np
17
+ import streamlit as st
18
  from transformers import GPT2TokenizerFast
19
 
20
+ # --- Setup ---
21
+ st.set_page_config(page_title="Text to Embedding Visualizer", layout="wide")
22
+ st.title("🔍 Token Embedding & Positional Encoding Coding Demo")
23
+
24
+ # --- Input UI ---
25
+ sentence = st.text_input("Enter your sentence", "Learning is fun")
26
+ embedding_dim = st.slider("Embedding Dimension (even only)", min_value=4, max_value=64, value=8, step=2)
27
+
28
+ # --- Load tokenizer ---
29
+ tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
30
+ input_ids = tokenizer.encode(sentence, return_tensors="pt")[0]
31
+ tokens = tokenizer.convert_ids_to_tokens(input_ids)
32
+
33
+ # st.markdown("### 1️⃣ Tokenization")
34
+ # with st.expander("Show Token IDs"):
35
+ # st.write("**Tokens:**", tokens)
36
+ # st.write("**Token IDs:**", input_ids.tolist())
37
+
38
+ st.markdown("### 1️⃣ Tokenization")
39
+ with st.expander("Token IDs and Subwords"):
40
+ st.write("**Tokens:**", tokens)
41
+ st.write("**Token IDs:**", input_ids.tolist())
42
+
43
+ with st.expander("📜 Show Code: Tokenization"):
44
+ st.code("""
45
+ tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
46
+ input_ids = tokenizer.encode(sentence, return_tensors="pt")[0]
47
+ tokens = tokenizer.convert_ids_to_tokens(input_ids)
48
+ """, language="python")
49
+
50
+
51
+ # --- Embedding Matrix ---
52
+ torch.manual_seed(0) # Reproducibility
53
+ embedding_matrix = torch.nn.Embedding(tokenizer.vocab_size, embedding_dim)
54
+ embedded = embedding_matrix(input_ids)
55
+
56
+ st.markdown("### 2️⃣ Embedding")
57
+ with st.expander("Show Token Embeddings"):
58
+ st.write("Shape:", embedded.shape)
59
+ st.write(embedded)
60
+
61
+ with st.expander("📜 Show Code: Embedding"):
62
+ st.code(f"""
63
+ embedding_matrix = torch.nn.Embedding(tokenizer.vocab_size, {embedding_dim})
64
+ embedded = embedding_matrix(input_ids)
65
+ """, language="python")
66
+
67
+ # --- Positional Encoding ---
68
+ def get_positional_encoding(seq_len, dim):
69
+ pe = torch.zeros(seq_len, dim)
70
+ position = torch.arange(0, seq_len, dtype=torch.float32).unsqueeze(1)
71
+ div_term = torch.exp(torch.arange(0, dim, 2).float() * (-np.log(10000.0) / dim))
72
+ pe[:, 0::2] = torch.sin(position * div_term)
73
+ pe[:, 1::2] = torch.cos(position * div_term)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  return pe
75
 
76
+ pos_enc = get_positional_encoding(len(input_ids), embedding_dim)
77
+
78
+ st.markdown("### 3️⃣ Positional Encoding")
79
+ with st.expander("Show Positional Encoding"):
80
+ st.write("Shape:", pos_enc.shape)
81
+ st.write(pos_enc)
82
+
83
+ with st.expander("📜 Show Code: Positional Encoding"):
84
+ st.code(f'''
85
+ def get_positional_encoding(seq_len, dim):
86
+ pe = torch.zeros(seq_len, dim)
87
+ position = torch.arange(0, seq_len).unsqueeze(1).float()
88
+ div_term = torch.exp(torch.arange(0, dim, 2).float() * (-np.log(10000.0) / dim))
89
+ pe[:, 0::2] = torch.sin(position * div_term)
90
+ pe[:, 1::2] = torch.cos(position * div_term)
91
+ return pe
92
+
93
+ pos_enc = get_positional_encoding(len(input_ids), {embedding_dim})
94
+ ''', language="python")
95
+
96
+ # --- Combined Embedding + Position ---
97
+ embedded_with_pos = embedded + pos_enc
98
+
99
+ st.markdown("### 4️⃣ Embedding + Positional Encoding")
100
+ with st.expander("Show Combined Embedding"):
101
+ st.write(embedded_with_pos)
102
+
103
+ with st.expander("📜 Show Code: Add Positional Encoding"):
104
+ st.code("""
105
+ embedded_with_pos = embedded + pos_enc
106
+ """, language="python")
107
+
108
+ # --- Approximate Reverse to Token IDs ---
109
+ def find_closest_token(vec, emb_matrix):
110
+ sims = torch.nn.functional.cosine_similarity(vec.unsqueeze(0), emb_matrix.weight, dim=1)
111
+ return torch.argmax(sims).item()
112
+
113
+ recovered_ids = [find_closest_token(vec, embedding_matrix) for vec in embedded]
114
+ #recovered_text = tokenizer.decode(recovered_ids)
115
+
116
+ #st.markdown("### 5️⃣ Approximate Reverse")
117
+ #with st.expander("Recovered Tokens"):
118
+ # st.write("**Recovered IDs:**", recovered_ids)
119
+ # st.write("**Recovered Text:**", recovered_text)
120
+
121
+ recovered_tokens = tokenizer.convert_ids_to_tokens(recovered_ids) # ← Subwords
122
+ recovered_text = tokenizer.decode(recovered_ids) # Final string
123
+
124
+ st.markdown("### 5️⃣ Approximate Reverse")
125
+ with st.expander("Recovered Tokens and Text"):
126
+ st.write("**Recovered Token IDs:**", recovered_ids)
127
+ st.write("**Recovered Subword Tokens (BPE):**", recovered_tokens)
128
+ st.write("**Recovered Sentence:**", recovered_text)
129
+
130
+ with st.expander("📜 Show Code: Recover Token IDs and Text"):
131
+ st.code("""
132
+ def find_closest_token(vec, emb_matrix):
133
+ sims = torch.nn.functional.cosine_similarity(vec.unsqueeze(0), emb_matrix.weight, dim=1)
134
+ return torch.argmax(sims).item()
135
+
136
+ recovered_ids = [find_closest_token(vec, embedding_matrix) for vec in embedded]
137
+ recovered_tokens = tokenizer.convert_ids_to_tokens(recovered_ids)
138
+ recovered_text = tokenizer.decode(recovered_ids)
139
+ """, language="python")
140
+
141
+ # --- Recover Position (Approx) ---
142
+ recovered_pos = embedded_with_pos - embedded
143
+ position_error = pos_enc - recovered_pos
144
+
145
+ st.markdown("### 6️⃣ Recovered Positional Encoding")
146
+ with st.expander("Compare Recovered vs Original"):
147
+ st.write("**Recovered Positional Encoding:**")
148
+ st.write(recovered_pos)
149
+ st.write("**Difference from Original (should be ~0):**")
150
+ st.write(position_error)
151
+
152
+ with st.expander("📜 Show Code: Recovered Positional Encoding"):
153
+ st.code("""
154
+ recovered_pos = embedded_with_pos - embedded
155
+ position_error = pos_enc - recovered_pos
156
+ """, language="python")
157
+
158
+ # Estimate position from positional encoding using cosine similarity
159
+ def estimate_position_from_encoding(pe_row, full_table):
160
+ sims = torch.nn.functional.cosine_similarity(pe_row.unsqueeze(0), full_table, dim=1)
161
+ return torch.argmax(sims).item()
162
+
163
+ # Build reference table of known encodings for positions 0 to N
164
+ reference_pos_table = get_positional_encoding(seq_len=len(input_ids), dim=embedding_dim)
165
+
166
+ # Now estimate each token's position
167
+ estimated_positions = [estimate_position_from_encoding(row, reference_pos_table) for row in recovered_pos]
168
+
169
+ st.markdown("### 7️⃣ Estimate Position from Positional Encoding")
170
+ with st.expander("Recovered Positions"):
171
+ st.write("**Estimated Token Positions:**", estimated_positions)
172
+ st.write("**Original True Positions:**", list(range(len(input_ids))))
173
+
174
+ with st.expander("📜 Show Code: Estimate Positions"):
175
+ st.code("""
176
+ def estimate_position_from_encoding(pe_row, full_table):
177
+ sims = torch.nn.functional.cosine_similarity(pe_row.unsqueeze(0), full_table, dim=1)
178
+ return torch.argmax(sims).item()
179
+
180
+ reference_pos_table = get_positional_encoding(seq_len=len(input_ids), dim=embedding_dim)
181
+ estimated_positions = [estimate_position_from_encoding(row, reference_pos_table) for row in recovered_pos]
182
+ """, language="python")
183
+
184
+
185
+ st.markdown("### 📘 Final Notes: Theory & Formulas")
186
+
187
+ with st.expander("🧠 Theory and Formulas"):
188
+ st.markdown(r"""
189
+ ### 1️⃣ Tokenization (BPE)
190
+
191
+ We use **Byte Pair Encoding (BPE)** to break text into subword units.
192
+ For example:
193
+
194
+ "Learning is fun" → ["Learning", "Ġis", "Ġfun"]
195
+
196
+
197
+ Note: The "Ġ" indicates a **space** before the token.
198
+
199
+ ---
200
+
201
+ ### 2️⃣ Embedding
202
+
203
+ Each token ID $t_i \in \mathbb{Z}$ is mapped to a dense vector:
204
+
205
+ $$
206
+ \text{Embedding}(t_i) = \mathbf{e}_i \in \mathbb{R}^d
207
+ $$
208
+
209
+ Where:
210
+
211
+ - $t_i$: token ID
212
+ - $\mathbf{e}_i$: embedding vector of dimension $d$
213
+
214
+ ---
215
+
216
+ ### 3️⃣ Sinusoidal Positional Encoding
217
+
218
+ Used to encode the **position $p$** of a token without learnable parameters:
219
+
220
+ $$
221
+ \text{PE}(p, 2i) = \sin\left(\frac{p}{10000^{\frac{2i}{d}}}\right)
222
+ $$
223
+
224
+ $$
225
+ \text{PE}(p, 2i+1) = \cos\left(\frac{p}{10000^{\frac{2i}{d}}}\right)
226
+ $$
227
+
228
+ Where:
229
+
230
+ - $p$: position index (0, 1, 2, …)
231
+ - $i$: dimension index
232
+ - $d$: total embedding dimension
233
+
234
+ This gives a positional vector $\text{PE}(p) \in \mathbb{R}^d$
235
+
236
+ ---
237
+
238
+ ### 4️⃣ Add Embedding and Positional Encoding
239
+
240
+ We add the embedding and positional encoding element-wise:
241
+
242
+ $$
243
+ \mathbf{z}_i = \mathbf{e}_i + \text{PE}(p_i)
244
+ $$
245
+
246
+ Where:
247
+
248
+ - $\mathbf{z}_i$: final input to the transformer
249
+
250
+ ---
251
+
252
+ ### 5️⃣ Reverse Lookup (Approximate)
253
+
254
+ We find the nearest embedding using cosine similarity:
255
+
256
+ $$
257
+ \hat{t}_i = \underset{j}{\arg\max} \left( \frac{ \mathbf{z}_i \cdot \mathbf{e}_j }{ \| \mathbf{z}_i \| \, \| \mathbf{e}_j \| } \right)
258
+ $$
259
+
260
+ ---
261
+
262
+ ### 6️⃣ Recover Position from Embedding + PE
263
+
264
+ To isolate positional encoding:
265
+
266
+ $$
267
+ \text{Recovered PE}_i = \mathbf{z}_i - \mathbf{e}_i
268
+ $$
269
+
270
+ We then compare this with reference positional encodings to estimate token position.
271
+
272
+ ---
273
+
274
+ ### 🌟 Summary Table
275
+
276
+ | Step | What Happens |
277
+ |------|--------------|
278
+ | **Tokenization** | Sentence → Subwords → Token IDs |
279
+ | **Embedding** | Token IDs → Vectors |
280
+ | **Pos Encoding** | Position Index → Sin/Cos Vector |
281
+ | **Sum** | Embedding + PE = Input to Transformer |
282
+ | **Reverse** | Approximate token ID from vector |
283
+ | **PE Recovery** | Recover position using similarity |
284
+
285
+ """, unsafe_allow_html=True)