File size: 11,569 Bytes
e336c48
 
 
 
401d1b1
e336c48
401d1b1
 
 
 
 
 
 
e336c48
401d1b1
 
 
 
e336c48
401d1b1
 
 
 
 
 
 
 
 
 
 
 
e336c48
401d1b1
 
 
 
aee7b66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401d1b1
e336c48
401d1b1
 
 
 
 
 
 
 
 
 
 
aee7b66
401d1b1
 
 
aee7b66
401d1b1
 
 
aee7b66
 
401d1b1
 
 
 
aee7b66
 
 
401d1b1
 
 
 
aee7b66
 
401d1b1
aee7b66
401d1b1
 
 
 
e336c48
 
401d1b1
e336c48
401d1b1
a039d86
401d1b1
 
 
 
a039d86
401d1b1
 
 
a039d86
401d1b1
aee7b66
 
 
 
 
 
 
 
 
 
 
 
401d1b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a039d86
401d1b1
 
 
aee7b66
 
 
 
 
 
 
 
 
 
 
 
 
 
a039d86
e336c48
bedc82f
e336c48
401d1b1
 
e336c48
401d1b1
 
 
5d6df7a
401d1b1
 
 
e336c48
401d1b1
 
 
 
5d6df7a
401d1b1
 
 
eeee269
401d1b1
 
 
 
 
 
aee7b66
 
5d6df7a
52e6cfc
aee7b66
401d1b1
 
 
 
5d6df7a
401d1b1
5d6df7a
e336c48
5d6df7a
 
401d1b1
5d6df7a
401d1b1
 
 
5d6df7a
 
 
401d1b1
5d6df7a
 
 
 
 
 
e795f9f
401d1b1
 
 
5d6df7a
 
 
aee7b66
bedc82f
aee7b66
bedc82f
 
aee7b66
5d6df7a
bedc82f
5d6df7a
bedc82f
aee7b66
bedc82f
 
aee7b66
bedc82f
 
 
 
 
 
401d1b1
aee7b66
5d6df7a
 
bedc82f
401d1b1
5d6df7a
401d1b1
aee7b66
 
401d1b1
 
 
 
 
aee7b66
e336c48
 
1c6303f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
import streamlit as st
import pandas as pd
import os
from wordcloud import WordCloud
import matplotlib.pyplot as plt

# Initialize session state if needed
if 'selected_token' not in st.session_state:
    st.session_state.selected_token = None
if 'selected_task' not in st.session_state:
    st.session_state.selected_task = None
if 'selected_layer' not in st.session_state:
    st.session_state.selected_layer = None

def get_available_tasks():
    """Get list of available tasks based on directory structure."""
    base_path = os.path.join("src", "codebert")
    return [d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))]

def get_available_layers(task):
    """Get list of available layers for a task."""
    task_path = os.path.join("src", "codebert", task)
    layers = []
    for item in os.listdir(task_path):
        if item.startswith("layer"):
            try:
                layer_num = int(item.replace("layer", ""))
                layers.append(layer_num)
            except ValueError:
                continue
    return sorted(layers)

def load_predictions(task, layer):
    """Load predictions from CSV file."""
    predictions_path = os.path.join("src", "codebert", task, f"layer{layer}", f"predictions_layer_{layer}.csv")
    if os.path.exists(predictions_path):
        try:
            # Read CSV with tab delimiter
            df = pd.read_csv(predictions_path, delimiter='\t')
            # Convert Token column to string to handle numeric tokens
            df['Token'] = df['Token'].astype(str)
            # Get the primary predicted cluster (Top 1)
            df['predicted_cluster'] = df['Top 1'].astype(str)
            # Create display strings for each token occurrence
            df['display_text'] = df.apply(
                lambda row: f"{row['Token']} (line {row['line_idx']}, pos {row['position_idx']}, cluster {row['predicted_cluster']})", 
                axis=1
            )
            return df
        except Exception as e:
            st.error(f"Error loading predictions: {str(e)}")
            return None
    return None

def load_clusters(task, layer):
    """Load cluster data from clusters file."""
    clusters_path = os.path.join("src", "codebert", task, f"layer{layer}", "clusters-350.txt")
    if not os.path.exists(clusters_path):
        return None
        
    clusters = {}
    try:
        with open(clusters_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if not line:  # Skip empty lines
                    continue
                    
                try:
                    # Split on ||| and get the parts
                    parts = [p.strip() for p in line.split('|||')]
                    if len(parts) == 5:
                        token, occurrence, line_num, col_num, cluster_id = parts
                        # Clean up cluster_id (remove any trailing pipes)
                        cluster_id = cluster_id.split('|')[0].strip()
                        
                        if not cluster_id.isdigit():  # Skip if cluster_id is not a valid number
                            continue
                            
                        cluster_id = str(int(cluster_id))  # Normalize cluster ID
                        
                        # Store in clusters dict
                        if cluster_id not in clusters:
                            clusters[cluster_id] = []
                        clusters[cluster_id].append({
                            'token': token,
                            'line_num': int(line_num),
                            'col_num': int(col_num)
                        })
                except Exception:
                    continue
                    
    except Exception as e:
        st.error(f"Error loading clusters: {str(e)}")
        return None
    
    return clusters

def load_dev_sentences(task, layer):
    """Load sentences from dev.in file."""
    dev_path = os.path.join("src", "codebert", task, f"layer{layer}", "dev.in")
    if not os.path.exists(dev_path):
        dev_path = os.path.join("src", "codebert", task, "dev.in")
    
    try:
        with open(dev_path, 'r', encoding='utf-8') as f:
            return f.readlines()
    except Exception:
        return []

def load_train_sentences(task, layer):
    """Load sentences from input.in (training set) file."""
    train_path = os.path.join("src", "codebert", task, f"layer{layer}", "input.in")
    if not os.path.exists(train_path):
        train_path = os.path.join("src", "codebert", task, "input.in")
    
    try:
        with open(train_path, 'r', encoding='utf-8') as f:
            return f.readlines()
    except Exception:
        return []

def is_cls_token(token):
    """Check if token is a CLS token (including numbered ones like [CLS]0)."""
    return token.startswith('[CLS]')

def create_wordcloud(tokens_with_freq):
    """Create wordcloud from tokens with their frequencies."""
    if not tokens_with_freq:
        return None
    
    try:
        # Set all frequencies to 1 to make all words the same size
        uniform_frequencies = {token: 1 for token in tokens_with_freq.keys()}
        wordcloud = WordCloud(
            width=800, 
            height=400, 
            background_color='#f9f9f9',  # Very light gray, almost white
            prefer_horizontal=1,  # All text horizontal
            relative_scaling=0,  # This ensures uniform sizing
            min_font_size=35,  # Ensure text is readable
            max_font_size=150,  # Same as min to ensure uniform size
        ).generate_from_frequencies(uniform_frequencies)
        return wordcloud
    except Exception as e:
        st.error(f"Error creating wordcloud: {str(e)}")
        return None

def load_explanation_words(task, layer):
    """Load explanation words file with labels."""
    file_path = os.path.join("src", "codebert", task, f"layer{layer}", f"explanation_words_layer{layer}.csv")
    try:
        df = pd.read_csv(file_path, sep='\t')
        # Create a dictionary mapping (token, line_idx, position_idx) to label
        token_to_label = {}
        for _, row in df.iterrows():
            key = (row['token'], row['line_idx'], row['position_idx'])
            token_to_label[key] = row['labels']
        return token_to_label
    except Exception as e:
        st.error(f"Error loading explanation words: {str(e)}")
        return {}

def main():
    st.title("Token Analysis")
    
    # Task and Layer Selection
    col1, col2 = st.columns(2)
    
    with col1:
        available_tasks = get_available_tasks()
        selected_task = st.selectbox(
            "Select Task",
            available_tasks,
            key='task_selector'
        )
    
    with col2:
        if selected_task:
            available_layers = get_available_layers(selected_task)
            selected_layer = st.selectbox(
                "Select Layer",
                available_layers,
                key='layer_selector'
            )
        else:
            selected_layer = None
    
    # Only proceed if both task and layer are selected
    if selected_task and selected_layer is not None:
        predictions_df = load_predictions(selected_task, selected_layer)
        clusters = load_clusters(selected_task, selected_layer)
        dev_sentences = load_dev_sentences(selected_task, selected_layer)  # Test set sentences
        train_sentences = load_train_sentences(selected_task, selected_layer)  # Training set sentences
        token_labels = load_explanation_words(selected_task, selected_layer)  # Load token labels
        
        if predictions_df is not None and clusters is not None:
            # Token selection with search
            search_token = st.text_input("Search tokens", key='token_search')
            
            # Filter display options based on search
            filtered_df = predictions_df
            if search_token:
                filtered_df = predictions_df[predictions_df['Token'].str.contains(search_token, case=False, na=False)]
            
            # Display token selection
            selected_token_display = st.selectbox(
                "Select a token occurrence",
                filtered_df['display_text'].tolist(),
                key='token_selector'
            )
            
            if selected_token_display:
                # Get the selected row from the dataframe
                selected_row = filtered_df[filtered_df['display_text'] == selected_token_display].iloc[0]
                token = selected_row['Token']
                line_idx = selected_row['line_idx']
                position_idx = selected_row['position_idx']
                
                # Get the label for the selected token
                token_key = (token, line_idx, position_idx)
                
                
                # Display token information
                st.header(f"Token: {token}")
                st.write(f"๐Ÿ“ Line: {selected_row['line_idx']}, Position: {selected_row['position_idx']}")
                st.metric("Predicted Cluster", selected_row['predicted_cluster'])
                if token_key in token_labels:
                    st.write(f"**Label:** {token_labels[token_key]}")                
                # Show original context from dev.in (test set)
                if dev_sentences and selected_row['line_idx'] < len(dev_sentences):
                    st.subheader("Original Context (from test set)")
                    st.code(dev_sentences[selected_row['line_idx']].strip())
                
                # Show wordcloud for the cluster (from training set)
                if clusters and selected_row['predicted_cluster'] in clusters:
                    token_frequencies = {}
                    for token_info in clusters[selected_row['predicted_cluster']]:
                        token = token_info['token']
                        token_frequencies[token] = 1  # Set all frequencies to 1 for uniform size
                    
                    if token_frequencies:
                        st.subheader("Cluster Word Cloud (from training set)")
                        wordcloud = create_wordcloud(token_frequencies)
                        if wordcloud:
                            plt.figure(figsize=(10, 5))
                            plt.imshow(wordcloud, interpolation='bilinear')
                            plt.axis('off')
                            st.pyplot(plt)
                
                # Show similar contexts from the cluster (from training set)
                with st.expander(f"๐Ÿ‘€ View Similar Contexts (from training set, Cluster {selected_row['predicted_cluster']})"):
                    if clusters and selected_row['predicted_cluster'] in clusters:
                        shown_contexts = set()
                        
                        for token_info in clusters[selected_row['predicted_cluster']]:
                            line_num = token_info['line_num']
                            if line_num >= 0 and line_num < len(train_sentences):
                                context = train_sentences[line_num].strip()
                                if context not in shown_contexts:
                                    st.code(context)
                                    shown_contexts.add(context)
                        
                        if not shown_contexts:
                            st.info("No similar contexts found in this cluster from the training set.")

if __name__ == "__main__":
    main()