Spaces:
Runtime error
Runtime error
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +267 -125
src/streamlit_app.py
CHANGED
|
@@ -12,132 +12,274 @@ import numpy as np
|
|
| 12 |
if isinstance(getattr(sys.modules.get("torch"), "classes", None), types.ModuleType):
|
| 13 |
torch.classes.__path__ = []
|
| 14 |
|
| 15 |
-
|
| 16 |
-
import
|
|
|
|
| 17 |
from transformers import GPT2TokenizerFast
|
| 18 |
|
| 19 |
-
|
| 20 |
-
st.
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
""
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
""
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
token_strs = [enc.decode([tid]) for tid in token_ids]
|
| 74 |
-
else:
|
| 75 |
-
hf_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
| 76 |
-
token_ids = hf_tokenizer.encode(input_text)
|
| 77 |
-
token_strs = hf_tokenizer.convert_ids_to_tokens(token_ids)
|
| 78 |
-
|
| 79 |
-
st.subheader("🪶 Tokens and IDs")
|
| 80 |
-
for i, (tok, tid) in enumerate(zip(token_strs, token_ids), start=1):
|
| 81 |
-
st.write(f"**{i}.** `{tok}` → ID **{tid}**")
|
| 82 |
-
|
| 83 |
-
st.write("---")
|
| 84 |
-
st.subheader("📊 Embedding + Positional Encoding per Token")
|
| 85 |
-
st.write(f"Input: `{input_text}` | Tokenizer: **{tokenizer_choice}** | Dims per token: **{dim}**")
|
| 86 |
-
if dim > 20:
|
| 87 |
-
st.warning("Showing >20 sliders per block may be unwieldy; consider smaller dims for teaching.")
|
| 88 |
-
|
| 89 |
-
# helper for sinusoidal positional encoding
|
| 90 |
-
def get_positional_encoding(position: int, d_model: int) -> np.ndarray:
|
| 91 |
-
pe = np.zeros(d_model, dtype=float)
|
| 92 |
-
for i in range(d_model):
|
| 93 |
-
angle = position / np.power(10000, (2 * (i // 2)) / d_model)
|
| 94 |
-
pe[i] = np.sin(angle) if (i % 2 == 0) else np.cos(angle)
|
| 95 |
return pe
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
if isinstance(getattr(sys.modules.get("torch"), "classes", None), types.ModuleType):
|
| 13 |
torch.classes.__path__ = []
|
| 14 |
|
| 15 |
+
import torch
|
| 16 |
+
import numpy as np
|
| 17 |
+
import streamlit as st
|
| 18 |
from transformers import GPT2TokenizerFast
|
| 19 |
|
| 20 |
+
# --- Setup ---
|
| 21 |
+
st.set_page_config(page_title="Text to Embedding Visualizer", layout="wide")
|
| 22 |
+
st.title("🔍 Token Embedding & Positional Encoding Coding Demo")
|
| 23 |
+
|
| 24 |
+
# --- Input UI ---
|
| 25 |
+
sentence = st.text_input("Enter your sentence", "Learning is fun")
|
| 26 |
+
embedding_dim = st.slider("Embedding Dimension (even only)", min_value=4, max_value=64, value=8, step=2)
|
| 27 |
+
|
| 28 |
+
# --- Load tokenizer ---
|
| 29 |
+
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
| 30 |
+
input_ids = tokenizer.encode(sentence, return_tensors="pt")[0]
|
| 31 |
+
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
| 32 |
+
|
| 33 |
+
# st.markdown("### 1️⃣ Tokenization")
|
| 34 |
+
# with st.expander("Show Token IDs"):
|
| 35 |
+
# st.write("**Tokens:**", tokens)
|
| 36 |
+
# st.write("**Token IDs:**", input_ids.tolist())
|
| 37 |
+
|
| 38 |
+
st.markdown("### 1️⃣ Tokenization")
|
| 39 |
+
with st.expander("Token IDs and Subwords"):
|
| 40 |
+
st.write("**Tokens:**", tokens)
|
| 41 |
+
st.write("**Token IDs:**", input_ids.tolist())
|
| 42 |
+
|
| 43 |
+
with st.expander("📜 Show Code: Tokenization"):
|
| 44 |
+
st.code("""
|
| 45 |
+
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
| 46 |
+
input_ids = tokenizer.encode(sentence, return_tensors="pt")[0]
|
| 47 |
+
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
| 48 |
+
""", language="python")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# --- Embedding Matrix ---
|
| 52 |
+
torch.manual_seed(0) # Reproducibility
|
| 53 |
+
embedding_matrix = torch.nn.Embedding(tokenizer.vocab_size, embedding_dim)
|
| 54 |
+
embedded = embedding_matrix(input_ids)
|
| 55 |
+
|
| 56 |
+
st.markdown("### 2️⃣ Embedding")
|
| 57 |
+
with st.expander("Show Token Embeddings"):
|
| 58 |
+
st.write("Shape:", embedded.shape)
|
| 59 |
+
st.write(embedded)
|
| 60 |
+
|
| 61 |
+
with st.expander("📜 Show Code: Embedding"):
|
| 62 |
+
st.code(f"""
|
| 63 |
+
embedding_matrix = torch.nn.Embedding(tokenizer.vocab_size, {embedding_dim})
|
| 64 |
+
embedded = embedding_matrix(input_ids)
|
| 65 |
+
""", language="python")
|
| 66 |
+
|
| 67 |
+
# --- Positional Encoding ---
|
| 68 |
+
def get_positional_encoding(seq_len, dim):
|
| 69 |
+
pe = torch.zeros(seq_len, dim)
|
| 70 |
+
position = torch.arange(0, seq_len, dtype=torch.float32).unsqueeze(1)
|
| 71 |
+
div_term = torch.exp(torch.arange(0, dim, 2).float() * (-np.log(10000.0) / dim))
|
| 72 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 73 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
return pe
|
| 75 |
|
| 76 |
+
pos_enc = get_positional_encoding(len(input_ids), embedding_dim)
|
| 77 |
+
|
| 78 |
+
st.markdown("### 3️⃣ Positional Encoding")
|
| 79 |
+
with st.expander("Show Positional Encoding"):
|
| 80 |
+
st.write("Shape:", pos_enc.shape)
|
| 81 |
+
st.write(pos_enc)
|
| 82 |
+
|
| 83 |
+
with st.expander("📜 Show Code: Positional Encoding"):
|
| 84 |
+
st.code(f'''
|
| 85 |
+
def get_positional_encoding(seq_len, dim):
|
| 86 |
+
pe = torch.zeros(seq_len, dim)
|
| 87 |
+
position = torch.arange(0, seq_len).unsqueeze(1).float()
|
| 88 |
+
div_term = torch.exp(torch.arange(0, dim, 2).float() * (-np.log(10000.0) / dim))
|
| 89 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 90 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 91 |
+
return pe
|
| 92 |
+
|
| 93 |
+
pos_enc = get_positional_encoding(len(input_ids), {embedding_dim})
|
| 94 |
+
''', language="python")
|
| 95 |
+
|
| 96 |
+
# --- Combined Embedding + Position ---
|
| 97 |
+
embedded_with_pos = embedded + pos_enc
|
| 98 |
+
|
| 99 |
+
st.markdown("### 4️⃣ Embedding + Positional Encoding")
|
| 100 |
+
with st.expander("Show Combined Embedding"):
|
| 101 |
+
st.write(embedded_with_pos)
|
| 102 |
+
|
| 103 |
+
with st.expander("📜 Show Code: Add Positional Encoding"):
|
| 104 |
+
st.code("""
|
| 105 |
+
embedded_with_pos = embedded + pos_enc
|
| 106 |
+
""", language="python")
|
| 107 |
+
|
| 108 |
+
# --- Approximate Reverse to Token IDs ---
|
| 109 |
+
def find_closest_token(vec, emb_matrix):
|
| 110 |
+
sims = torch.nn.functional.cosine_similarity(vec.unsqueeze(0), emb_matrix.weight, dim=1)
|
| 111 |
+
return torch.argmax(sims).item()
|
| 112 |
+
|
| 113 |
+
recovered_ids = [find_closest_token(vec, embedding_matrix) for vec in embedded]
|
| 114 |
+
#recovered_text = tokenizer.decode(recovered_ids)
|
| 115 |
+
|
| 116 |
+
#st.markdown("### 5️⃣ Approximate Reverse")
|
| 117 |
+
#with st.expander("Recovered Tokens"):
|
| 118 |
+
# st.write("**Recovered IDs:**", recovered_ids)
|
| 119 |
+
# st.write("**Recovered Text:**", recovered_text)
|
| 120 |
+
|
| 121 |
+
recovered_tokens = tokenizer.convert_ids_to_tokens(recovered_ids) # ← Subwords
|
| 122 |
+
recovered_text = tokenizer.decode(recovered_ids) # ← Final string
|
| 123 |
+
|
| 124 |
+
st.markdown("### 5️⃣ Approximate Reverse")
|
| 125 |
+
with st.expander("Recovered Tokens and Text"):
|
| 126 |
+
st.write("**Recovered Token IDs:**", recovered_ids)
|
| 127 |
+
st.write("**Recovered Subword Tokens (BPE):**", recovered_tokens)
|
| 128 |
+
st.write("**Recovered Sentence:**", recovered_text)
|
| 129 |
+
|
| 130 |
+
with st.expander("📜 Show Code: Recover Token IDs and Text"):
|
| 131 |
+
st.code("""
|
| 132 |
+
def find_closest_token(vec, emb_matrix):
|
| 133 |
+
sims = torch.nn.functional.cosine_similarity(vec.unsqueeze(0), emb_matrix.weight, dim=1)
|
| 134 |
+
return torch.argmax(sims).item()
|
| 135 |
+
|
| 136 |
+
recovered_ids = [find_closest_token(vec, embedding_matrix) for vec in embedded]
|
| 137 |
+
recovered_tokens = tokenizer.convert_ids_to_tokens(recovered_ids)
|
| 138 |
+
recovered_text = tokenizer.decode(recovered_ids)
|
| 139 |
+
""", language="python")
|
| 140 |
+
|
| 141 |
+
# --- Recover Position (Approx) ---
|
| 142 |
+
recovered_pos = embedded_with_pos - embedded
|
| 143 |
+
position_error = pos_enc - recovered_pos
|
| 144 |
+
|
| 145 |
+
st.markdown("### 6️⃣ Recovered Positional Encoding")
|
| 146 |
+
with st.expander("Compare Recovered vs Original"):
|
| 147 |
+
st.write("**Recovered Positional Encoding:**")
|
| 148 |
+
st.write(recovered_pos)
|
| 149 |
+
st.write("**Difference from Original (should be ~0):**")
|
| 150 |
+
st.write(position_error)
|
| 151 |
+
|
| 152 |
+
with st.expander("📜 Show Code: Recovered Positional Encoding"):
|
| 153 |
+
st.code("""
|
| 154 |
+
recovered_pos = embedded_with_pos - embedded
|
| 155 |
+
position_error = pos_enc - recovered_pos
|
| 156 |
+
""", language="python")
|
| 157 |
+
|
| 158 |
+
# Estimate position from positional encoding using cosine similarity
|
| 159 |
+
def estimate_position_from_encoding(pe_row, full_table):
|
| 160 |
+
sims = torch.nn.functional.cosine_similarity(pe_row.unsqueeze(0), full_table, dim=1)
|
| 161 |
+
return torch.argmax(sims).item()
|
| 162 |
+
|
| 163 |
+
# Build reference table of known encodings for positions 0 to N
|
| 164 |
+
reference_pos_table = get_positional_encoding(seq_len=len(input_ids), dim=embedding_dim)
|
| 165 |
+
|
| 166 |
+
# Now estimate each token's position
|
| 167 |
+
estimated_positions = [estimate_position_from_encoding(row, reference_pos_table) for row in recovered_pos]
|
| 168 |
+
|
| 169 |
+
st.markdown("### 7️⃣ Estimate Position from Positional Encoding")
|
| 170 |
+
with st.expander("Recovered Positions"):
|
| 171 |
+
st.write("**Estimated Token Positions:**", estimated_positions)
|
| 172 |
+
st.write("**Original True Positions:**", list(range(len(input_ids))))
|
| 173 |
+
|
| 174 |
+
with st.expander("📜 Show Code: Estimate Positions"):
|
| 175 |
+
st.code("""
|
| 176 |
+
def estimate_position_from_encoding(pe_row, full_table):
|
| 177 |
+
sims = torch.nn.functional.cosine_similarity(pe_row.unsqueeze(0), full_table, dim=1)
|
| 178 |
+
return torch.argmax(sims).item()
|
| 179 |
+
|
| 180 |
+
reference_pos_table = get_positional_encoding(seq_len=len(input_ids), dim=embedding_dim)
|
| 181 |
+
estimated_positions = [estimate_position_from_encoding(row, reference_pos_table) for row in recovered_pos]
|
| 182 |
+
""", language="python")
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
st.markdown("### 📘 Final Notes: Theory & Formulas")
|
| 186 |
+
|
| 187 |
+
with st.expander("🧠 Theory and Formulas"):
|
| 188 |
+
st.markdown(r"""
|
| 189 |
+
### 1️⃣ Tokenization (BPE)
|
| 190 |
+
|
| 191 |
+
We use **Byte Pair Encoding (BPE)** to break text into subword units.
|
| 192 |
+
For example:
|
| 193 |
+
|
| 194 |
+
"Learning is fun" → ["Learning", "Ġis", "Ġfun"]
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
Note: The "Ġ" indicates a **space** before the token.
|
| 198 |
+
|
| 199 |
+
---
|
| 200 |
+
|
| 201 |
+
### 2️⃣ Embedding
|
| 202 |
+
|
| 203 |
+
Each token ID $t_i \in \mathbb{Z}$ is mapped to a dense vector:
|
| 204 |
+
|
| 205 |
+
$$
|
| 206 |
+
\text{Embedding}(t_i) = \mathbf{e}_i \in \mathbb{R}^d
|
| 207 |
+
$$
|
| 208 |
+
|
| 209 |
+
Where:
|
| 210 |
+
|
| 211 |
+
- $t_i$: token ID
|
| 212 |
+
- $\mathbf{e}_i$: embedding vector of dimension $d$
|
| 213 |
+
|
| 214 |
+
---
|
| 215 |
+
|
| 216 |
+
### 3️⃣ Sinusoidal Positional Encoding
|
| 217 |
+
|
| 218 |
+
Used to encode the **position $p$** of a token without learnable parameters:
|
| 219 |
+
|
| 220 |
+
$$
|
| 221 |
+
\text{PE}(p, 2i) = \sin\left(\frac{p}{10000^{\frac{2i}{d}}}\right)
|
| 222 |
+
$$
|
| 223 |
+
|
| 224 |
+
$$
|
| 225 |
+
\text{PE}(p, 2i+1) = \cos\left(\frac{p}{10000^{\frac{2i}{d}}}\right)
|
| 226 |
+
$$
|
| 227 |
+
|
| 228 |
+
Where:
|
| 229 |
+
|
| 230 |
+
- $p$: position index (0, 1, 2, …)
|
| 231 |
+
- $i$: dimension index
|
| 232 |
+
- $d$: total embedding dimension
|
| 233 |
+
|
| 234 |
+
This gives a positional vector $\text{PE}(p) \in \mathbb{R}^d$
|
| 235 |
+
|
| 236 |
+
---
|
| 237 |
+
|
| 238 |
+
### 4️⃣ Add Embedding and Positional Encoding
|
| 239 |
+
|
| 240 |
+
We add the embedding and positional encoding element-wise:
|
| 241 |
+
|
| 242 |
+
$$
|
| 243 |
+
\mathbf{z}_i = \mathbf{e}_i + \text{PE}(p_i)
|
| 244 |
+
$$
|
| 245 |
+
|
| 246 |
+
Where:
|
| 247 |
+
|
| 248 |
+
- $\mathbf{z}_i$: final input to the transformer
|
| 249 |
+
|
| 250 |
+
---
|
| 251 |
+
|
| 252 |
+
### 5️⃣ Reverse Lookup (Approximate)
|
| 253 |
+
|
| 254 |
+
We find the nearest embedding using cosine similarity:
|
| 255 |
+
|
| 256 |
+
$$
|
| 257 |
+
\hat{t}_i = \underset{j}{\arg\max} \left( \frac{ \mathbf{z}_i \cdot \mathbf{e}_j }{ \| \mathbf{z}_i \| \, \| \mathbf{e}_j \| } \right)
|
| 258 |
+
$$
|
| 259 |
+
|
| 260 |
+
---
|
| 261 |
+
|
| 262 |
+
### 6️⃣ Recover Position from Embedding + PE
|
| 263 |
+
|
| 264 |
+
To isolate positional encoding:
|
| 265 |
+
|
| 266 |
+
$$
|
| 267 |
+
\text{Recovered PE}_i = \mathbf{z}_i - \mathbf{e}_i
|
| 268 |
+
$$
|
| 269 |
+
|
| 270 |
+
We then compare this with reference positional encodings to estimate token position.
|
| 271 |
+
|
| 272 |
+
---
|
| 273 |
+
|
| 274 |
+
### 🌟 Summary Table
|
| 275 |
+
|
| 276 |
+
| Step | What Happens |
|
| 277 |
+
|------|--------------|
|
| 278 |
+
| **Tokenization** | Sentence → Subwords → Token IDs |
|
| 279 |
+
| **Embedding** | Token IDs → Vectors |
|
| 280 |
+
| **Pos Encoding** | Position Index → Sin/Cos Vector |
|
| 281 |
+
| **Sum** | Embedding + PE = Input to Transformer |
|
| 282 |
+
| **Reverse** | Approximate token ID from vector |
|
| 283 |
+
| **PE Recovery** | Recover position using similarity |
|
| 284 |
+
|
| 285 |
+
""", unsafe_allow_html=True)
|