Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import os | |
| def load_and_preprocess_data(filepath='user_study/data/user_data.csv'): | |
| # Loads and preprocesses the user study data. | |
| if not os.path.exists(filepath): | |
| raise FileNotFoundError(f"The data file was not found at {filepath}") | |
| df = pd.read_csv(filepath) | |
| # Invert Cognitive Load to make it more intuitive. | |
| for col in ['attr_q_cognitive_load']: | |
| if col in df.columns: | |
| df[col.replace('cognitive_load', 'ease_of_use')] = 6 - df[col] | |
| df.drop(columns=[col], inplace=True) | |
| return df | |
| def plot_user_demographics(df, output_dir='writing/Simplifying_Outcomes_of_Language_Model_Component_Analyses/figures/results'): | |
| # Plots user demographic data. | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Configure the plot style. | |
| sns.set_theme(style="ticks", palette="viridis") | |
| plt.rcParams['font.family'] = 'sans-serif' | |
| plt.rcParams['font.sans-serif'] = 'Arial' | |
| plt.rcParams['axes.labelweight'] = 'normal' | |
| plt.rcParams['axes.titleweight'] = 'bold' | |
| plt.rcParams['figure.titleweight'] = 'bold' | |
| plt.rcParams['savefig.dpi'] = 300 | |
| plt.rcParams['figure.facecolor'] = 'white' | |
| plt.rcParams['axes.facecolor'] = 'white' | |
| plt.rcParams['grid.alpha'] = 0.2 | |
| plt.rcParams['axes.spines.top'] = False | |
| plt.rcParams['axes.spines.right'] = False | |
| # Plot age distribution. | |
| plt.figure(figsize=(8, 6)) | |
| age_order = ['under_18', '18_24', '25_34', '35_44', '55_64'] | |
| ax = sns.countplot(data=df, x='age', order=age_order, palette="colorblind", hue='age', legend=False) | |
| plt.xlabel('Age Group', fontsize=18) | |
| plt.ylabel('Number of Participants', fontsize=18) | |
| plt.xticks(rotation=45, fontsize=14) | |
| # Set y-axis ticks to integers. | |
| ax.set_yticks(np.arange(0, df['age'].value_counts().max() + 1, 1)) | |
| ax.tick_params(axis='y', labelsize=12) | |
| sns.despine() | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(output_dir, 'user_demographics_age.png')) | |
| plt.close() | |
| # Plot LLM experience. | |
| plt.figure(figsize=(8, 6)) | |
| exp_order = ['novice', 'intermediate', 'expert'] | |
| ax = sns.countplot(data=df, x='llm_experience', order=exp_order, palette="colorblind", hue='llm_experience', legend=False) | |
| plt.xlabel('Experience Level', fontsize=16) | |
| plt.ylabel('Number of Participants', fontsize=16) | |
| plt.xticks(fontsize=14) | |
| # Set y-axis ticks to integers. | |
| ax.set_yticks(np.arange(0, df['llm_experience'].value_counts().max() + 1, 1)) | |
| ax.tick_params(axis='y', labelsize=14) | |
| sns.despine() | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(output_dir, 'user_demographics_experience.png')) | |
| plt.close() | |
| def plot_ux_ratings_by_page(df, language_filter='all'): | |
| # Plots UX ratings, with an option to filter by language. | |
| if language_filter != 'all': | |
| df = df[df['language'] == language_filter] | |
| if df.empty: | |
| print(f"No data for language filter: {language_filter}. Skipping plot.") | |
| return | |
| ux_metrics = { | |
| 'Attribution': ['Visual Clarity', 'Ease of Use', 'Influencer Plausibility'], | |
| 'Function Vectors': ['Pca Clarity', 'Type Attribution Clarity', 'Layer Evolution Plausibility'], | |
| 'Circuit Trace': ['Main Graph Clarity', 'Feature Explorer Usefulness', 'Subnetwork Clarity'] | |
| } | |
| # Map the clean metric names to the dataframe column names. | |
| column_mapping = { | |
| 'Visual Clarity': 'attr_q_visual_clarity', | |
| 'Ease of Use': 'attr_q_ease_of_use', | |
| 'Influencer Plausibility': 'attr_q_influencer_plausibility', | |
| 'Pca Clarity': 'fv_q_pca_clarity', | |
| 'Type Attribution Clarity': 'fv_q_type_attribution_clarity', | |
| 'Layer Evolution Plausibility': 'fv_q_layer_evolution_plausibility', | |
| 'Main Graph Clarity': 'ct_q_main_graph_clarity', | |
| 'Feature Explorer Usefulness': 'ct_q_feature_explorer_usefulness', | |
| 'Subnetwork Clarity': 'ct_q_subnetwork_clarity' | |
| } | |
| df_melted = pd.DataFrame() | |
| for page, cols in ux_metrics.items(): | |
| for col_clean_name in cols: | |
| col_original_name = column_mapping[col_clean_name] | |
| if col_original_name in df.columns: | |
| temp_df = df[[col_original_name]].copy() | |
| temp_df.rename(columns={col_original_name: 'Rating'}, inplace=True) | |
| temp_df['Page'] = page | |
| temp_df['Metric'] = col_clean_name | |
| df_melted = pd.concat([df_melted, temp_df], ignore_index=True) | |
| plt.figure(figsize=(14, 8)) | |
| sns.boxplot(data=df_melted, x='Metric', y='Rating', hue='Page', palette='colorblind', fliersize=0) | |
| plt.xlabel('UX Metric', fontsize=14) | |
| plt.ylabel('Rating (1-5)', fontsize=14) | |
| plt.xticks(rotation=15, fontsize=12) | |
| plt.legend(title='Analysis Page', fontsize=12) | |
| plt.yticks(np.arange(1, 6, 1), fontsize=12) | |
| sns.despine() | |
| plt.tight_layout() | |
| # Save the figure with a language-specific name. | |
| output_path = os.path.join('writing/Simplifying_Outcomes_of_Language_Model_Component_Analyses/figures/results', f'ux_ratings_by_page_{language_filter}.png') | |
| plt.savefig(output_path) | |
| print(f"Saved UX ratings plot to {output_path}") | |
| plt.close() | |
| def plot_correctness_by_experience(df, output_dir='writing/Simplifying_Outcomes_of_Language_Model_Component_Analyses/figures/results'): | |
| # Plots comprehension correctness by LLM experience. | |
| os.makedirs(output_dir, exist_ok=True) | |
| correct_cols = [col for col in df.columns if 'correct' in col] | |
| df_corr = df[['llm_experience'] + correct_cols].copy() | |
| df_melted = df_corr.melt(id_vars=['llm_experience'], var_name='Question', value_name='Is Correct') | |
| df_melted['Is Correct'] = df_melted['Is Correct'].astype(float) | |
| plt.figure(figsize=(10, 6)) | |
| ax = sns.barplot(data=df_melted, x='llm_experience', y='Is Correct', order=['novice', 'intermediate', 'expert'], palette='colorblind', hue='llm_experience', legend=False, errorbar=None) | |
| plt.xlabel('Experience Level', fontsize=16) | |
| plt.ylabel('Proportion Correct', fontsize=16) | |
| plt.ylim(0, 1) | |
| plt.yticks(np.arange(0, 1.1, 0.1)) | |
| ax.tick_params(axis='x', labelsize=14) | |
| ax.tick_params(axis='y', labelsize=14) | |
| sns.despine() | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(output_dir, 'correctness_by_experience.png')) | |
| plt.close() | |
| def plot_correlation_heatmap(df, output_dir='writing/Simplifying_Outcomes_of_Language_Model_Component_Analyses/figures/results'): | |
| # Plots a correlation heatmap of all numerical data. | |
| os.makedirs(output_dir, exist_ok=True) | |
| quant_cols = df.select_dtypes(include=np.number).columns.tolist() | |
| # Remove participant ID from the correlation. | |
| if 'participant_id' in quant_cols: | |
| quant_cols.remove('participant_id') | |
| corr = df[quant_cols].corr() | |
| plt.figure(figsize=(12, 10)) | |
| sns.heatmap(corr, annot=True, fmt=".2f", cmap="coolwarm", linewidths=.5) | |
| plt.xticks(rotation=45, ha='right') | |
| plt.yticks(rotation=0) | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(output_dir, 'correlation_heatmap.png')) | |
| plt.close() | |
| if __name__ == '__main__': | |
| try: | |
| data = load_and_preprocess_data('../../user_study/data/user_data.csv') | |
| plot_user_demographics(data) | |
| # Generate all three versions of the UX ratings plot. | |
| plot_ux_ratings_by_page(data.copy(), language_filter='all') | |
| plot_ux_ratings_by_page(data.copy(), language_filter='en') | |
| plot_ux_ratings_by_page(data.copy(), language_filter='de') | |
| plot_correctness_by_experience(data) | |
| plot_correlation_heatmap(data) | |
| print("All plots generated successfully.") | |
| except Exception as e: | |
| print(f"An error occurred: {e}") |