Spaces:
Sleeping
Sleeping
File size: 7,844 Bytes
5b6c556 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import os
def load_and_preprocess_data(filepath='user_study/data/user_data.csv'):
# Loads and preprocesses the user study data.
if not os.path.exists(filepath):
raise FileNotFoundError(f"The data file was not found at {filepath}")
df = pd.read_csv(filepath)
# Invert Cognitive Load to make it more intuitive.
for col in ['attr_q_cognitive_load']:
if col in df.columns:
df[col.replace('cognitive_load', 'ease_of_use')] = 6 - df[col]
df.drop(columns=[col], inplace=True)
return df
def plot_user_demographics(df, output_dir='writing/Simplifying_Outcomes_of_Language_Model_Component_Analyses/figures/results'):
# Plots user demographic data.
os.makedirs(output_dir, exist_ok=True)
# Configure the plot style.
sns.set_theme(style="ticks", palette="viridis")
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rcParams['axes.labelweight'] = 'normal'
plt.rcParams['axes.titleweight'] = 'bold'
plt.rcParams['figure.titleweight'] = 'bold'
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['grid.alpha'] = 0.2
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False
# Plot age distribution.
plt.figure(figsize=(8, 6))
age_order = ['under_18', '18_24', '25_34', '35_44', '55_64']
ax = sns.countplot(data=df, x='age', order=age_order, palette="colorblind", hue='age', legend=False)
plt.xlabel('Age Group', fontsize=18)
plt.ylabel('Number of Participants', fontsize=18)
plt.xticks(rotation=45, fontsize=14)
# Set y-axis ticks to integers.
ax.set_yticks(np.arange(0, df['age'].value_counts().max() + 1, 1))
ax.tick_params(axis='y', labelsize=12)
sns.despine()
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'user_demographics_age.png'))
plt.close()
# Plot LLM experience.
plt.figure(figsize=(8, 6))
exp_order = ['novice', 'intermediate', 'expert']
ax = sns.countplot(data=df, x='llm_experience', order=exp_order, palette="colorblind", hue='llm_experience', legend=False)
plt.xlabel('Experience Level', fontsize=16)
plt.ylabel('Number of Participants', fontsize=16)
plt.xticks(fontsize=14)
# Set y-axis ticks to integers.
ax.set_yticks(np.arange(0, df['llm_experience'].value_counts().max() + 1, 1))
ax.tick_params(axis='y', labelsize=14)
sns.despine()
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'user_demographics_experience.png'))
plt.close()
def plot_ux_ratings_by_page(df, language_filter='all'):
# Plots UX ratings, with an option to filter by language.
if language_filter != 'all':
df = df[df['language'] == language_filter]
if df.empty:
print(f"No data for language filter: {language_filter}. Skipping plot.")
return
ux_metrics = {
'Attribution': ['Visual Clarity', 'Ease of Use', 'Influencer Plausibility'],
'Function Vectors': ['Pca Clarity', 'Type Attribution Clarity', 'Layer Evolution Plausibility'],
'Circuit Trace': ['Main Graph Clarity', 'Feature Explorer Usefulness', 'Subnetwork Clarity']
}
# Map the clean metric names to the dataframe column names.
column_mapping = {
'Visual Clarity': 'attr_q_visual_clarity',
'Ease of Use': 'attr_q_ease_of_use',
'Influencer Plausibility': 'attr_q_influencer_plausibility',
'Pca Clarity': 'fv_q_pca_clarity',
'Type Attribution Clarity': 'fv_q_type_attribution_clarity',
'Layer Evolution Plausibility': 'fv_q_layer_evolution_plausibility',
'Main Graph Clarity': 'ct_q_main_graph_clarity',
'Feature Explorer Usefulness': 'ct_q_feature_explorer_usefulness',
'Subnetwork Clarity': 'ct_q_subnetwork_clarity'
}
df_melted = pd.DataFrame()
for page, cols in ux_metrics.items():
for col_clean_name in cols:
col_original_name = column_mapping[col_clean_name]
if col_original_name in df.columns:
temp_df = df[[col_original_name]].copy()
temp_df.rename(columns={col_original_name: 'Rating'}, inplace=True)
temp_df['Page'] = page
temp_df['Metric'] = col_clean_name
df_melted = pd.concat([df_melted, temp_df], ignore_index=True)
plt.figure(figsize=(14, 8))
sns.boxplot(data=df_melted, x='Metric', y='Rating', hue='Page', palette='colorblind', fliersize=0)
plt.xlabel('UX Metric', fontsize=14)
plt.ylabel('Rating (1-5)', fontsize=14)
plt.xticks(rotation=15, fontsize=12)
plt.legend(title='Analysis Page', fontsize=12)
plt.yticks(np.arange(1, 6, 1), fontsize=12)
sns.despine()
plt.tight_layout()
# Save the figure with a language-specific name.
output_path = os.path.join('writing/Simplifying_Outcomes_of_Language_Model_Component_Analyses/figures/results', f'ux_ratings_by_page_{language_filter}.png')
plt.savefig(output_path)
print(f"Saved UX ratings plot to {output_path}")
plt.close()
def plot_correctness_by_experience(df, output_dir='writing/Simplifying_Outcomes_of_Language_Model_Component_Analyses/figures/results'):
# Plots comprehension correctness by LLM experience.
os.makedirs(output_dir, exist_ok=True)
correct_cols = [col for col in df.columns if 'correct' in col]
df_corr = df[['llm_experience'] + correct_cols].copy()
df_melted = df_corr.melt(id_vars=['llm_experience'], var_name='Question', value_name='Is Correct')
df_melted['Is Correct'] = df_melted['Is Correct'].astype(float)
plt.figure(figsize=(10, 6))
ax = sns.barplot(data=df_melted, x='llm_experience', y='Is Correct', order=['novice', 'intermediate', 'expert'], palette='colorblind', hue='llm_experience', legend=False, errorbar=None)
plt.xlabel('Experience Level', fontsize=16)
plt.ylabel('Proportion Correct', fontsize=16)
plt.ylim(0, 1)
plt.yticks(np.arange(0, 1.1, 0.1))
ax.tick_params(axis='x', labelsize=14)
ax.tick_params(axis='y', labelsize=14)
sns.despine()
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'correctness_by_experience.png'))
plt.close()
def plot_correlation_heatmap(df, output_dir='writing/Simplifying_Outcomes_of_Language_Model_Component_Analyses/figures/results'):
# Plots a correlation heatmap of all numerical data.
os.makedirs(output_dir, exist_ok=True)
quant_cols = df.select_dtypes(include=np.number).columns.tolist()
# Remove participant ID from the correlation.
if 'participant_id' in quant_cols:
quant_cols.remove('participant_id')
corr = df[quant_cols].corr()
plt.figure(figsize=(12, 10))
sns.heatmap(corr, annot=True, fmt=".2f", cmap="coolwarm", linewidths=.5)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'correlation_heatmap.png'))
plt.close()
if __name__ == '__main__':
try:
data = load_and_preprocess_data('../../user_study/data/user_data.csv')
plot_user_demographics(data)
# Generate all three versions of the UX ratings plot.
plot_ux_ratings_by_page(data.copy(), language_filter='all')
plot_ux_ratings_by_page(data.copy(), language_filter='en')
plot_ux_ratings_by_page(data.copy(), language_filter='de')
plot_correctness_by_experience(data)
plot_correlation_heatmap(data)
print("All plots generated successfully.")
except Exception as e:
print(f"An error occurred: {e}") |