File size: 5,483 Bytes
713f666 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | import json
import numpy as np
import plotly.graph_objects as go
from pathlib import Path
from argparse import ArgumentParser
import os
def load_model_data(base_path: str, model: str) -> dict:
"""Load perplexity data for one model (3 files)"""
model_path = Path(base_path) / model / 'perplexity'
data = {}
for filename in ['harmful_test.json', 'harmful.json', 'normal.json']:
filepath = model_path / filename
if filepath.exists():
print(f"✅ the {model_path/filename} exists")
with open(filepath) as f:
file_data = json.load(f)
key = filename.replace('.json', '').replace('_test', '_test')
data[key] = file_data if isinstance(file_data, list) else file_data.get('perplexities', [])
else:
print(f"❌ the {model_path/filename} doesn't exists")
key = filename.replace('.json', '')
data[key] = []
return data
def compute_mean(values: list) -> float:
"""Compute mean, return 0 if empty"""
return float(np.mean(values)) if values else 0.0
def create_comparison_plot(base_path: str, output_path: str = './results'):
"""Create 3-bar comparison plot for all models"""
models = ['qwen', 'mistral', 'llama2', "llama3"]
# Load all data
all_data = {}
for model in models:
data = load_model_data(base_path, model)
all_data[model] = {
'harmful': compute_mean(data.get('harmful', [])),
'harmful_test': compute_mean(data.get('harmful_test', [])),
'normal': compute_mean(data.get('normal', []))
}
print(f"{model}: {len(data.get('harmful_test', []))} harmful_test, {len(data.get('harmful', []))} harmful, {len(data.get('normal', []))} normal")
# Create plot
fig = go.Figure()
# Use exact color scheme and patterns from reference image
# Add bars with patterns matching the reference image
for i, data_type in enumerate(['harmful', 'harmful_test', 'normal']):
values = [all_data[model][data_type] for model in models]
if data_type == 'harmful':
# Light purple with dot pattern (darker purple pattern)
fig.add_trace(go.Bar(
x=models,
y=values,
name='Harmful (Train Data)',
marker=dict(
color='#E1BEE7',
line=dict(color='#6A1B9A', width=1.5),
pattern=dict(shape=".", fgcolor='#BA68C8', size=8)
),
text=[f'{v:.2f}' for v in values],
textposition='outside',
textfont=dict(size=12, color='black')
))
elif data_type == 'harmful_test':
# Light teal with crosshatch pattern (darker teal pattern)
fig.add_trace(go.Bar(
x=models,
y=values,
name='Harmful Test (Test Data)',
marker=dict(
color='#B2DFDB',
line=dict(color='#00695C', width=1.5),
pattern=dict(shape="x", fgcolor='#4DB6AC', size=8)
),
text=[f'{v:.2f}' for v in values],
textposition='outside',
textfont=dict(size=12, color='black')
))
else: # normal
# Orange with horizontal lines pattern
fig.add_trace(go.Bar(
x=models,
y=values,
name='Normal (Test Data)',
marker=dict(
color='#64B5F6',
line=dict(color='#2874A6', width=1.5),
pattern=dict(shape="-", fgcolor='#3498DB', size=8)
),
text=[f'{v:.2f}' for v in values],
textposition='outside',
textfont=dict(size=12, color='black')
))
# Style with exact formatting from reference
fig.update_layout(
title={'text': 'Perplexity Comparison Across Models', 'x': 0.5, 'font': {'size': 26, 'color': 'black'}},
xaxis_title='Model',
yaxis_title='Perplexity',
xaxis_title_font_size=20,
yaxis_title_font_size=20,
font={'family': 'Times New Roman', 'size': 16, 'color': 'black'},
plot_bgcolor='#FFFEF7',
paper_bgcolor='white',
barmode='group',
bargap=0.2,
bargroupgap=0.05,
width=600,
height=400,
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5, font=dict(size=12)),
yaxis=dict(tickfont=dict(size=16), gridcolor='lightgray', gridwidth=0.3, zeroline=True, zerolinecolor='gray', zerolinewidth=0.5)
)
# Save
os.makedirs(output_path, exist_ok=True)
fig.write_html(f"{output_path}/perplexity_model_comparison.html")
fig.write_image(f"{output_path}/perplexity_model_comparison.pdf", width=700, height=500, scale=2)
print(f"✓ Saved to {output_path}/perplexity_model_comparison.html and .pdf")
return fig
def main():
parser = ArgumentParser(description="Multi-model perplexity comparison")
parser.add_argument("--base_path", default="utils/data", help="Base path with model directories")
parser.add_argument("--output_path", default="utils/data", help="Output directory")
args = parser.parse_args()
create_comparison_plot(args.base_path, args.output_path)
if __name__ == "__main__":
main() |