Spaces:
Build error
Build error
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +26 -32
src/streamlit_app.py
CHANGED
|
@@ -144,7 +144,7 @@ def load_explanations(task, layer):
|
|
| 144 |
return None
|
| 145 |
|
| 146 |
def main():
|
| 147 |
-
st.title("Token Analysis
|
| 148 |
|
| 149 |
# Task and Layer Selection
|
| 150 |
col1, col2 = st.columns(2)
|
|
@@ -213,57 +213,51 @@ def main():
|
|
| 213 |
|
| 214 |
st.metric("Predicted Cluster", selected_row['Top 1'])
|
| 215 |
|
| 216 |
-
|
| 217 |
dev_sentences = load_dev_sentences(selected_task, selected_layer)
|
| 218 |
predicted_cluster = str(selected_row['Top 1'])
|
| 219 |
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
st.code(dev_sentences[selected_row['line_idx']].strip())
|
| 225 |
-
else:
|
| 226 |
-
st.info("No original context found for this token.")
|
| 227 |
else:
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
|
|
|
| 245 |
|
| 246 |
# Show cluster contexts in expander
|
| 247 |
with st.expander(f"👀 View Contexts (Cluster {predicted_cluster})"):
|
| 248 |
if clusters and predicted_cluster in clusters:
|
| 249 |
-
current_line = selected_row['line_idx']
|
| 250 |
-
shown_contexts = set()
|
| 251 |
|
| 252 |
for token_info in clusters[predicted_cluster]:
|
| 253 |
line_num = token_info['line_num']
|
| 254 |
-
# Skip if it's the same line as current token or if we've shown this context before
|
| 255 |
if (line_num >= 0 and line_num < len(dev_sentences) and
|
| 256 |
line_num != current_line):
|
| 257 |
context = dev_sentences[line_num].strip()
|
| 258 |
-
# Only show if we haven't shown this exact context before
|
| 259 |
if context not in shown_contexts:
|
| 260 |
st.code(context)
|
| 261 |
shown_contexts.add(context)
|
| 262 |
|
| 263 |
if not shown_contexts:
|
| 264 |
st.info("No other similar contexts found in this cluster.")
|
| 265 |
-
else:
|
| 266 |
-
st.info("No similar contexts found in this cluster.")
|
| 267 |
|
| 268 |
if __name__ == "__main__":
|
| 269 |
main()
|
|
|
|
| 144 |
return None
|
| 145 |
|
| 146 |
def main():
|
| 147 |
+
st.title("Token Analysis")
|
| 148 |
|
| 149 |
# Task and Layer Selection
|
| 150 |
col1, col2 = st.columns(2)
|
|
|
|
| 213 |
|
| 214 |
st.metric("Predicted Cluster", selected_row['Top 1'])
|
| 215 |
|
| 216 |
+
# Load dev sentences once
|
| 217 |
dev_sentences = load_dev_sentences(selected_task, selected_layer)
|
| 218 |
predicted_cluster = str(selected_row['Top 1'])
|
| 219 |
|
| 220 |
+
# Show original context for all tokens
|
| 221 |
+
if dev_sentences and selected_row['line_idx'] < len(dev_sentences):
|
| 222 |
+
st.subheader("Original Context")
|
| 223 |
+
st.code(dev_sentences[selected_row['line_idx']].strip())
|
|
|
|
|
|
|
|
|
|
| 224 |
else:
|
| 225 |
+
st.info("No original context found for this token.")
|
| 226 |
+
|
| 227 |
+
# Show wordcloud for all tokens
|
| 228 |
+
if clusters and predicted_cluster in clusters:
|
| 229 |
+
# Create frequency dict for wordcloud
|
| 230 |
+
token_frequencies = {}
|
| 231 |
+
for token_info in clusters[predicted_cluster]:
|
| 232 |
+
token = token_info['token']
|
| 233 |
+
token_frequencies[token] = token_frequencies.get(token, 0) + token_info['occurrence']
|
| 234 |
+
|
| 235 |
+
if token_frequencies:
|
| 236 |
+
st.subheader("Cluster Word Cloud")
|
| 237 |
+
wordcloud = create_wordcloud(token_frequencies)
|
| 238 |
+
if wordcloud:
|
| 239 |
+
plt.figure(figsize=(10, 5))
|
| 240 |
+
plt.imshow(wordcloud, interpolation='bilinear')
|
| 241 |
+
plt.axis('off')
|
| 242 |
+
st.pyplot(plt)
|
| 243 |
|
| 244 |
# Show cluster contexts in expander
|
| 245 |
with st.expander(f"👀 View Contexts (Cluster {predicted_cluster})"):
|
| 246 |
if clusters and predicted_cluster in clusters:
|
| 247 |
+
current_line = selected_row['line_idx']
|
| 248 |
+
shown_contexts = set()
|
| 249 |
|
| 250 |
for token_info in clusters[predicted_cluster]:
|
| 251 |
line_num = token_info['line_num']
|
|
|
|
| 252 |
if (line_num >= 0 and line_num < len(dev_sentences) and
|
| 253 |
line_num != current_line):
|
| 254 |
context = dev_sentences[line_num].strip()
|
|
|
|
| 255 |
if context not in shown_contexts:
|
| 256 |
st.code(context)
|
| 257 |
shown_contexts.add(context)
|
| 258 |
|
| 259 |
if not shown_contexts:
|
| 260 |
st.info("No other similar contexts found in this cluster.")
|
|
|
|
|
|
|
| 261 |
|
| 262 |
if __name__ == "__main__":
|
| 263 |
main()
|