ELIA / user_study /scripts /quantitative_analysis.py
aaron0eidt's picture
Deploy static demo
5b6c556
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import os
def load_and_preprocess_data(filepath='user_study/data/user_data.csv'):
# Loads and preprocesses the user study data.
if not os.path.exists(filepath):
raise FileNotFoundError(f"The data file was not found at {filepath}")
df = pd.read_csv(filepath)
# Invert Cognitive Load to make it more intuitive.
for col in ['attr_q_cognitive_load']:
if col in df.columns:
df[col.replace('cognitive_load', 'ease_of_use')] = 6 - df[col]
df.drop(columns=[col], inplace=True)
return df
def plot_user_demographics(df, output_dir='writing/Simplifying_Outcomes_of_Language_Model_Component_Analyses/figures/results'):
# Plots user demographic data.
os.makedirs(output_dir, exist_ok=True)
# Configure the plot style.
sns.set_theme(style="ticks", palette="viridis")
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rcParams['axes.labelweight'] = 'normal'
plt.rcParams['axes.titleweight'] = 'bold'
plt.rcParams['figure.titleweight'] = 'bold'
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['grid.alpha'] = 0.2
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False
# Plot age distribution.
plt.figure(figsize=(8, 6))
age_order = ['under_18', '18_24', '25_34', '35_44', '55_64']
ax = sns.countplot(data=df, x='age', order=age_order, palette="colorblind", hue='age', legend=False)
plt.xlabel('Age Group', fontsize=18)
plt.ylabel('Number of Participants', fontsize=18)
plt.xticks(rotation=45, fontsize=14)
# Set y-axis ticks to integers.
ax.set_yticks(np.arange(0, df['age'].value_counts().max() + 1, 1))
ax.tick_params(axis='y', labelsize=12)
sns.despine()
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'user_demographics_age.png'))
plt.close()
# Plot LLM experience.
plt.figure(figsize=(8, 6))
exp_order = ['novice', 'intermediate', 'expert']
ax = sns.countplot(data=df, x='llm_experience', order=exp_order, palette="colorblind", hue='llm_experience', legend=False)
plt.xlabel('Experience Level', fontsize=16)
plt.ylabel('Number of Participants', fontsize=16)
plt.xticks(fontsize=14)
# Set y-axis ticks to integers.
ax.set_yticks(np.arange(0, df['llm_experience'].value_counts().max() + 1, 1))
ax.tick_params(axis='y', labelsize=14)
sns.despine()
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'user_demographics_experience.png'))
plt.close()
def plot_ux_ratings_by_page(df, language_filter='all'):
# Plots UX ratings, with an option to filter by language.
if language_filter != 'all':
df = df[df['language'] == language_filter]
if df.empty:
print(f"No data for language filter: {language_filter}. Skipping plot.")
return
ux_metrics = {
'Attribution': ['Visual Clarity', 'Ease of Use', 'Influencer Plausibility'],
'Function Vectors': ['Pca Clarity', 'Type Attribution Clarity', 'Layer Evolution Plausibility'],
'Circuit Trace': ['Main Graph Clarity', 'Feature Explorer Usefulness', 'Subnetwork Clarity']
}
# Map the clean metric names to the dataframe column names.
column_mapping = {
'Visual Clarity': 'attr_q_visual_clarity',
'Ease of Use': 'attr_q_ease_of_use',
'Influencer Plausibility': 'attr_q_influencer_plausibility',
'Pca Clarity': 'fv_q_pca_clarity',
'Type Attribution Clarity': 'fv_q_type_attribution_clarity',
'Layer Evolution Plausibility': 'fv_q_layer_evolution_plausibility',
'Main Graph Clarity': 'ct_q_main_graph_clarity',
'Feature Explorer Usefulness': 'ct_q_feature_explorer_usefulness',
'Subnetwork Clarity': 'ct_q_subnetwork_clarity'
}
df_melted = pd.DataFrame()
for page, cols in ux_metrics.items():
for col_clean_name in cols:
col_original_name = column_mapping[col_clean_name]
if col_original_name in df.columns:
temp_df = df[[col_original_name]].copy()
temp_df.rename(columns={col_original_name: 'Rating'}, inplace=True)
temp_df['Page'] = page
temp_df['Metric'] = col_clean_name
df_melted = pd.concat([df_melted, temp_df], ignore_index=True)
plt.figure(figsize=(14, 8))
sns.boxplot(data=df_melted, x='Metric', y='Rating', hue='Page', palette='colorblind', fliersize=0)
plt.xlabel('UX Metric', fontsize=14)
plt.ylabel('Rating (1-5)', fontsize=14)
plt.xticks(rotation=15, fontsize=12)
plt.legend(title='Analysis Page', fontsize=12)
plt.yticks(np.arange(1, 6, 1), fontsize=12)
sns.despine()
plt.tight_layout()
# Save the figure with a language-specific name.
output_path = os.path.join('writing/Simplifying_Outcomes_of_Language_Model_Component_Analyses/figures/results', f'ux_ratings_by_page_{language_filter}.png')
plt.savefig(output_path)
print(f"Saved UX ratings plot to {output_path}")
plt.close()
def plot_correctness_by_experience(df, output_dir='writing/Simplifying_Outcomes_of_Language_Model_Component_Analyses/figures/results'):
# Plots comprehension correctness by LLM experience.
os.makedirs(output_dir, exist_ok=True)
correct_cols = [col for col in df.columns if 'correct' in col]
df_corr = df[['llm_experience'] + correct_cols].copy()
df_melted = df_corr.melt(id_vars=['llm_experience'], var_name='Question', value_name='Is Correct')
df_melted['Is Correct'] = df_melted['Is Correct'].astype(float)
plt.figure(figsize=(10, 6))
ax = sns.barplot(data=df_melted, x='llm_experience', y='Is Correct', order=['novice', 'intermediate', 'expert'], palette='colorblind', hue='llm_experience', legend=False, errorbar=None)
plt.xlabel('Experience Level', fontsize=16)
plt.ylabel('Proportion Correct', fontsize=16)
plt.ylim(0, 1)
plt.yticks(np.arange(0, 1.1, 0.1))
ax.tick_params(axis='x', labelsize=14)
ax.tick_params(axis='y', labelsize=14)
sns.despine()
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'correctness_by_experience.png'))
plt.close()
def plot_correlation_heatmap(df, output_dir='writing/Simplifying_Outcomes_of_Language_Model_Component_Analyses/figures/results'):
# Plots a correlation heatmap of all numerical data.
os.makedirs(output_dir, exist_ok=True)
quant_cols = df.select_dtypes(include=np.number).columns.tolist()
# Remove participant ID from the correlation.
if 'participant_id' in quant_cols:
quant_cols.remove('participant_id')
corr = df[quant_cols].corr()
plt.figure(figsize=(12, 10))
sns.heatmap(corr, annot=True, fmt=".2f", cmap="coolwarm", linewidths=.5)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'correlation_heatmap.png'))
plt.close()
if __name__ == '__main__':
try:
data = load_and_preprocess_data('../../user_study/data/user_data.csv')
plot_user_demographics(data)
# Generate all three versions of the UX ratings plot.
plot_ux_ratings_by_page(data.copy(), language_filter='all')
plot_ux_ratings_by_page(data.copy(), language_filter='en')
plot_ux_ratings_by_page(data.copy(), language_filter='de')
plot_correctness_by_experience(data)
plot_correlation_heatmap(data)
print("All plots generated successfully.")
except Exception as e:
print(f"An error occurred: {e}")