File size: 7,844 Bytes
5b6c556
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import os

def load_and_preprocess_data(filepath='user_study/data/user_data.csv'):
    # Loads and preprocesses the user study data.
    if not os.path.exists(filepath):
        raise FileNotFoundError(f"The data file was not found at {filepath}")
    
    df = pd.read_csv(filepath)
    
    # Invert Cognitive Load to make it more intuitive.
    for col in ['attr_q_cognitive_load']:
        if col in df.columns:
            df[col.replace('cognitive_load', 'ease_of_use')] = 6 - df[col]
            df.drop(columns=[col], inplace=True)
            
    return df

def plot_user_demographics(df, output_dir='writing/Simplifying_Outcomes_of_Language_Model_Component_Analyses/figures/results'):
    # Plots user demographic data.
    os.makedirs(output_dir, exist_ok=True)
    
    # Configure the plot style.
    sns.set_theme(style="ticks", palette="viridis")
    plt.rcParams['font.family'] = 'sans-serif'
    plt.rcParams['font.sans-serif'] = 'Arial'
    plt.rcParams['axes.labelweight'] = 'normal'
    plt.rcParams['axes.titleweight'] = 'bold'
    plt.rcParams['figure.titleweight'] = 'bold'
    plt.rcParams['savefig.dpi'] = 300
    plt.rcParams['figure.facecolor'] = 'white'
    plt.rcParams['axes.facecolor'] = 'white'
    plt.rcParams['grid.alpha'] = 0.2
    plt.rcParams['axes.spines.top'] = False
    plt.rcParams['axes.spines.right'] = False

    # Plot age distribution.
    plt.figure(figsize=(8, 6))
    age_order = ['under_18', '18_24', '25_34', '35_44', '55_64']
    ax = sns.countplot(data=df, x='age', order=age_order, palette="colorblind", hue='age', legend=False)
    plt.xlabel('Age Group', fontsize=18)
    plt.ylabel('Number of Participants', fontsize=18)
    plt.xticks(rotation=45, fontsize=14)
    # Set y-axis ticks to integers.
    ax.set_yticks(np.arange(0, df['age'].value_counts().max() + 1, 1))
    ax.tick_params(axis='y', labelsize=12)
    sns.despine()
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'user_demographics_age.png'))
    plt.close()

    # Plot LLM experience.
    plt.figure(figsize=(8, 6))
    exp_order = ['novice', 'intermediate', 'expert']
    ax = sns.countplot(data=df, x='llm_experience', order=exp_order, palette="colorblind", hue='llm_experience', legend=False)
    plt.xlabel('Experience Level', fontsize=16)
    plt.ylabel('Number of Participants', fontsize=16)
    plt.xticks(fontsize=14)
    # Set y-axis ticks to integers.
    ax.set_yticks(np.arange(0, df['llm_experience'].value_counts().max() + 1, 1))
    ax.tick_params(axis='y', labelsize=14)
    sns.despine()
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'user_demographics_experience.png'))
    plt.close()

def plot_ux_ratings_by_page(df, language_filter='all'):
    # Plots UX ratings, with an option to filter by language.
    if language_filter != 'all':
        df = df[df['language'] == language_filter]
        
    if df.empty:
        print(f"No data for language filter: {language_filter}. Skipping plot.")
        return

    ux_metrics = {
        'Attribution': ['Visual Clarity', 'Ease of Use', 'Influencer Plausibility'],
        'Function Vectors': ['Pca Clarity', 'Type Attribution Clarity', 'Layer Evolution Plausibility'],
        'Circuit Trace': ['Main Graph Clarity', 'Feature Explorer Usefulness', 'Subnetwork Clarity']
    }

    # Map the clean metric names to the dataframe column names.
    column_mapping = {
        'Visual Clarity': 'attr_q_visual_clarity',
        'Ease of Use': 'attr_q_ease_of_use',
        'Influencer Plausibility': 'attr_q_influencer_plausibility',
        'Pca Clarity': 'fv_q_pca_clarity',
        'Type Attribution Clarity': 'fv_q_type_attribution_clarity',
        'Layer Evolution Plausibility': 'fv_q_layer_evolution_plausibility',
        'Main Graph Clarity': 'ct_q_main_graph_clarity',
        'Feature Explorer Usefulness': 'ct_q_feature_explorer_usefulness',
        'Subnetwork Clarity': 'ct_q_subnetwork_clarity'
    }
    
    df_melted = pd.DataFrame()
    for page, cols in ux_metrics.items():
        for col_clean_name in cols:
            col_original_name = column_mapping[col_clean_name]
            if col_original_name in df.columns:
                temp_df = df[[col_original_name]].copy()
                temp_df.rename(columns={col_original_name: 'Rating'}, inplace=True)
                temp_df['Page'] = page
                temp_df['Metric'] = col_clean_name
                df_melted = pd.concat([df_melted, temp_df], ignore_index=True)

    plt.figure(figsize=(14, 8))
    sns.boxplot(data=df_melted, x='Metric', y='Rating', hue='Page', palette='colorblind', fliersize=0)
    plt.xlabel('UX Metric', fontsize=14)
    plt.ylabel('Rating (1-5)', fontsize=14)
    plt.xticks(rotation=15, fontsize=12)
    plt.legend(title='Analysis Page', fontsize=12)
    plt.yticks(np.arange(1, 6, 1), fontsize=12)
    sns.despine()
    plt.tight_layout()
    
    # Save the figure with a language-specific name.
    output_path = os.path.join('writing/Simplifying_Outcomes_of_Language_Model_Component_Analyses/figures/results', f'ux_ratings_by_page_{language_filter}.png')
    plt.savefig(output_path)
    print(f"Saved UX ratings plot to {output_path}")
    plt.close()

def plot_correctness_by_experience(df, output_dir='writing/Simplifying_Outcomes_of_Language_Model_Component_Analyses/figures/results'):
    # Plots comprehension correctness by LLM experience.
    os.makedirs(output_dir, exist_ok=True)
    
    correct_cols = [col for col in df.columns if 'correct' in col]
    df_corr = df[['llm_experience'] + correct_cols].copy()
    
    df_melted = df_corr.melt(id_vars=['llm_experience'], var_name='Question', value_name='Is Correct')
    df_melted['Is Correct'] = df_melted['Is Correct'].astype(float)
    
    plt.figure(figsize=(10, 6))
    ax = sns.barplot(data=df_melted, x='llm_experience', y='Is Correct', order=['novice', 'intermediate', 'expert'], palette='colorblind', hue='llm_experience', legend=False, errorbar=None)
    plt.xlabel('Experience Level', fontsize=16)
    plt.ylabel('Proportion Correct', fontsize=16)
    plt.ylim(0, 1)
    plt.yticks(np.arange(0, 1.1, 0.1))
    ax.tick_params(axis='x', labelsize=14)
    ax.tick_params(axis='y', labelsize=14)
    sns.despine()
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'correctness_by_experience.png'))
    plt.close()

def plot_correlation_heatmap(df, output_dir='writing/Simplifying_Outcomes_of_Language_Model_Component_Analyses/figures/results'):
    # Plots a correlation heatmap of all numerical data.
    os.makedirs(output_dir, exist_ok=True)
    
    quant_cols = df.select_dtypes(include=np.number).columns.tolist()
    # Remove participant ID from the correlation.
    if 'participant_id' in quant_cols:
        quant_cols.remove('participant_id')
        
    corr = df[quant_cols].corr()
    
    plt.figure(figsize=(12, 10))
    sns.heatmap(corr, annot=True, fmt=".2f", cmap="coolwarm", linewidths=.5)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'correlation_heatmap.png'))
    plt.close()

if __name__ == '__main__':
    try:
        data = load_and_preprocess_data('../../user_study/data/user_data.csv')
        plot_user_demographics(data)
        
        # Generate all three versions of the UX ratings plot.
        plot_ux_ratings_by_page(data.copy(), language_filter='all')
        plot_ux_ratings_by_page(data.copy(), language_filter='en')
        plot_ux_ratings_by_page(data.copy(), language_filter='de')
        
        plot_correctness_by_experience(data)
        plot_correlation_heatmap(data)
        print("All plots generated successfully.")
    except Exception as e:
        print(f"An error occurred: {e}")