File size: 3,308 Bytes
80603a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""Generate publication-quality PPL convergence figure (white background, PDF-ready)."""
import json, glob, math
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib import rcParams

# Publication styling
rcParams['font.family'] = 'sans-serif'
rcParams['font.sans-serif'] = ['DejaVu Sans']
rcParams['font.size'] = 12
rcParams['axes.linewidth'] = 1.2

files = sorted(glob.glob('models/*_s[0-9]*_results.json'))
runs = []
for f in files:
    with open(f) as fh:
        d = json.load(fh)
        if len(d.get('log', [])) >= 10:
            runs.append(d)

# ── Figure 1: PPL Convergence ──
fig, ax = plt.subplots(figsize=(8, 5))

color_map = {
    ('pentanet', 42): '#1f77b4', ('pentanet', 1337): '#2ca02c', ('pentanet', 2026): '#9467bd',
    ('bitnet', 42): '#d62728',   ('bitnet', 1337): '#ff7f0e',   ('bitnet', 2026): '#8c564b',
}
style_map = {'pentanet': '-', 'bitnet': '--'}

for run in runs:
    logs = run['log'][1:]  # skip iter 0
    iters = [l['iter'] for l in logs]
    ppls  = [l['ppl'] for l in logs]
    mode  = run['mode']
    seed  = run['seed']
    color = color_map.get((mode, seed), '#333333')
    ax.plot(iters, ppls,
            linestyle=style_map[mode],
            color=color,
            linewidth=2.0,
            label=f'{mode.upper()} (seed {seed})')

ax.set_yscale('log')
ax.set_xlabel('Training Iterations', fontsize=13)
ax.set_ylabel('Validation Perplexity', fontsize=13)
ax.set_title('PentaNet vs BitNet β€” Perplexity Convergence on WikiText-103', fontsize=14, fontweight='bold')
ax.legend(fontsize=9, ncol=2, loc='upper right', framealpha=0.9)
ax.grid(True, which='both', ls='--', alpha=0.3)
ax.tick_params(labelsize=11)

fig.tight_layout()
fig.savefig('figure1_ppl_convergence.png', dpi=300, facecolor='white')
fig.savefig('figure1_ppl_convergence.pdf', facecolor='white')
print("βœ… figure1_ppl_convergence.png (300 dpi, white bg)")
print("βœ… figure1_ppl_convergence.pdf")

# ── Figure 2: Weight Distribution Evolution ──
fig2, ax2 = plt.subplots(figsize=(8, 5))

penta42 = [r for r in runs if r['mode'] == 'pentanet' and r['seed'] == 42]
if penta42:
    logs = penta42[0]['log']
    iters = [l['iter'] for l in logs]
    buckets = ['-2', '-1', '0', '1', '2']
    colors = ['#d62728', '#1f77b4', '#2ca02c', '#9467bd', '#ff7f0e']
    
    for b, c in zip(buckets, colors):
        pcts = []
        for l in logs:
            w = l['weights']
            total = sum(w.values())
            pcts.append(w[b] / total * 100)
        ax2.plot(iters, pcts, color=c, linewidth=2.5, label=f'Bucket [{b}]')

    ax2.set_xlabel('Training Iterations', fontsize=13)
    ax2.set_ylabel('Weight Distribution (%)', fontsize=13)
    ax2.set_title('PentaNet Weight Bucket Stability (Seed 42)', fontsize=14, fontweight='bold')
    ax2.legend(fontsize=10, ncol=5, loc='upper center', bbox_to_anchor=(0.5, -0.12), framealpha=0.9)
    ax2.grid(True, ls='--', alpha=0.3)
    ax2.tick_params(labelsize=11)
    ax2.set_ylim(0, 40)

    fig2.tight_layout()
    fig2.savefig('figure2_weight_distribution.png', dpi=300, facecolor='white')
    fig2.savefig('figure2_weight_distribution.pdf', facecolor='white')
    print("βœ… figure2_weight_distribution.png (300 dpi, white bg)")
    print("βœ… figure2_weight_distribution.pdf")