Spaces:
Build error
Build error
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +46 -140
src/streamlit_app.py
CHANGED
|
@@ -12,25 +12,11 @@ st.set_page_config(
|
|
| 12 |
layout="wide"
|
| 13 |
)
|
| 14 |
|
| 15 |
-
# Page title with custom styling
|
| 16 |
-
st.markdown('<h1 class="main-header">🔍 Code Token Cluster Visualization</h1>', unsafe_allow_html=True)
|
| 17 |
-
st.markdown("""
|
| 18 |
-
<p>Explore token clusters from language model representations.
|
| 19 |
-
Select a token to view its cluster information and contexts.</p>
|
| 20 |
-
""", unsafe_allow_html=True)
|
| 21 |
-
|
| 22 |
-
# Create sidebar for input controls
|
| 23 |
-
with st.sidebar:
|
| 24 |
-
st.markdown("## 🛠️ Controls")
|
| 25 |
-
st.markdown("---")
|
| 26 |
-
|
| 27 |
# Functions to load data
|
| 28 |
@st.cache_data
|
| 29 |
def load_predictions(file_path):
|
| 30 |
"""Load the predictions CSV file."""
|
| 31 |
-
# Prepend src to the file path
|
| 32 |
full_path = os.path.join("src", file_path)
|
| 33 |
-
# Read the file with all columns as string type initially
|
| 34 |
df = pd.read_csv(full_path, sep="\t", dtype=str)
|
| 35 |
|
| 36 |
# Convert numeric columns safely
|
|
@@ -71,28 +57,26 @@ def create_wordcloud(tokens):
|
|
| 71 |
if not tokens:
|
| 72 |
return None
|
| 73 |
|
| 74 |
-
# Create a dictionary with equal weights for all tokens
|
| 75 |
token_weights = {token: 1 for token in set(tokens)}
|
| 76 |
|
| 77 |
wordcloud = WordCloud(
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
).generate_from_frequencies(token_weights)
|
| 93 |
|
| 94 |
-
|
| 95 |
-
fig = plt.figure(figsize=(16, 8)) # Increased figure size
|
| 96 |
plt.imshow(wordcloud, interpolation='bilinear')
|
| 97 |
plt.axis('off')
|
| 98 |
plt.tight_layout(pad=0)
|
|
@@ -102,7 +86,6 @@ def create_wordcloud(tokens):
|
|
| 102 |
def load_dev_sentences():
|
| 103 |
"""Load sentences from dev.in file."""
|
| 104 |
try:
|
| 105 |
-
# Try different possible locations of dev.in
|
| 106 |
possible_paths = [
|
| 107 |
os.path.join("src", "codebert", "compile_error", "dev.in"),
|
| 108 |
os.path.join("src", "codebert", "language_classification", "layer11", "dev.in"),
|
|
@@ -118,73 +101,20 @@ def load_dev_sentences():
|
|
| 118 |
st.error(f"Error loading dev.in file: {str(e)}")
|
| 119 |
return []
|
| 120 |
|
| 121 |
-
def get_available_models():
|
| 122 |
-
# Check in the src directory for the codebert folder
|
| 123 |
-
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 124 |
-
model_path = os.path.join("src", "codebert")
|
| 125 |
-
if os.path.exists(model_path):
|
| 126 |
-
return ["codebert"]
|
| 127 |
-
return []
|
| 128 |
-
|
| 129 |
-
def get_available_tasks(model):
|
| 130 |
-
if not model:
|
| 131 |
-
return []
|
| 132 |
-
|
| 133 |
-
model_dir = os.path.join("src", model)
|
| 134 |
-
if os.path.exists(model_dir):
|
| 135 |
-
return [d for d in os.listdir(model_dir)
|
| 136 |
-
if os.path.isdir(os.path.join(model_dir, d))]
|
| 137 |
-
return []
|
| 138 |
-
|
| 139 |
-
def get_available_layers(model, task):
|
| 140 |
-
if not model or not task:
|
| 141 |
-
return []
|
| 142 |
-
|
| 143 |
-
task_dir = os.path.join("src", model, task)
|
| 144 |
-
if os.path.exists(task_dir):
|
| 145 |
-
layers = [d for d in os.listdir(task_dir)
|
| 146 |
-
if os.path.isdir(os.path.join(task_dir, d)) and d.startswith('layer')]
|
| 147 |
-
return sorted(layers, key=lambda x: int(x.replace('layer', '')))
|
| 148 |
-
return []
|
| 149 |
-
|
| 150 |
def main():
|
| 151 |
st.title("Code Token Cluster Visualization")
|
| 152 |
|
| 153 |
-
# Model selection
|
| 154 |
-
available_models =
|
| 155 |
-
|
| 156 |
-
st.error("No models found in the workspace")
|
| 157 |
-
return
|
| 158 |
-
|
| 159 |
-
model = st.selectbox(
|
| 160 |
-
"Select Model",
|
| 161 |
-
options=available_models,
|
| 162 |
-
index=0
|
| 163 |
-
)
|
| 164 |
|
| 165 |
-
# Task selection
|
| 166 |
-
tasks =
|
| 167 |
-
|
| 168 |
-
st.error(f"No tasks found for model {model}")
|
| 169 |
-
return
|
| 170 |
-
|
| 171 |
-
task = st.selectbox(
|
| 172 |
-
"Select Task",
|
| 173 |
-
options=tasks,
|
| 174 |
-
index=tasks.index("language_classification") if "language_classification" in tasks else 0
|
| 175 |
-
)
|
| 176 |
|
| 177 |
-
# Layer selection
|
| 178 |
-
layers =
|
| 179 |
-
|
| 180 |
-
st.error(f"No layers found for {model}/{task}")
|
| 181 |
-
return
|
| 182 |
-
|
| 183 |
-
layer = st.selectbox(
|
| 184 |
-
"Select Layer",
|
| 185 |
-
options=layers,
|
| 186 |
-
index=layers.index("layer6") if "layer6" in layers else 0
|
| 187 |
-
)
|
| 188 |
|
| 189 |
# Fix the file paths
|
| 190 |
layer_dir = os.path.join(model, task, layer)
|
|
@@ -198,44 +128,29 @@ def main():
|
|
| 198 |
clusters = load_clusters(clusters_file)
|
| 199 |
sentences = load_input_data(input_file)
|
| 200 |
|
| 201 |
-
#
|
| 202 |
-
|
| 203 |
|
| 204 |
-
#
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
with col1:
|
| 208 |
-
st.metric("Total Tokens", f"{len(predictions_df):,}")
|
| 209 |
-
with col2:
|
| 210 |
-
st.metric("Total Clusters", f"{len(clusters):,}")
|
| 211 |
-
with col3:
|
| 212 |
-
avg_tokens = sum(len(tokens) for tokens in clusters.values()) / max(len(clusters), 1)
|
| 213 |
-
st.metric("Avg. Tokens per Cluster", f"{avg_tokens:.1f}")
|
| 214 |
|
| 215 |
-
#
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
else:
|
| 232 |
-
token_options = all_tokens
|
| 233 |
-
|
| 234 |
-
selected_token = st.selectbox(
|
| 235 |
-
"Select a token:",
|
| 236 |
-
token_options,
|
| 237 |
-
index=0 if token_options else None
|
| 238 |
-
)
|
| 239 |
|
| 240 |
# Main content
|
| 241 |
if selected_token:
|
|
@@ -243,7 +158,6 @@ def main():
|
|
| 243 |
token_instances = predictions_df[predictions_df['Token'] == selected_token]
|
| 244 |
|
| 245 |
if not token_instances.empty:
|
| 246 |
-
# Simple header and token display
|
| 247 |
st.title(f"Token: {selected_token}")
|
| 248 |
|
| 249 |
# Get most frequent cluster (Top 1) for this token
|
|
@@ -298,14 +212,6 @@ def main():
|
|
| 298 |
st.info("No contexts found in this cluster")
|
| 299 |
else:
|
| 300 |
st.warning(f"No instances found for token: {selected_token}")
|
| 301 |
-
else:
|
| 302 |
-
# Show welcome message when no token is selected
|
| 303 |
-
st.markdown("""
|
| 304 |
-
<div style="text-align: center; margin-top: 50px; color: #757575;">
|
| 305 |
-
<h3>👈 Select a token from the sidebar to begin</h3>
|
| 306 |
-
<p>The visualization will show cluster information and code contexts.</p>
|
| 307 |
-
</div>
|
| 308 |
-
""", unsafe_allow_html=True)
|
| 309 |
|
| 310 |
except Exception as e:
|
| 311 |
st.error("An error occurred while processing the data")
|
|
|
|
| 12 |
layout="wide"
|
| 13 |
)
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
# Functions to load data
|
| 16 |
@st.cache_data
|
| 17 |
def load_predictions(file_path):
|
| 18 |
"""Load the predictions CSV file."""
|
|
|
|
| 19 |
full_path = os.path.join("src", file_path)
|
|
|
|
| 20 |
df = pd.read_csv(full_path, sep="\t", dtype=str)
|
| 21 |
|
| 22 |
# Convert numeric columns safely
|
|
|
|
| 57 |
if not tokens:
|
| 58 |
return None
|
| 59 |
|
|
|
|
| 60 |
token_weights = {token: 1 for token in set(tokens)}
|
| 61 |
|
| 62 |
wordcloud = WordCloud(
|
| 63 |
+
width=1000,
|
| 64 |
+
height=500,
|
| 65 |
+
background_color='#FFF0DB',
|
| 66 |
+
prefer_horizontal=1,
|
| 67 |
+
min_font_size=10,
|
| 68 |
+
max_font_size=150,
|
| 69 |
+
relative_scaling=0.5,
|
| 70 |
+
collocations=False,
|
| 71 |
+
margin=1,
|
| 72 |
+
random_state=42,
|
| 73 |
+
scale=2,
|
| 74 |
+
repeat=False,
|
| 75 |
+
max_words=2000,
|
| 76 |
+
regexp=r"\w[\w' ]+"
|
| 77 |
+
).generate_from_frequencies(token_weights)
|
| 78 |
|
| 79 |
+
fig = plt.figure(figsize=(16, 8))
|
|
|
|
| 80 |
plt.imshow(wordcloud, interpolation='bilinear')
|
| 81 |
plt.axis('off')
|
| 82 |
plt.tight_layout(pad=0)
|
|
|
|
| 86 |
def load_dev_sentences():
|
| 87 |
"""Load sentences from dev.in file."""
|
| 88 |
try:
|
|
|
|
| 89 |
possible_paths = [
|
| 90 |
os.path.join("src", "codebert", "compile_error", "dev.in"),
|
| 91 |
os.path.join("src", "codebert", "language_classification", "layer11", "dev.in"),
|
|
|
|
| 101 |
st.error(f"Error loading dev.in file: {str(e)}")
|
| 102 |
return []
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
def main():
|
| 105 |
st.title("Code Token Cluster Visualization")
|
| 106 |
|
| 107 |
+
# Model selection
|
| 108 |
+
available_models = ["codebert"]
|
| 109 |
+
model = st.selectbox("Select Model", available_models)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
+
# Task selection
|
| 112 |
+
tasks = ["language_classification"]
|
| 113 |
+
task = st.selectbox("Select Task", tasks)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
+
# Layer selection
|
| 116 |
+
layers = ["layer6", "layer11"]
|
| 117 |
+
layer = st.selectbox("Select Layer", layers)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
# Fix the file paths
|
| 120 |
layer_dir = os.path.join(model, task, layer)
|
|
|
|
| 128 |
clusters = load_clusters(clusters_file)
|
| 129 |
sentences = load_input_data(input_file)
|
| 130 |
|
| 131 |
+
# Create token selector in sidebar
|
| 132 |
+
st.sidebar.title("Token Selection")
|
| 133 |
|
| 134 |
+
# Convert all tokens to strings before sorting
|
| 135 |
+
all_tokens = sorted([str(token) for token in predictions_df['Token'].unique()],
|
| 136 |
+
key=lambda x: (x.lower() if isinstance(x, str) else str(x)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
+
# Add a search box to filter tokens
|
| 139 |
+
token_search = st.sidebar.text_input("Search tokens")
|
| 140 |
+
|
| 141 |
+
if token_search:
|
| 142 |
+
filtered_tokens = [t for t in all_tokens if token_search.lower() in t.lower()]
|
| 143 |
+
token_options = filtered_tokens if filtered_tokens else all_tokens
|
| 144 |
+
if not filtered_tokens:
|
| 145 |
+
st.sidebar.warning(f"No tokens matching '{token_search}'")
|
| 146 |
+
else:
|
| 147 |
+
token_options = all_tokens
|
| 148 |
+
|
| 149 |
+
selected_token = st.sidebar.selectbox(
|
| 150 |
+
"Select a token:",
|
| 151 |
+
token_options,
|
| 152 |
+
index=0 if token_options else None
|
| 153 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
# Main content
|
| 156 |
if selected_token:
|
|
|
|
| 158 |
token_instances = predictions_df[predictions_df['Token'] == selected_token]
|
| 159 |
|
| 160 |
if not token_instances.empty:
|
|
|
|
| 161 |
st.title(f"Token: {selected_token}")
|
| 162 |
|
| 163 |
# Get most frequent cluster (Top 1) for this token
|
|
|
|
| 212 |
st.info("No contexts found in this cluster")
|
| 213 |
else:
|
| 214 |
st.warning(f"No instances found for token: {selected_token}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
except Exception as e:
|
| 217 |
st.error("An error occurred while processing the data")
|