embedding-positional / src /streamlit_app.py
schoginitoys's picture
Update src/streamlit_app.py
316297e verified
import streamlit as st
import numpy as np
from transformers import GPT2TokenizerFast, GPT2Model
# 1. Load tokenizer and model
@st.cache_resource
def load_resources():
# tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer = GPT2TokenizerFast.from_pretrained("./assets/tokenizer", local_files_only=True)
# model = GPT2Model.from_pretrained("gpt2")
model = GPT2Model.from_pretrained("./assets/model", local_files_only=True)
# from transformers import GPT2TokenizerFast
# # Load tokenizer from bundled local files only
#
return tokenizer, model
# Initialize resources
tokenizer, model = load_resources()
# 2. Helper to get the full embedding matrix
@st.cache_resource
def get_embedding_matrix():
return model.get_input_embeddings().weight.detach().cpu().numpy()
# 3. Initialize session state
for key in ["tokens", "token_ids", "embeddings", "current_id"]:
if key not in st.session_state:
if key in ["tokens", "token_ids"]:
st.session_state[key] = []
else:
st.session_state[key] = {} if key == "embeddings" else None
st.title("🔍 Embedding & Positional Encoding Explorer")
# 4. Sentence input & BPE tokenize
sentence = st.text_input("Enter a sentence to tokenize:")
if st.button("BPE Tokenize"):
ids = tokenizer.encode(sentence, add_special_tokens=False)
toks = tokenizer.convert_ids_to_tokens(ids)
st.session_state.tokens = toks
st.session_state.token_ids = ids
# 5. Display tokens + IDs with embedding buttons
if st.session_state.tokens:
st.subheader("Tokens and IDs")
cols = st.columns([4, 1])
for i, (tok, tid) in enumerate(zip(st.session_state.tokens, st.session_state.token_ids)):
cols[0].write(f"{i+1}. **{tok}** → ID {tid}")
if cols[1].button(f"Create Embedding for {tid}", key=f"embed_{tid}"):
vec = model.get_input_embeddings().weight[tid].detach().cpu().numpy()
st.session_state.embeddings[tid] = vec.copy()
st.session_state.current_id = tid
# 6. Show & edit embedding sliders for selected token
if st.session_state.current_id is not None:
tok_id = st.session_state.current_id
emb_vec = st.session_state.embeddings[tok_id]
st.subheader(f"Embedding for token ID {tok_id}")
for dim in range(len(emb_vec)):
emb_vec[dim] = st.slider(
f"Emb Dim {dim}", -5.0, 5.0, float(emb_vec[dim]), step=0.01,
key=f"slider_{tok_id}_{dim}"
)
st.session_state.embeddings[tok_id] = emb_vec
# 7. Similarity search on current embedding
# if st.button("Similarity Search", key="sim_search"):
# matrix = get_embedding_matrix()
# query = emb_vec
# dot = matrix.dot(query)
# mat_norm = np.linalg.norm(matrix, axis=1)
# q_norm = np.linalg.norm(query)
# sims = dot / (mat_norm * q_norm + 1e-12)
# topk = (-sims).argsort()[1:21]
# st.write("**Top 20 similar tokens:**")
# for idx in topk:
# token_str = tokenizer.convert_ids_to_tokens([idx])[0]
# st.write(f"ID {idx} ({token_str}): {sims[idx]:.4f}")
# 8. Positional Encoding inputs
st.subheader("Positional Encoding")
# Show formula in LaTeX
st.markdown(r"""
**Positional Encoding Formula**
For position $p$ and dimension $d$ (where $D$ is the embedding size):
$$
PE(p,d) = \begin{cases}
\sin\bigl(\frac{p}{10000^{d / D}}\bigr), & \text{if } d \text{ is even} \\
\cos\bigl(\frac{p}{10000^{(d-1) / D}}\bigr), & \text{if } d \text{ is odd}
\end{cases}
$$
""")
pos = st.number_input("Position (p)", min_value=0, format="%d")
dim = st.number_input(
"Dimension index (0-based)", min_value=0, max_value=len(emb_vec)-1, format="%d"
)
emb_dim = st.number_input(
"Embedding Dimension (vector length)", value=len(emb_vec), format="%d"
)
# 9. Add Pos Encoding
if st.button("Compute and Add Pos Encoding to the Embedding"):
p, d, D = int(pos), int(dim), int(emb_dim)
if 0 <= d < D:
if d % 2 == 0:
pe = np.sin(p / (10000 ** (d / D)))
else:
pe = np.cos(p / (10000 ** ((d - 1) / D)))
emb_vec[d] += pe
st.session_state.embeddings[tok_id] = emb_vec
else:
st.error("Dimension index out of range.")
# 10. Similarity search with positional encoding
if st.button("Similarity Search (Using the Embedding)", key="sim_search_pos"):
matrix = get_embedding_matrix()
query = st.session_state.embeddings[tok_id]
dot = matrix.dot(query)
mat_norm = np.linalg.norm(matrix, axis=1)
q_norm = np.linalg.norm(query)
sims = dot / (mat_norm * q_norm + 1e-12)
topk = (-sims).argsort()[1:21]
st.write("**Top 20 similar tokens after PosEnc:**")
for idx in topk:
token_str = tokenizer.convert_ids_to_tokens([idx])[0]
st.write(f"ID {idx} ({token_str}): {sims[idx]:.4f}")