File size: 5,483 Bytes
713f666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import json
import numpy as np
import plotly.graph_objects as go
from pathlib import Path
from argparse import ArgumentParser
import os

def load_model_data(base_path: str, model: str) -> dict:
    """Load perplexity data for one model (3 files)"""
    model_path = Path(base_path) / model / 'perplexity' 
    data = {}
    
    for filename in ['harmful_test.json', 'harmful.json', 'normal.json']:
        filepath = model_path / filename
        if filepath.exists():
            print(f"✅ the {model_path/filename} exists")
            with open(filepath) as f:
                file_data = json.load(f)
            key = filename.replace('.json', '').replace('_test', '_test')
            data[key] = file_data if isinstance(file_data, list) else file_data.get('perplexities', [])
        else:
            print(f"❌ the {model_path/filename} doesn't exists")
            key = filename.replace('.json', '')
            data[key] = []
    
    return data

def compute_mean(values: list) -> float:
    """Compute mean, return 0 if empty"""
    return float(np.mean(values)) if values else 0.0

def create_comparison_plot(base_path: str, output_path: str = './results'):
    """Create 3-bar comparison plot for all models"""
    models = ['qwen', 'mistral', 'llama2', "llama3"]
    
    # Load all data
    all_data = {}
    for model in models:
        data = load_model_data(base_path, model)
        all_data[model] = {
            'harmful': compute_mean(data.get('harmful', [])),
            'harmful_test': compute_mean(data.get('harmful_test', [])),
            'normal': compute_mean(data.get('normal', []))
        }
        print(f"{model}: {len(data.get('harmful_test', []))} harmful_test, {len(data.get('harmful', []))} harmful, {len(data.get('normal', []))} normal")
    
    # Create plot
    fig = go.Figure()
    
    # Use exact color scheme and patterns from reference image
    
    # Add bars with patterns matching the reference image
    for i, data_type in enumerate(['harmful', 'harmful_test', 'normal']):
        values = [all_data[model][data_type] for model in models]
        if data_type == 'harmful':
            # Light purple with dot pattern (darker purple pattern)
            fig.add_trace(go.Bar(
                x=models,
                y=values,
                name='Harmful (Train Data)',
                marker=dict(
                    color='#E1BEE7',
                    line=dict(color='#6A1B9A', width=1.5),
                    pattern=dict(shape=".", fgcolor='#BA68C8', size=8)
                ),
                text=[f'{v:.2f}' for v in values],
                textposition='outside',
                textfont=dict(size=12, color='black')
            ))
        elif data_type == 'harmful_test':
            # Light teal with crosshatch pattern (darker teal pattern)
            fig.add_trace(go.Bar(
                x=models,
                y=values,
                name='Harmful Test (Test Data)',
                marker=dict(
                    color='#B2DFDB',
                    line=dict(color='#00695C', width=1.5),
                    pattern=dict(shape="x", fgcolor='#4DB6AC', size=8)
                ),
                text=[f'{v:.2f}' for v in values],
                textposition='outside',
                textfont=dict(size=12, color='black')
            ))
        else:  # normal
            # Orange with horizontal lines pattern
            fig.add_trace(go.Bar(
                x=models,
                y=values,
                name='Normal (Test Data)',
                marker=dict(
                    color='#64B5F6',
                    line=dict(color='#2874A6', width=1.5),
                    pattern=dict(shape="-", fgcolor='#3498DB', size=8)
                ),
                text=[f'{v:.2f}' for v in values],
                textposition='outside',
                textfont=dict(size=12, color='black')
            ))
    
    # Style with exact formatting from reference
    fig.update_layout(
        title={'text': 'Perplexity Comparison Across Models', 'x': 0.5, 'font': {'size': 26, 'color': 'black'}},
        xaxis_title='Model',
        yaxis_title='Perplexity',
        xaxis_title_font_size=20,
        yaxis_title_font_size=20,
        font={'family': 'Times New Roman', 'size': 16, 'color': 'black'},
        plot_bgcolor='#FFFEF7',
        paper_bgcolor='white',
        barmode='group',
        bargap=0.2,
        bargroupgap=0.05,
        width=600,
        height=400,
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5, font=dict(size=12)),
        yaxis=dict(tickfont=dict(size=16), gridcolor='lightgray', gridwidth=0.3, zeroline=True, zerolinecolor='gray', zerolinewidth=0.5)
    )
    
    # Save
    os.makedirs(output_path, exist_ok=True)
    fig.write_html(f"{output_path}/perplexity_model_comparison.html")
    fig.write_image(f"{output_path}/perplexity_model_comparison.pdf", width=700, height=500, scale=2)
    
    print(f"✓ Saved to {output_path}/perplexity_model_comparison.html and .pdf")
    return fig

def main():
    parser = ArgumentParser(description="Multi-model perplexity comparison")
    parser.add_argument("--base_path", default="utils/data", help="Base path with model directories")
    parser.add_argument("--output_path", default="utils/data", help="Output directory")
    args = parser.parse_args()
    
    create_comparison_plot(args.base_path, args.output_path)

if __name__ == "__main__":
    main()