pentanet-124m / scripts /export_figures.py
Kyworn's picture
Upload folder using huggingface_hub
80603a0 verified
"""Generate publication-quality PPL convergence figure (white background, PDF-ready)."""
import json, glob, math
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib import rcParams
# Publication styling
rcParams['font.family'] = 'sans-serif'
rcParams['font.sans-serif'] = ['DejaVu Sans']
rcParams['font.size'] = 12
rcParams['axes.linewidth'] = 1.2
files = sorted(glob.glob('models/*_s[0-9]*_results.json'))
runs = []
for f in files:
with open(f) as fh:
d = json.load(fh)
if len(d.get('log', [])) >= 10:
runs.append(d)
# ── Figure 1: PPL Convergence ──
fig, ax = plt.subplots(figsize=(8, 5))
color_map = {
('pentanet', 42): '#1f77b4', ('pentanet', 1337): '#2ca02c', ('pentanet', 2026): '#9467bd',
('bitnet', 42): '#d62728', ('bitnet', 1337): '#ff7f0e', ('bitnet', 2026): '#8c564b',
}
style_map = {'pentanet': '-', 'bitnet': '--'}
for run in runs:
logs = run['log'][1:] # skip iter 0
iters = [l['iter'] for l in logs]
ppls = [l['ppl'] for l in logs]
mode = run['mode']
seed = run['seed']
color = color_map.get((mode, seed), '#333333')
ax.plot(iters, ppls,
linestyle=style_map[mode],
color=color,
linewidth=2.0,
label=f'{mode.upper()} (seed {seed})')
ax.set_yscale('log')
ax.set_xlabel('Training Iterations', fontsize=13)
ax.set_ylabel('Validation Perplexity', fontsize=13)
ax.set_title('PentaNet vs BitNet β€” Perplexity Convergence on WikiText-103', fontsize=14, fontweight='bold')
ax.legend(fontsize=9, ncol=2, loc='upper right', framealpha=0.9)
ax.grid(True, which='both', ls='--', alpha=0.3)
ax.tick_params(labelsize=11)
fig.tight_layout()
fig.savefig('figure1_ppl_convergence.png', dpi=300, facecolor='white')
fig.savefig('figure1_ppl_convergence.pdf', facecolor='white')
print("βœ… figure1_ppl_convergence.png (300 dpi, white bg)")
print("βœ… figure1_ppl_convergence.pdf")
# ── Figure 2: Weight Distribution Evolution ──
fig2, ax2 = plt.subplots(figsize=(8, 5))
penta42 = [r for r in runs if r['mode'] == 'pentanet' and r['seed'] == 42]
if penta42:
logs = penta42[0]['log']
iters = [l['iter'] for l in logs]
buckets = ['-2', '-1', '0', '1', '2']
colors = ['#d62728', '#1f77b4', '#2ca02c', '#9467bd', '#ff7f0e']
for b, c in zip(buckets, colors):
pcts = []
for l in logs:
w = l['weights']
total = sum(w.values())
pcts.append(w[b] / total * 100)
ax2.plot(iters, pcts, color=c, linewidth=2.5, label=f'Bucket [{b}]')
ax2.set_xlabel('Training Iterations', fontsize=13)
ax2.set_ylabel('Weight Distribution (%)', fontsize=13)
ax2.set_title('PentaNet Weight Bucket Stability (Seed 42)', fontsize=14, fontweight='bold')
ax2.legend(fontsize=10, ncol=5, loc='upper center', bbox_to_anchor=(0.5, -0.12), framealpha=0.9)
ax2.grid(True, ls='--', alpha=0.3)
ax2.tick_params(labelsize=11)
ax2.set_ylim(0, 40)
fig2.tight_layout()
fig2.savefig('figure2_weight_distribution.png', dpi=300, facecolor='white')
fig2.savefig('figure2_weight_distribution.pdf', facecolor='white')
print("βœ… figure2_weight_distribution.png (300 dpi, white bg)")
print("βœ… figure2_weight_distribution.pdf")