Maheep's picture
Add files using upload-large-folder tool
713f666 verified
import json
import numpy as np
import plotly.graph_objects as go
from pathlib import Path
from argparse import ArgumentParser
import os
def load_model_data(base_path: str, model: str) -> dict:
"""Load perplexity data for one model (3 files)"""
model_path = Path(base_path) / model / 'perplexity'
data = {}
for filename in ['harmful_test.json', 'harmful.json', 'normal.json']:
filepath = model_path / filename
if filepath.exists():
print(f"✅ the {model_path/filename} exists")
with open(filepath) as f:
file_data = json.load(f)
key = filename.replace('.json', '').replace('_test', '_test')
data[key] = file_data if isinstance(file_data, list) else file_data.get('perplexities', [])
else:
print(f"❌ the {model_path/filename} doesn't exists")
key = filename.replace('.json', '')
data[key] = []
return data
def compute_mean(values: list) -> float:
"""Compute mean, return 0 if empty"""
return float(np.mean(values)) if values else 0.0
def create_comparison_plot(base_path: str, output_path: str = './results'):
"""Create 3-bar comparison plot for all models"""
models = ['qwen', 'mistral', 'llama2', "llama3"]
# Load all data
all_data = {}
for model in models:
data = load_model_data(base_path, model)
all_data[model] = {
'harmful': compute_mean(data.get('harmful', [])),
'harmful_test': compute_mean(data.get('harmful_test', [])),
'normal': compute_mean(data.get('normal', []))
}
print(f"{model}: {len(data.get('harmful_test', []))} harmful_test, {len(data.get('harmful', []))} harmful, {len(data.get('normal', []))} normal")
# Create plot
fig = go.Figure()
# Use exact color scheme and patterns from reference image
# Add bars with patterns matching the reference image
for i, data_type in enumerate(['harmful', 'harmful_test', 'normal']):
values = [all_data[model][data_type] for model in models]
if data_type == 'harmful':
# Light purple with dot pattern (darker purple pattern)
fig.add_trace(go.Bar(
x=models,
y=values,
name='Harmful (Train Data)',
marker=dict(
color='#E1BEE7',
line=dict(color='#6A1B9A', width=1.5),
pattern=dict(shape=".", fgcolor='#BA68C8', size=8)
),
text=[f'{v:.2f}' for v in values],
textposition='outside',
textfont=dict(size=12, color='black')
))
elif data_type == 'harmful_test':
# Light teal with crosshatch pattern (darker teal pattern)
fig.add_trace(go.Bar(
x=models,
y=values,
name='Harmful Test (Test Data)',
marker=dict(
color='#B2DFDB',
line=dict(color='#00695C', width=1.5),
pattern=dict(shape="x", fgcolor='#4DB6AC', size=8)
),
text=[f'{v:.2f}' for v in values],
textposition='outside',
textfont=dict(size=12, color='black')
))
else: # normal
# Orange with horizontal lines pattern
fig.add_trace(go.Bar(
x=models,
y=values,
name='Normal (Test Data)',
marker=dict(
color='#64B5F6',
line=dict(color='#2874A6', width=1.5),
pattern=dict(shape="-", fgcolor='#3498DB', size=8)
),
text=[f'{v:.2f}' for v in values],
textposition='outside',
textfont=dict(size=12, color='black')
))
# Style with exact formatting from reference
fig.update_layout(
title={'text': 'Perplexity Comparison Across Models', 'x': 0.5, 'font': {'size': 26, 'color': 'black'}},
xaxis_title='Model',
yaxis_title='Perplexity',
xaxis_title_font_size=20,
yaxis_title_font_size=20,
font={'family': 'Times New Roman', 'size': 16, 'color': 'black'},
plot_bgcolor='#FFFEF7',
paper_bgcolor='white',
barmode='group',
bargap=0.2,
bargroupgap=0.05,
width=600,
height=400,
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5, font=dict(size=12)),
yaxis=dict(tickfont=dict(size=16), gridcolor='lightgray', gridwidth=0.3, zeroline=True, zerolinecolor='gray', zerolinewidth=0.5)
)
# Save
os.makedirs(output_path, exist_ok=True)
fig.write_html(f"{output_path}/perplexity_model_comparison.html")
fig.write_image(f"{output_path}/perplexity_model_comparison.pdf", width=700, height=500, scale=2)
print(f"✓ Saved to {output_path}/perplexity_model_comparison.html and .pdf")
return fig
def main():
parser = ArgumentParser(description="Multi-model perplexity comparison")
parser.add_argument("--base_path", default="utils/data", help="Base path with model directories")
parser.add_argument("--output_path", default="utils/data", help="Output directory")
args = parser.parse_args()
create_comparison_plot(args.base_path, args.output_path)
if __name__ == "__main__":
main()