Marketing-Memory-Routing-8B / assets /generate_charts.py
MuratcanKoylan's picture
Upload folder using huggingface_hub
685d968 verified
"""
Generate training visualization charts for README
"""
import json
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
import numpy as np
import os
# Create assets directory
os.makedirs('assets', exist_ok=True)
# Load SFT metrics
sft_metrics = []
with open('training/logs/run_20251124_200256/sft_metrics.jsonl', 'r') as f:
for line in f:
sft_metrics.append(json.loads(line))
# Load RL metrics
rl_metrics = []
with open('training/logs/run_20251124_200256/rl_metrics.jsonl', 'r') as f:
for line in f:
rl_metrics.append(json.loads(line))
# Style settings
plt.style.use('seaborn-v0_8-whitegrid')
colors = {
'train': '#2563eb',
'test': '#dc2626',
'reward': '#059669',
'f1': '#7c3aed',
'our_model': '#2563eb',
'cohere': '#dc2626'
}
# ============ Chart 1: SFT Loss Curve ============
fig, ax = plt.subplots(figsize=(10, 5))
steps = [m['step'] for m in sft_metrics]
train_loss = [m['train_loss'] for m in sft_metrics]
# Get test loss points
test_steps = [m['step'] for m in sft_metrics if 'test_loss' in m]
test_loss = [m['test_loss'] for m in sft_metrics if 'test_loss' in m]
ax.plot(steps, train_loss, color=colors['train'], linewidth=2, label='Train Loss', alpha=0.8)
ax.scatter(test_steps, test_loss, color=colors['test'], s=80, zorder=5, label='Test Loss', marker='o')
ax.plot(test_steps, test_loss, color=colors['test'], linewidth=2, linestyle='--', alpha=0.5)
ax.set_xlabel('Training Step', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('SFT Training: Loss Convergence', fontsize=14, fontweight='bold')
ax.legend(loc='upper right', fontsize=10)
ax.set_ylim(0, 6)
# Add annotations
ax.annotate(f'Final: {train_loss[-1]:.3f}', xy=(steps[-1], train_loss[-1]),
xytext=(steps[-1]-15, train_loss[-1]+0.5),
fontsize=9, color=colors['train'])
ax.annotate(f'Best Test: {min(test_loss):.3f}', xy=(test_steps[test_loss.index(min(test_loss))], min(test_loss)),
xytext=(test_steps[test_loss.index(min(test_loss))]+5, min(test_loss)+0.3),
fontsize=9, color=colors['test'])
plt.tight_layout()
plt.savefig('assets/sft_loss.png', dpi=150, bbox_inches='tight')
plt.close()
print("Saved: assets/sft_loss.png")
# ============ Chart 2: RL Reward Progression ============
fig, ax = plt.subplots(figsize=(10, 5))
iterations = [m['iteration'] for m in rl_metrics]
mean_reward = [m['mean_reward'] for m in rl_metrics]
std_reward = [m['std_reward'] for m in rl_metrics]
# Plot with confidence band
ax.fill_between(iterations,
[r - s for r, s in zip(mean_reward, std_reward)],
[r + s for r, s in zip(mean_reward, std_reward)],
alpha=0.2, color=colors['reward'])
ax.plot(iterations, mean_reward, color=colors['reward'], linewidth=2.5, label='Mean Reward')
ax.set_xlabel('RL Iteration', fontsize=12)
ax.set_ylabel('Reward', fontsize=12)
ax.set_title('RL Training: Reward Progression', fontsize=14, fontweight='bold')
ax.legend(loc='lower right', fontsize=10)
ax.set_ylim(0.5, 1.0)
# Add annotations
ax.annotate(f'Start: {mean_reward[0]:.3f}', xy=(0, mean_reward[0]),
xytext=(2, mean_reward[0]-0.05), fontsize=9, color=colors['reward'])
ax.annotate(f'Peak: {max(mean_reward):.3f}', xy=(mean_reward.index(max(mean_reward)), max(mean_reward)),
xytext=(mean_reward.index(max(mean_reward))+2, max(mean_reward)+0.02),
fontsize=9, color=colors['reward'])
plt.tight_layout()
plt.savefig('assets/rl_reward.png', dpi=150, bbox_inches='tight')
plt.close()
print("Saved: assets/rl_reward.png")
# ============ Chart 3: Reward Components ============
fig, ax = plt.subplots(figsize=(10, 5))
r_f1 = [m['mean_r_f1'] for m in rl_metrics]
r_temp = [m['mean_r_temp'] for m in rl_metrics]
r_parity = [m['mean_r_parity'] for m in rl_metrics]
r_eff = [m['mean_r_eff'] for m in rl_metrics]
ax.plot(iterations, r_f1, label='R_F1 (60%)', linewidth=2, color='#2563eb')
ax.plot(iterations, r_temp, label='R_temp (20%)', linewidth=2, color='#7c3aed')
ax.plot(iterations, r_parity, label='R_parity (10%)', linewidth=2, color='#059669')
ax.plot(iterations, r_eff, label='R_eff (10%)', linewidth=2, color='#f59e0b')
ax.set_xlabel('RL Iteration', fontsize=12)
ax.set_ylabel('Reward Component', fontsize=12)
ax.set_title('RL Training: Reward Components', fontsize=14, fontweight='bold')
ax.legend(loc='lower right', fontsize=10)
ax.set_ylim(0.5, 1.05)
plt.tight_layout()
plt.savefig('assets/rl_components.png', dpi=150, bbox_inches='tight')
plt.close()
print("Saved: assets/rl_components.png")
# ============ Chart 4: Model Comparison ============
fig, ax = plt.subplots(figsize=(8, 5))
metrics = ['Avg F1', 'Exact Match', 'Any Match']
our_model = [0.68, 0.60, 0.72]
cohere = [0.61, 0.26, 0.82]
x = np.arange(len(metrics))
width = 0.35
bars1 = ax.bar(x - width/2, our_model, width, label='Ours (8B)', color=colors['our_model'])
bars2 = ax.bar(x + width/2, cohere, width, label='Cohere (104B)', color=colors['cohere'])
ax.set_ylabel('Score', fontsize=12)
ax.set_title('Model Comparison: 50 Marketing Scenarios', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(metrics, fontsize=11)
ax.legend(loc='upper right', fontsize=10)
ax.set_ylim(0, 1.0)
# Add value labels
for bar in bars1:
height = bar.get_height()
ax.annotate(f'{height:.0%}',
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 3), textcoords="offset points",
ha='center', va='bottom', fontsize=10, fontweight='bold')
for bar in bars2:
height = bar.get_height()
ax.annotate(f'{height:.0%}',
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 3), textcoords="offset points",
ha='center', va='bottom', fontsize=10)
plt.tight_layout()
plt.savefig('assets/model_comparison.png', dpi=150, bbox_inches='tight')
plt.close()
print("Saved: assets/model_comparison.png")
# ============ Chart 5: Performance by Difficulty ============
fig, ax = plt.subplots(figsize=(8, 5))
difficulties = ['Easy', 'Medium', 'Hard']
our_f1 = [0.86, 0.65, 0.50]
cohere_f1 = [0.48, 0.64, 0.72]
x = np.arange(len(difficulties))
width = 0.35
bars1 = ax.bar(x - width/2, our_f1, width, label='Ours (8B)', color=colors['our_model'])
bars2 = ax.bar(x + width/2, cohere_f1, width, label='Cohere (104B)', color=colors['cohere'])
ax.set_ylabel('F1 Score', fontsize=12)
ax.set_title('F1 Score by Difficulty Level', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(difficulties, fontsize=11)
ax.legend(loc='upper right', fontsize=10)
ax.set_ylim(0, 1.0)
# Add value labels
for bar in bars1:
height = bar.get_height()
ax.annotate(f'{height:.2f}',
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 3), textcoords="offset points",
ha='center', va='bottom', fontsize=10, fontweight='bold')
for bar in bars2:
height = bar.get_height()
ax.annotate(f'{height:.2f}',
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 3), textcoords="offset points",
ha='center', va='bottom', fontsize=10)
plt.tight_layout()
plt.savefig('assets/difficulty_comparison.png', dpi=150, bbox_inches='tight')
plt.close()
print("Saved: assets/difficulty_comparison.png")
print("\nAll charts generated successfully!")