schoginitoys commited on
Commit
3adf2d7
Β·
verified Β·
1 Parent(s): 2971f05

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +35 -41
src/streamlit_app.py CHANGED
@@ -30,28 +30,27 @@ if input_text:
30
  # ---------- Tokenization Info ----------
31
  st.subheader("πŸ”€ Token Information")
32
  st.markdown("This shows how your input text is broken down into tokens. Each token is a subword unit that the model processes individually.")
33
-
34
- if st.button("πŸ” Show Token Details"):
35
- enc = tiktoken.get_encoding(tokenizer_name)
36
- tokens = enc.encode(input_text)
37
- token_strings = [enc.decode([t]) for t in tokens]
38
 
39
- with st.expander("🧾 Token IDs"):
40
- st.write(tokens)
 
 
 
 
41
 
42
- with st.expander("πŸ“– Decoded Tokens"):
43
- st.write(token_strings)
44
 
45
- st.info(f"Token count: {len(tokens)}")
46
 
47
- if st.button("πŸ“Š Show Token ID Chart"):
48
- fig, ax = plt.subplots()
49
- ax.bar(range(len(tokens)), tokens, tick_label=token_strings)
50
- ax.set_xlabel("Token")
51
- ax.set_ylabel("Token ID")
52
- ax.set_title("Token IDs for Input Text")
53
- plt.xticks(rotation=45, ha='right')
54
- st.pyplot(fig)
55
 
56
  # ---------- Embedding Section ----------
57
  st.subheader("πŸ”— Token Embeddings (OpenAI)")
@@ -64,10 +63,6 @@ if input_text:
64
  if st.button("πŸ“‘ Generate Embeddings"):
65
  with st.spinner("Generating embedding for each token..."):
66
  try:
67
- enc = tiktoken.get_encoding(tokenizer_name)
68
- tokens = enc.encode(input_text)
69
- token_strings = [enc.decode([t]) for t in tokens]
70
-
71
  all_embeddings = []
72
 
73
  for i, token_text in enumerate(token_strings):
@@ -78,7 +73,7 @@ if input_text:
78
  embedding = response.data[0].embedding
79
  all_embeddings.append(embedding)
80
 
81
- with st.expander(f"πŸ”Έ Token {i+1}: '{token_text}'"):
82
  st.write(embedding)
83
  st.caption(f"Embedding dimension: {len(embedding)}")
84
 
@@ -91,16 +86,16 @@ if input_text:
91
 
92
  st.success(f"Successfully generated embeddings for {len(token_strings)} tokens.")
93
 
94
- # Optional PCA Visualization
95
- if st.checkbox("🧭 Visualize all embeddings in 2D (PCA)"):
96
- pca = PCA(n_components=2)
97
- reduced = pca.fit_transform(np.array(all_embeddings))
98
- fig, ax = plt.subplots()
99
- ax.scatter(reduced[:, 0], reduced[:, 1])
100
- for i, label in enumerate(token_strings):
101
- ax.text(reduced[i, 0], reduced[i, 1], label, fontsize=9)
102
- ax.set_title("Token Embeddings (PCA 2D)")
103
- st.pyplot(fig)
104
 
105
  except Exception as e:
106
  st.error(f"OpenAI Error: {str(e)}")
@@ -114,8 +109,6 @@ if input_text:
114
  """)
115
 
116
  if st.button("πŸŒ€ Generate Positional Encoding"):
117
- enc = tiktoken.get_encoding(tokenizer_name)
118
- tokens = enc.encode(input_text)
119
  seq_len = len(tokens)
120
  dim = st.slider("Select positional encoding dimension:", 16, 512, 64, step=16)
121
 
@@ -131,12 +124,13 @@ if input_text:
131
 
132
  PE = get_positional_encoding(seq_len, dim)
133
 
134
- with st.expander("πŸ“ Positional Encoding Matrix"):
135
  st.write(PE)
136
  st.caption(f"Shape: {PE.shape}")
137
 
138
- if st.checkbox("πŸ”¬ Show Positional Encoding Heatmap"):
139
- fig, ax = plt.subplots(figsize=(10, seq_len // 2 + 1))
140
- sns.heatmap(PE, cmap="coolwarm", cbar=True, ax=ax)
141
- ax.set_title("Positional Encoding Heatmap")
142
- st.pyplot(fig)
 
 
30
  # ---------- Tokenization Info ----------
31
  st.subheader("πŸ”€ Token Information")
32
  st.markdown("This shows how your input text is broken down into tokens. Each token is a subword unit that the model processes individually.")
 
 
 
 
 
33
 
34
+ enc = tiktoken.get_encoding(tokenizer_name)
35
+ tokens = enc.encode(input_text)
36
+ token_strings = [enc.decode([t]) for t in tokens]
37
+
38
+ with st.expander("🧾 Token IDs", expanded=True):
39
+ st.write(tokens)
40
 
41
+ with st.expander("πŸ“– Decoded Tokens", expanded=True):
42
+ st.write(token_strings)
43
 
44
+ st.info(f"Token count: {len(tokens)}")
45
 
46
+ # βœ… Always show token ID chart
47
+ fig, ax = plt.subplots()
48
+ ax.bar(range(len(tokens)), tokens, tick_label=token_strings)
49
+ ax.set_xlabel("Token")
50
+ ax.set_ylabel("Token ID")
51
+ ax.set_title("Token IDs for Input Text")
52
+ plt.xticks(rotation=45, ha='right')
53
+ st.pyplot(fig)
54
 
55
  # ---------- Embedding Section ----------
56
  st.subheader("πŸ”— Token Embeddings (OpenAI)")
 
63
  if st.button("πŸ“‘ Generate Embeddings"):
64
  with st.spinner("Generating embedding for each token..."):
65
  try:
 
 
 
 
66
  all_embeddings = []
67
 
68
  for i, token_text in enumerate(token_strings):
 
73
  embedding = response.data[0].embedding
74
  all_embeddings.append(embedding)
75
 
76
+ with st.expander(f"πŸ”Έ Token {i+1}: '{token_text}'", expanded=True):
77
  st.write(embedding)
78
  st.caption(f"Embedding dimension: {len(embedding)}")
79
 
 
86
 
87
  st.success(f"Successfully generated embeddings for {len(token_strings)} tokens.")
88
 
89
+ # βœ… PCA Visualization ON by default
90
+ st.subheader("🧭 Token Embeddings in 2D (PCA)")
91
+ pca = PCA(n_components=2)
92
+ reduced = pca.fit_transform(np.array(all_embeddings))
93
+ fig, ax = plt.subplots()
94
+ ax.scatter(reduced[:, 0], reduced[:, 1])
95
+ for i, label in enumerate(token_strings):
96
+ ax.text(reduced[i, 0], reduced[i, 1], label, fontsize=9)
97
+ ax.set_title("Token Embeddings (PCA 2D)")
98
+ st.pyplot(fig)
99
 
100
  except Exception as e:
101
  st.error(f"OpenAI Error: {str(e)}")
 
109
  """)
110
 
111
  if st.button("πŸŒ€ Generate Positional Encoding"):
 
 
112
  seq_len = len(tokens)
113
  dim = st.slider("Select positional encoding dimension:", 16, 512, 64, step=16)
114
 
 
124
 
125
  PE = get_positional_encoding(seq_len, dim)
126
 
127
+ with st.expander("πŸ“ Positional Encoding Matrix", expanded=True):
128
  st.write(PE)
129
  st.caption(f"Shape: {PE.shape}")
130
 
131
+ # βœ… Default show heatmap ON
132
+ st.subheader("πŸ”¬ Positional Encoding Heatmap")
133
+ fig, ax = plt.subplots(figsize=(10, seq_len // 2 + 1))
134
+ sns.heatmap(PE, cmap="coolwarm", cbar=True, ax=ax)
135
+ ax.set_title("Positional Encoding Heatmap")
136
+ st.pyplot(fig)