| | import json |
| | import numpy as np |
| | import plotly.graph_objects as go |
| | from pathlib import Path |
| | from argparse import ArgumentParser |
| | import os |
| |
|
| | def load_model_data(base_path: str, model: str) -> dict: |
| | """Load perplexity data for one model (3 files)""" |
| | model_path = Path(base_path) / model / 'perplexity' |
| | data = {} |
| | |
| | for filename in ['harmful_test.json', 'harmful.json', 'normal.json']: |
| | filepath = model_path / filename |
| | if filepath.exists(): |
| | print(f"✅ the {model_path/filename} exists") |
| | with open(filepath) as f: |
| | file_data = json.load(f) |
| | key = filename.replace('.json', '').replace('_test', '_test') |
| | data[key] = file_data if isinstance(file_data, list) else file_data.get('perplexities', []) |
| | else: |
| | print(f"❌ the {model_path/filename} doesn't exists") |
| | key = filename.replace('.json', '') |
| | data[key] = [] |
| | |
| | return data |
| |
|
| | def compute_mean(values: list) -> float: |
| | """Compute mean, return 0 if empty""" |
| | return float(np.mean(values)) if values else 0.0 |
| |
|
| | def create_comparison_plot(base_path: str, output_path: str = './results'): |
| | """Create 3-bar comparison plot for all models""" |
| | models = ['qwen', 'mistral', 'llama2', "llama3"] |
| | |
| | |
| | all_data = {} |
| | for model in models: |
| | data = load_model_data(base_path, model) |
| | all_data[model] = { |
| | 'harmful': compute_mean(data.get('harmful', [])), |
| | 'harmful_test': compute_mean(data.get('harmful_test', [])), |
| | 'normal': compute_mean(data.get('normal', [])) |
| | } |
| | print(f"{model}: {len(data.get('harmful_test', []))} harmful_test, {len(data.get('harmful', []))} harmful, {len(data.get('normal', []))} normal") |
| | |
| | |
| | fig = go.Figure() |
| | |
| | |
| | |
| | |
| | for i, data_type in enumerate(['harmful', 'harmful_test', 'normal']): |
| | values = [all_data[model][data_type] for model in models] |
| | if data_type == 'harmful': |
| | |
| | fig.add_trace(go.Bar( |
| | x=models, |
| | y=values, |
| | name='Harmful (Train Data)', |
| | marker=dict( |
| | color='#E1BEE7', |
| | line=dict(color='#6A1B9A', width=1.5), |
| | pattern=dict(shape=".", fgcolor='#BA68C8', size=8) |
| | ), |
| | text=[f'{v:.2f}' for v in values], |
| | textposition='outside', |
| | textfont=dict(size=12, color='black') |
| | )) |
| | elif data_type == 'harmful_test': |
| | |
| | fig.add_trace(go.Bar( |
| | x=models, |
| | y=values, |
| | name='Harmful Test (Test Data)', |
| | marker=dict( |
| | color='#B2DFDB', |
| | line=dict(color='#00695C', width=1.5), |
| | pattern=dict(shape="x", fgcolor='#4DB6AC', size=8) |
| | ), |
| | text=[f'{v:.2f}' for v in values], |
| | textposition='outside', |
| | textfont=dict(size=12, color='black') |
| | )) |
| | else: |
| | |
| | fig.add_trace(go.Bar( |
| | x=models, |
| | y=values, |
| | name='Normal (Test Data)', |
| | marker=dict( |
| | color='#64B5F6', |
| | line=dict(color='#2874A6', width=1.5), |
| | pattern=dict(shape="-", fgcolor='#3498DB', size=8) |
| | ), |
| | text=[f'{v:.2f}' for v in values], |
| | textposition='outside', |
| | textfont=dict(size=12, color='black') |
| | )) |
| | |
| | |
| | fig.update_layout( |
| | title={'text': 'Perplexity Comparison Across Models', 'x': 0.5, 'font': {'size': 26, 'color': 'black'}}, |
| | xaxis_title='Model', |
| | yaxis_title='Perplexity', |
| | xaxis_title_font_size=20, |
| | yaxis_title_font_size=20, |
| | font={'family': 'Times New Roman', 'size': 16, 'color': 'black'}, |
| | plot_bgcolor='#FFFEF7', |
| | paper_bgcolor='white', |
| | barmode='group', |
| | bargap=0.2, |
| | bargroupgap=0.05, |
| | width=600, |
| | height=400, |
| | legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5, font=dict(size=12)), |
| | yaxis=dict(tickfont=dict(size=16), gridcolor='lightgray', gridwidth=0.3, zeroline=True, zerolinecolor='gray', zerolinewidth=0.5) |
| | ) |
| | |
| | |
| | os.makedirs(output_path, exist_ok=True) |
| | fig.write_html(f"{output_path}/perplexity_model_comparison.html") |
| | fig.write_image(f"{output_path}/perplexity_model_comparison.pdf", width=700, height=500, scale=2) |
| | |
| | print(f"✓ Saved to {output_path}/perplexity_model_comparison.html and .pdf") |
| | return fig |
| |
|
| | def main(): |
| | parser = ArgumentParser(description="Multi-model perplexity comparison") |
| | parser.add_argument("--base_path", default="utils/data", help="Base path with model directories") |
| | parser.add_argument("--output_path", default="utils/data", help="Output directory") |
| | args = parser.parse_args() |
| | |
| | create_comparison_plot(args.base_path, args.output_path) |
| |
|
| | if __name__ == "__main__": |
| | main() |