schoginitoys commited on
Commit
500c16d
·
verified ·
1 Parent(s): f4958c1

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +130 -1
src/streamlit_app.py CHANGED
@@ -7,6 +7,8 @@ import types
7
  import torch # now safe to import
8
  import streamlit as st
9
  import numpy as np
 
 
10
 
11
  # Prevent Streamlit from trying to walk torch.classes' non-standard __path__
12
  if isinstance(getattr(sys.modules.get("torch"), "classes", None), types.ModuleType):
@@ -28,7 +30,7 @@ embedding_dim = st.slider("Embedding Dimension (even only)", min_value=4, max_va
28
  # --- Load tokenizer ---
29
 
30
  # Set custom cache directory within your app's working directory (which is writable on Spaces)
31
- os.environ['TRANSFORMERS_CACHE'] = './hf_cache'
32
 
33
  # Load the tokenizer using the custom cache path
34
  # tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="./hf_cache")
@@ -294,3 +296,130 @@ We then compare this with reference positional encodings to estimate token posit
294
  | **PE Recovery** | Recover position using similarity |
295
 
296
  """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import torch # now safe to import
8
  import streamlit as st
9
  import numpy as np
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
 
13
  # Prevent Streamlit from trying to walk torch.classes' non-standard __path__
14
  if isinstance(getattr(sys.modules.get("torch"), "classes", None), types.ModuleType):
 
30
  # --- Load tokenizer ---
31
 
32
  # Set custom cache directory within your app's working directory (which is writable on Spaces)
33
+ # os.environ['TRANSFORMERS_CACHE'] = './hf_cache'
34
 
35
  # Load the tokenizer using the custom cache path
36
  # tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="./hf_cache")
 
296
  | **PE Recovery** | Recover position using similarity |
297
 
298
  """, unsafe_allow_html=True)
299
+
300
+
301
+ st.markdown("### 🤖 Transformer Internals: Key Concepts")
302
+
303
+ with st.expander("🔁 Multi-Head Attention: Q, K, V Projections"):
304
+ st.markdown(r"""
305
+ Each token embedding $\mathbf{x}_i$ is linearly projected into:
306
+ - Query vector: $Q_i = \mathbf{x}_i W^Q$
307
+ - Key vector: $K_i = \mathbf{x}_i W^K$
308
+ - Value vector: $V_i = \mathbf{x}_i W^V$
309
+
310
+ All of shape: $\mathbb{R}^{d_{model} \times d_{head}}$
311
+
312
+ Multiple such projections (heads) run in parallel:
313
+
314
+ $$
315
+ \text{MultiHead}(X) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O
316
+ $$
317
+
318
+ Each head does:
319
+ $$
320
+ \text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d_k}} \right) V
321
+ $$
322
+ """, unsafe_allow_html=True)
323
+
324
+ with st.expander("🧠 Contextualized Representations"):
325
+ st.markdown(r"""
326
+ The attention mechanism lets each token **attend to others**, allowing the output for each token to contain **context**.
327
+
328
+ For example:
329
+ - Token "fun" gets influenced by "is" and "learning"
330
+ - The output is no longer static but dynamic, depending on sentence context
331
+
332
+ This is what makes Transformers powerful for understanding relationships between tokens.
333
+ """)
334
+
335
+ with st.expander("🛠 Feed-Forward Neural Network (FFN)"):
336
+ st.markdown(r"""
337
+ After attention, each token’s vector goes through a two-layer feed-forward network applied independently:
338
+
339
+ $$
340
+ \text{FFN}(x) = \max(0, x W_1 + b_1) W_2 + b_2
341
+ $$
342
+
343
+ This allows deeper transformations on each token representation.
344
+ """)
345
+
346
+ with st.expander("📊 Softmax Over Vocabulary"):
347
+ st.markdown(r"""
348
+ The final output layer transforms each token representation to **logits** for the full vocabulary.
349
+
350
+ Then, softmax is applied to convert them into probabilities:
351
+
352
+ $$
353
+ P(w_i \mid \text{context}) = \frac{\exp(\text{logit}_i)}{\sum_j \exp(\text{logit}_j)}
354
+ $$
355
+
356
+ The token with the highest probability is typically selected as the **predicted next word**.
357
+ """)
358
+
359
+ with st.expander("🔮 Predicted Next Token"):
360
+ st.markdown(r"""
361
+ By chaining all steps (embedding → attention → FFN → softmax), the model predicts the **next token**:
362
+
363
+ E.g.,
364
+ Input: `"Learning is"`
365
+ Predicted next token: `"fun"`
366
+
367
+ This is how autoregressive models like GPT-2 **generate text** one token at a time.
368
+ """)
369
+
370
+ st.markdown("### 🎨 Visualizations: Transformer Internals")
371
+
372
+ # ---- 1. Attention Heatmap ----
373
+ with st.expander("🔁 Multi-Head Attention Score Heatmap (QKᵀ / √d)"):
374
+ st.markdown("""
375
+ This heatmap shows how the attention mechanism scores each query against all keys.
376
+ Brighter color = higher attention weight.
377
+
378
+ $$
379
+ \\text{Attention}(Q, K, V) = \\text{softmax}\\left( \\frac{QK^T}{\\sqrt{d_k}} \\right)V
380
+ $$
381
+ """, unsafe_allow_html=True)
382
+
383
+ tokens = ["Learning", "is", "fun"]
384
+ Q = np.array([[1, 0], [0.5, 0.5], [0, 1]])
385
+ K = np.array([[1, 0], [0.5, 0.5], [0, 1]])
386
+ scores = np.dot(Q, K.T) / np.sqrt(2)
387
+ softmax_scores = np.exp(scores) / np.sum(np.exp(scores), axis=1, keepdims=True)
388
+
389
+ fig1, ax1 = plt.subplots()
390
+ cax = ax1.matshow(softmax_scores, cmap="Blues")
391
+ fig1.colorbar(cax)
392
+ ax1.set_xticks(np.arange(len(tokens)))
393
+ ax1.set_xticklabels(tokens)
394
+ ax1.set_yticks(np.arange(len(tokens)))
395
+ ax1.set_yticklabels(tokens)
396
+ ax1.set_xlabel("Key Tokens (K)")
397
+ ax1.set_ylabel("Query Tokens (Q)")
398
+ ax1.set_title("Attention Score Heatmap")
399
+ st.pyplot(fig1)
400
+
401
+ # ---- 2. Softmax Curve ----
402
+ with st.expander("📊 Softmax Curve for Vocabulary Logits"):
403
+ st.markdown("""
404
+ This curve shows how softmax converts logits into probabilities.
405
+ Higher logits result in higher predicted probabilities:
406
+
407
+ $$
408
+ \\text{Softmax}(x_i) = \\frac{e^{x_i}}{\\sum_j e^{x_j}}
409
+ $$
410
+ """, unsafe_allow_html=True)
411
+
412
+ x = np.linspace(-4, 4, 100)
413
+ logits = np.vstack([x, x + 1, x - 1])
414
+ exps = np.exp(logits)
415
+ softmax = exps / np.sum(exps, axis=0)
416
+
417
+ fig2, ax2 = plt.subplots()
418
+ ax2.plot(x, softmax[0], label='Token A')
419
+ ax2.plot(x, softmax[1], label='Token B')
420
+ ax2.plot(x, softmax[2], label='Token C')
421
+ ax2.set_title("Softmax Output vs Logit Value")
422
+ ax2.set_xlabel("Logit")
423
+ ax2.set_ylabel("Probability")
424
+ ax2.legend()
425
+ st.pyplot(fig2)