vedant2905 commited on
Commit
bedc82f
·
verified ·
1 Parent(s): 0aea2d8

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +26 -32
src/streamlit_app.py CHANGED
@@ -144,7 +144,7 @@ def load_explanations(task, layer):
144
  return None
145
 
146
  def main():
147
- st.title("Token Analysis (Lang Class-layer 11 1 file left")
148
 
149
  # Task and Layer Selection
150
  col1, col2 = st.columns(2)
@@ -213,57 +213,51 @@ def main():
213
 
214
  st.metric("Predicted Cluster", selected_row['Top 1'])
215
 
216
- # Load dev sentences once
217
  dev_sentences = load_dev_sentences(selected_task, selected_layer)
218
  predicted_cluster = str(selected_row['Top 1'])
219
 
220
- if is_cls_token(token):
221
- # For CLS tokens, show the original context
222
- if dev_sentences and selected_row['line_idx'] < len(dev_sentences):
223
- st.subheader("Original Context")
224
- st.code(dev_sentences[selected_row['line_idx']].strip())
225
- else:
226
- st.info("No original context found for this token.")
227
  else:
228
- # For non-CLS tokens, show the wordcloud
229
- if clusters and predicted_cluster in clusters:
230
- # Create frequency dict for wordcloud
231
- token_frequencies = {}
232
- for token_info in clusters[predicted_cluster]:
233
- token = token_info['token']
234
- if not is_cls_token(token): # Skip CLS tokens in wordcloud
235
- token_frequencies[token] = token_frequencies.get(token, 0) + token_info['occurrence']
236
-
237
- if token_frequencies:
238
- st.subheader("Cluster Word Cloud")
239
- wordcloud = create_wordcloud(token_frequencies)
240
- if wordcloud:
241
- plt.figure(figsize=(10, 5))
242
- plt.imshow(wordcloud, interpolation='bilinear')
243
- plt.axis('off')
244
- st.pyplot(plt)
 
245
 
246
  # Show cluster contexts in expander
247
  with st.expander(f"👀 View Contexts (Cluster {predicted_cluster})"):
248
  if clusters and predicted_cluster in clusters:
249
- current_line = selected_row['line_idx'] # Get current line number
250
- shown_contexts = set() # Keep track of shown contexts to avoid duplicates
251
 
252
  for token_info in clusters[predicted_cluster]:
253
  line_num = token_info['line_num']
254
- # Skip if it's the same line as current token or if we've shown this context before
255
  if (line_num >= 0 and line_num < len(dev_sentences) and
256
  line_num != current_line):
257
  context = dev_sentences[line_num].strip()
258
- # Only show if we haven't shown this exact context before
259
  if context not in shown_contexts:
260
  st.code(context)
261
  shown_contexts.add(context)
262
 
263
  if not shown_contexts:
264
  st.info("No other similar contexts found in this cluster.")
265
- else:
266
- st.info("No similar contexts found in this cluster.")
267
 
268
  if __name__ == "__main__":
269
  main()
 
144
  return None
145
 
146
  def main():
147
+ st.title("Token Analysis")
148
 
149
  # Task and Layer Selection
150
  col1, col2 = st.columns(2)
 
213
 
214
  st.metric("Predicted Cluster", selected_row['Top 1'])
215
 
216
+ # Load dev sentences once
217
  dev_sentences = load_dev_sentences(selected_task, selected_layer)
218
  predicted_cluster = str(selected_row['Top 1'])
219
 
220
+ # Show original context for all tokens
221
+ if dev_sentences and selected_row['line_idx'] < len(dev_sentences):
222
+ st.subheader("Original Context")
223
+ st.code(dev_sentences[selected_row['line_idx']].strip())
 
 
 
224
  else:
225
+ st.info("No original context found for this token.")
226
+
227
+ # Show wordcloud for all tokens
228
+ if clusters and predicted_cluster in clusters:
229
+ # Create frequency dict for wordcloud
230
+ token_frequencies = {}
231
+ for token_info in clusters[predicted_cluster]:
232
+ token = token_info['token']
233
+ token_frequencies[token] = token_frequencies.get(token, 0) + token_info['occurrence']
234
+
235
+ if token_frequencies:
236
+ st.subheader("Cluster Word Cloud")
237
+ wordcloud = create_wordcloud(token_frequencies)
238
+ if wordcloud:
239
+ plt.figure(figsize=(10, 5))
240
+ plt.imshow(wordcloud, interpolation='bilinear')
241
+ plt.axis('off')
242
+ st.pyplot(plt)
243
 
244
  # Show cluster contexts in expander
245
  with st.expander(f"👀 View Contexts (Cluster {predicted_cluster})"):
246
  if clusters and predicted_cluster in clusters:
247
+ current_line = selected_row['line_idx']
248
+ shown_contexts = set()
249
 
250
  for token_info in clusters[predicted_cluster]:
251
  line_num = token_info['line_num']
 
252
  if (line_num >= 0 and line_num < len(dev_sentences) and
253
  line_num != current_line):
254
  context = dev_sentences[line_num].strip()
 
255
  if context not in shown_contexts:
256
  st.code(context)
257
  shown_contexts.add(context)
258
 
259
  if not shown_contexts:
260
  st.info("No other similar contexts found in this cluster.")
 
 
261
 
262
  if __name__ == "__main__":
263
  main()