vedant2905 commited on
Commit
eeee269
·
verified ·
1 Parent(s): ef99f5c

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +46 -140
src/streamlit_app.py CHANGED
@@ -12,25 +12,11 @@ st.set_page_config(
12
  layout="wide"
13
  )
14
 
15
- # Page title with custom styling
16
- st.markdown('<h1 class="main-header">🔍 Code Token Cluster Visualization</h1>', unsafe_allow_html=True)
17
- st.markdown("""
18
- <p>Explore token clusters from language model representations.
19
- Select a token to view its cluster information and contexts.</p>
20
- """, unsafe_allow_html=True)
21
-
22
- # Create sidebar for input controls
23
- with st.sidebar:
24
- st.markdown("## 🛠️ Controls")
25
- st.markdown("---")
26
-
27
  # Functions to load data
28
  @st.cache_data
29
  def load_predictions(file_path):
30
  """Load the predictions CSV file."""
31
- # Prepend src to the file path
32
  full_path = os.path.join("src", file_path)
33
- # Read the file with all columns as string type initially
34
  df = pd.read_csv(full_path, sep="\t", dtype=str)
35
 
36
  # Convert numeric columns safely
@@ -71,28 +57,26 @@ def create_wordcloud(tokens):
71
  if not tokens:
72
  return None
73
 
74
- # Create a dictionary with equal weights for all tokens
75
  token_weights = {token: 1 for token in set(tokens)}
76
 
77
  wordcloud = WordCloud(
78
- width=1000,
79
- height=500,
80
- background_color='#FFF0DB',
81
- prefer_horizontal=1,
82
- min_font_size=10, # Reduced to allow for more flexible scaling
83
- max_font_size=150, # Increased for better range
84
- relative_scaling=0.5, # Added relative scaling to vary sizes based on frequency
85
- collocations=False,
86
- margin=1,
87
- random_state=42,
88
- scale=2, # Increased scale for better resolution
89
- repeat=False,
90
- max_words=2000,
91
- regexp=r"\w[\w' ]+",
92
- ).generate_from_frequencies(token_weights)
93
 
94
- # Create a new figure with tight layout and adjusted size
95
- fig = plt.figure(figsize=(16, 8)) # Increased figure size
96
  plt.imshow(wordcloud, interpolation='bilinear')
97
  plt.axis('off')
98
  plt.tight_layout(pad=0)
@@ -102,7 +86,6 @@ def create_wordcloud(tokens):
102
  def load_dev_sentences():
103
  """Load sentences from dev.in file."""
104
  try:
105
- # Try different possible locations of dev.in
106
  possible_paths = [
107
  os.path.join("src", "codebert", "compile_error", "dev.in"),
108
  os.path.join("src", "codebert", "language_classification", "layer11", "dev.in"),
@@ -118,73 +101,20 @@ def load_dev_sentences():
118
  st.error(f"Error loading dev.in file: {str(e)}")
119
  return []
120
 
121
- def get_available_models():
122
- # Check in the src directory for the codebert folder
123
- current_dir = os.path.dirname(os.path.abspath(__file__))
124
- model_path = os.path.join("src", "codebert")
125
- if os.path.exists(model_path):
126
- return ["codebert"]
127
- return []
128
-
129
- def get_available_tasks(model):
130
- if not model:
131
- return []
132
-
133
- model_dir = os.path.join("src", model)
134
- if os.path.exists(model_dir):
135
- return [d for d in os.listdir(model_dir)
136
- if os.path.isdir(os.path.join(model_dir, d))]
137
- return []
138
-
139
- def get_available_layers(model, task):
140
- if not model or not task:
141
- return []
142
-
143
- task_dir = os.path.join("src", model, task)
144
- if os.path.exists(task_dir):
145
- layers = [d for d in os.listdir(task_dir)
146
- if os.path.isdir(os.path.join(task_dir, d)) and d.startswith('layer')]
147
- return sorted(layers, key=lambda x: int(x.replace('layer', '')))
148
- return []
149
-
150
  def main():
151
  st.title("Code Token Cluster Visualization")
152
 
153
- # Model selection with null checks
154
- available_models = get_available_models()
155
- if not available_models:
156
- st.error("No models found in the workspace")
157
- return
158
-
159
- model = st.selectbox(
160
- "Select Model",
161
- options=available_models,
162
- index=0
163
- )
164
 
165
- # Task selection with null checks
166
- tasks = get_available_tasks(model)
167
- if not tasks:
168
- st.error(f"No tasks found for model {model}")
169
- return
170
-
171
- task = st.selectbox(
172
- "Select Task",
173
- options=tasks,
174
- index=tasks.index("language_classification") if "language_classification" in tasks else 0
175
- )
176
 
177
- # Layer selection with null checks
178
- layers = get_available_layers(model, task)
179
- if not layers:
180
- st.error(f"No layers found for {model}/{task}")
181
- return
182
-
183
- layer = st.selectbox(
184
- "Select Layer",
185
- options=layers,
186
- index=layers.index("layer6") if "layer6" in layers else 0
187
- )
188
 
189
  # Fix the file paths
190
  layer_dir = os.path.join(model, task, layer)
@@ -198,44 +128,29 @@ def main():
198
  clusters = load_clusters(clusters_file)
199
  sentences = load_input_data(input_file)
200
 
201
- # Get tokens and their predicted clusters from predictions file
202
- token_predictions = dict(zip(predictions_df['Token'], predictions_df['Top 1']))
203
 
204
- # Display dataset statistics in an expandable section
205
- with st.expander("📊 Dataset Statistics", expanded=False):
206
- col1, col2, col3 = st.columns(3)
207
- with col1:
208
- st.metric("Total Tokens", f"{len(predictions_df):,}")
209
- with col2:
210
- st.metric("Total Clusters", f"{len(clusters):,}")
211
- with col3:
212
- avg_tokens = sum(len(tokens) for tokens in clusters.values()) / max(len(clusters), 1)
213
- st.metric("Avg. Tokens per Cluster", f"{avg_tokens:.1f}")
214
 
215
- # Create token selector in sidebar
216
- with st.sidebar:
217
- st.markdown("### 🔤 Token Selection")
218
-
219
- # Convert all tokens to strings before sorting to avoid type comparison issues
220
- all_tokens = sorted([str(token) for token in predictions_df['Token'].unique()],
221
- key=lambda x: (x.lower() if isinstance(x, str) else str(x)))
222
-
223
- # Add a search box to filter tokens
224
- token_search = st.text_input("🔍 Search tokens", "")
225
-
226
- if token_search:
227
- filtered_tokens = [t for t in all_tokens if token_search.lower() in t.lower()]
228
- token_options = filtered_tokens if filtered_tokens else all_tokens
229
- if not filtered_tokens:
230
- st.warning(f"No tokens matching '{token_search}'")
231
- else:
232
- token_options = all_tokens
233
-
234
- selected_token = st.selectbox(
235
- "Select a token:",
236
- token_options,
237
- index=0 if token_options else None
238
- )
239
 
240
  # Main content
241
  if selected_token:
@@ -243,7 +158,6 @@ def main():
243
  token_instances = predictions_df[predictions_df['Token'] == selected_token]
244
 
245
  if not token_instances.empty:
246
- # Simple header and token display
247
  st.title(f"Token: {selected_token}")
248
 
249
  # Get most frequent cluster (Top 1) for this token
@@ -298,14 +212,6 @@ def main():
298
  st.info("No contexts found in this cluster")
299
  else:
300
  st.warning(f"No instances found for token: {selected_token}")
301
- else:
302
- # Show welcome message when no token is selected
303
- st.markdown("""
304
- <div style="text-align: center; margin-top: 50px; color: #757575;">
305
- <h3>👈 Select a token from the sidebar to begin</h3>
306
- <p>The visualization will show cluster information and code contexts.</p>
307
- </div>
308
- """, unsafe_allow_html=True)
309
 
310
  except Exception as e:
311
  st.error("An error occurred while processing the data")
 
12
  layout="wide"
13
  )
14
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # Functions to load data
16
  @st.cache_data
17
  def load_predictions(file_path):
18
  """Load the predictions CSV file."""
 
19
  full_path = os.path.join("src", file_path)
 
20
  df = pd.read_csv(full_path, sep="\t", dtype=str)
21
 
22
  # Convert numeric columns safely
 
57
  if not tokens:
58
  return None
59
 
 
60
  token_weights = {token: 1 for token in set(tokens)}
61
 
62
  wordcloud = WordCloud(
63
+ width=1000,
64
+ height=500,
65
+ background_color='#FFF0DB',
66
+ prefer_horizontal=1,
67
+ min_font_size=10,
68
+ max_font_size=150,
69
+ relative_scaling=0.5,
70
+ collocations=False,
71
+ margin=1,
72
+ random_state=42,
73
+ scale=2,
74
+ repeat=False,
75
+ max_words=2000,
76
+ regexp=r"\w[\w' ]+"
77
+ ).generate_from_frequencies(token_weights)
78
 
79
+ fig = plt.figure(figsize=(16, 8))
 
80
  plt.imshow(wordcloud, interpolation='bilinear')
81
  plt.axis('off')
82
  plt.tight_layout(pad=0)
 
86
  def load_dev_sentences():
87
  """Load sentences from dev.in file."""
88
  try:
 
89
  possible_paths = [
90
  os.path.join("src", "codebert", "compile_error", "dev.in"),
91
  os.path.join("src", "codebert", "language_classification", "layer11", "dev.in"),
 
101
  st.error(f"Error loading dev.in file: {str(e)}")
102
  return []
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  def main():
105
  st.title("Code Token Cluster Visualization")
106
 
107
+ # Model selection
108
+ available_models = ["codebert"]
109
+ model = st.selectbox("Select Model", available_models)
 
 
 
 
 
 
 
 
110
 
111
+ # Task selection
112
+ tasks = ["language_classification"]
113
+ task = st.selectbox("Select Task", tasks)
 
 
 
 
 
 
 
 
114
 
115
+ # Layer selection
116
+ layers = ["layer6", "layer11"]
117
+ layer = st.selectbox("Select Layer", layers)
 
 
 
 
 
 
 
 
118
 
119
  # Fix the file paths
120
  layer_dir = os.path.join(model, task, layer)
 
128
  clusters = load_clusters(clusters_file)
129
  sentences = load_input_data(input_file)
130
 
131
+ # Create token selector in sidebar
132
+ st.sidebar.title("Token Selection")
133
 
134
+ # Convert all tokens to strings before sorting
135
+ all_tokens = sorted([str(token) for token in predictions_df['Token'].unique()],
136
+ key=lambda x: (x.lower() if isinstance(x, str) else str(x)))
 
 
 
 
 
 
 
137
 
138
+ # Add a search box to filter tokens
139
+ token_search = st.sidebar.text_input("Search tokens")
140
+
141
+ if token_search:
142
+ filtered_tokens = [t for t in all_tokens if token_search.lower() in t.lower()]
143
+ token_options = filtered_tokens if filtered_tokens else all_tokens
144
+ if not filtered_tokens:
145
+ st.sidebar.warning(f"No tokens matching '{token_search}'")
146
+ else:
147
+ token_options = all_tokens
148
+
149
+ selected_token = st.sidebar.selectbox(
150
+ "Select a token:",
151
+ token_options,
152
+ index=0 if token_options else None
153
+ )
 
 
 
 
 
 
 
 
154
 
155
  # Main content
156
  if selected_token:
 
158
  token_instances = predictions_df[predictions_df['Token'] == selected_token]
159
 
160
  if not token_instances.empty:
 
161
  st.title(f"Token: {selected_token}")
162
 
163
  # Get most frequent cluster (Top 1) for this token
 
212
  st.info("No contexts found in this cluster")
213
  else:
214
  st.warning(f"No instances found for token: {selected_token}")
 
 
 
 
 
 
 
 
215
 
216
  except Exception as e:
217
  st.error("An error occurred while processing the data")