import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import sys import re def load_loss_data(log_path='train_output.log'): """Load loss data from training log.""" steps_after, losses_after = [], [] with open(log_path) as f: for line in f: m = re.search(r'Step (\d+) .* Loss: ([\d.]+)', line) if m: steps_after.append(int(m.group(1))) losses_after.append(float(m.group(2))) if not steps_after: return None, None return steps_after, losses_after EARLY_LOSS = [ (0, 2.9836), (500, 1.2863), (1000, 0.8944), (1500, 0.6346), (2000, 0.4688), (2500, 0.3735), (3000, 0.2973), (3500, 0.2215), (4000, 0.1777), (4500, 0.1588), (5000, 0.1440), (5500, 0.1289), (6000, 0.1050), (6500, 0.1028), (7000, 0.1009), (7500, 0.0914), (8000, 0.0778), (8500, 0.0769), (9000, 0.0704), (9500, 0.0686), (10000, 0.0640), (10500, 0.0696), (11000, 0.0676), (11500, 0.0663), (12000, 0.0492), (12500, 0.0590), (13000, 0.0515), (13500, 0.0495), (14000, 0.0507), (14500, 0.0522), (15000, 0.0402), (15500, 0.0414), (16000, 0.0484), (16500, 0.0444), (17000, 0.0380), (17500, 0.0399), (18000, 0.0384), (18500, 0.0359), (19000, 0.0379), (19500, 0.0362), (20000, 0.0339), (20500, 0.0338), ] def plot(output_path='training_curves.png'): steps_after, losses_after = load_loss_data() steps_early = [s for s, _ in EARLY_LOSS] losses_early = [l for _, l in EARLY_LOSS] fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4.5)) for ax in [ax1, ax2]: ax.set_facecolor('#f8f9fa') if steps_early: ax1.plot(steps_early, losses_early, color='#2563eb', linewidth=1.5, alpha=0.7, label='Steps 0–20500') if steps_after: ax1.plot(steps_after, losses_after, color='#dc2626', linewidth=1.5, alpha=0.7, label='Steps 21000–49500') all_steps = (steps_early or []) + (steps_after or []) all_losses = (losses_early or []) + (losses_after or []) if all_steps: ax1.fill_between(all_steps, all_losses, alpha=0.06, color='#2563eb') ax1.axvline(x=21000, color='#888', linewidth=0.8, linestyle=':', alpha=0.6) ax1.text(21000, max(all_losses) * 0.9, 'resume', fontsize=8, color='#888', ha='center', va='top', style='italic') ax1.set_xlabel('Training Step', fontsize=10) ax1.set_ylabel('Cross-Entropy Loss', fontsize=10) ax1.set_title('Training Loss', fontsize=12, fontweight='bold', pad=10) ax1.set_yscale('log') ax1.grid(True, alpha=0.25, linestyle='--') ax1.legend(fontsize=8, loc='upper right') bars = ax2.bar(['Random 500\nExpressions', 'Fixed\nBenchmark'], [91.6, 100.0], color=['#2563eb', '#16a34a'], width=0.5, edgecolor='white', linewidth=1.5) for bar, val in zip(bars, [91.6, 100.0]): ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1.2, f'{val:.1f}%', ha='center', va='bottom', fontsize=11, fontweight='bold') ax2.set_ylim(0, 108) ax2.set_ylabel('Accuracy', fontsize=10) ax2.set_title('Inference Accuracy', fontsize=12, fontweight='bold', pad=10) ax2.grid(True, alpha=0.25, linestyle='--', axis='y') ax2.spines['top'].set_visible(False) ax2.spines['right'].set_visible(False) fig.suptitle('Arithmetic Reasoner — TrueACT 1-Layer', fontsize=14, fontweight='bold', y=1.02) plt.tight_layout() plt.savefig(output_path, dpi=150, bbox_inches='tight') print(f'Saved {output_path}') if __name__ == '__main__': out = sys.argv[1] if len(sys.argv) > 1 else 'training_curves.png' plot(out)