Pathikreet's picture
sync training files
ea9c69b verified
Raw
History Blame Contribute Delete
65.8 kB
"""
AP Commander — GRPO Training Script
Tracks: overall reward, per-component rewards, decision distribution,
format compliance, env errors, sample generations, reward curve.
"""
import os, json, re, random, time, datetime, collections
# Reduce CUDA memory fragmentation — must be set before torch is imported
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')
class _StopTraining(Exception):
"""Raised inside reward fn when /app/stop_requested flag is found."""
import requests
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
ENV_URL = 'https://pathikreet-ap-clerk-env.hf.space'
MODEL_NAME = os.environ.get('MODEL_NAME', 'Qwen/Qwen2.5-7B-Instruct')
NUM_EPOCHS = int(os.environ.get('NUM_EPOCHS', '6'))
NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', '16')) # 32 OOMs on 7B/A10G; 16 fits (~16 GB est.)
LOG_SAMPLES_EVERY = 20 # print a sample generation every N reward calls
SYSTEM_PROMPT = """You are an AI Accounts Payable Clerk. Review the invoice, PO, and GRN, then output ONLY a single valid JSON object. No prose, no markdown, no explanation outside the JSON.
Valid decisions: APPROVE_FULL | APPROVE_PARTIAL | REJECT | ESCALATE | QUERY_VENDOR
Valid reason codes: MATCH_CONFIRMED | QUANTITY_MISMATCH | PRICE_DISCREPANCY | POLICY_VIOLATION | NO_PO_FOUND | DUPLICATE_INVOICE | VENDOR_MISMATCH | TAX_DISCREPANCY | PENDING_CLARIFICATION | MANAGER_REVIEW
Example (do not copy — generate your own based on the actual invoice):
{"decision": "REJECT", "approved_amount": 0.0, "reason_code": "NO_PO_FOUND", "explanation": "Invoice INV-2024-5821 rejected: no open PO found for vendor TechCorp. Policy Rule 5 mandates a valid OPEN PO before payment."}
Your response must start with { and end with } with no other text."""
TRAIN_TASKS = [
'easy_perfect_match', 'easy_no_po_found',
'medium_quantity_shortfall', 'medium_price_discrepancy',
'medium_split_delivery', 'medium_vendor_mismatch',
'hard_policy_violation', 'hard_duplicate_invoice',
'hard_partial_po_match', 'hard_tax_discrepancy',
'hard_currency_conversion', 'hard_manager_preapproval', 'hard_credit_memo',
'long_invoice_dispute', 'long_policy_migration',
'long_batch_reconciliation', 'long_manager_chain',
'long_fraud_investigation', 'long_audit_trail',
'long_multi_vendor_split',
]
EVAL_TASKS = TRAIN_TASKS
VALID_DECISIONS = {'APPROVE_FULL','APPROVE_PARTIAL','REJECT','ESCALATE','QUERY_VENDOR','HOLD'}
VALID_REASON_CODES = {'MATCH_CONFIRMED','QUANTITY_MISMATCH','PRICE_DISCREPANCY','POLICY_VIOLATION',
'NO_PO_FOUND','DUPLICATE_INVOICE','VENDOR_MISMATCH','TAX_DISCREPANCY',
'PENDING_CLARIFICATION','MANAGER_REVIEW'}
_TASK_DIFFICULTY = {
'easy_perfect_match': 'easy', 'easy_no_po_found': 'easy',
'medium_quantity_shortfall': 'medium', 'medium_price_discrepancy': 'medium',
'medium_split_delivery': 'medium', 'medium_vendor_mismatch': 'medium',
'hard_policy_violation': 'hard', 'hard_duplicate_invoice': 'hard',
'hard_partial_po_match': 'hard', 'hard_tax_discrepancy': 'hard',
'hard_currency_conversion': 'hard', 'hard_manager_preapproval': 'hard',
'hard_credit_memo': 'hard',
'long_invoice_dispute': 'long', 'long_policy_migration': 'long',
'long_batch_reconciliation': 'long', 'long_manager_chain': 'long',
'long_fraud_investigation': 'long', 'long_audit_trail': 'long',
'long_multi_vendor_split': 'long',
}
_DIFFICULTY_ORDER = ['easy', 'medium', 'hard', 'long']
_UNLOCK_THRESHOLDS = {'easy': 0.70, 'medium': 0.65, 'hard': 0.60}
# ── Curriculum sampler ──────────────────────────────────────────────────────────
class CurriculumSampler:
"""Tracks per-difficulty running mean; unlocks harder tiers once thresholds are met."""
def __init__(self):
self._rewards: dict = collections.defaultdict(list) # task_id → [rewards]
self.unlocked: set = {'easy'}
def record(self, task_id: str, reward: float):
self._rewards[task_id].append(reward)
self._try_unlock()
def mean_for_difficulty(self, diff: str) -> float:
vals = []
for tid, d in _TASK_DIFFICULTY.items():
if d == diff:
vals.extend(self._rewards.get(tid, []))
return sum(vals) / len(vals) if vals else 0.0
def _try_unlock(self):
for i, diff in enumerate(_DIFFICULTY_ORDER[:-1]):
if diff in self.unlocked:
m = self.mean_for_difficulty(diff)
if m >= _UNLOCK_THRESHOLDS.get(diff, 0.70):
nxt = _DIFFICULTY_ORDER[i + 1]
if nxt not in self.unlocked:
self.unlocked.add(nxt)
print(f'\n[CURRICULUM] Unlocked {nxt}! mean({diff})={m:.3f} '
f'>= threshold {_UNLOCK_THRESHOLDS[diff]}')
def gate_task(self, task_id: str) -> str:
if _TASK_DIFFICULTY.get(task_id, 'easy') in self.unlocked:
return task_id
easiest = [t for t, d in _TASK_DIFFICULTY.items() if d == 'easy']
return random.choice(easiest)
def build_dataset_tasks(self) -> list:
"""Return (task_id, seed) pairs for all currently unlocked difficulties."""
rows = []
seeds_per_diff = {'easy': 10, 'medium': 5, 'hard': 2, 'long': 2}
for task_id, diff in _TASK_DIFFICULTY.items():
if diff in self.unlocked:
n = seeds_per_diff[diff]
rows.extend([(task_id, s) for s in range(1, n + 1)])
return rows
def status_line(self) -> str:
parts = []
for d in _DIFFICULTY_ORDER:
m = self.mean_for_difficulty(d)
unlk = '✓' if d in self.unlocked else '✗'
parts.append(f'{d}={m:.2f}{unlk}')
return ' | '.join(parts)
CURRICULUM = CurriculumSampler()
# ── Per-step greedy follow-up policy ───────────────────────────────────────────
def _greedy_followup(obs_dict: dict) -> dict:
"""Scripted terminal action after an intermediate step, based on revealed context notes."""
notes = ' '.join(obs_dict.get('context_notes', [])).lower()
total = abs(float(obs_dict.get('invoice', {}).get('invoice_total', 0) or 0))
# Manager / VP approved → APPROVE_FULL
if any(k in notes for k in ('manager approved', 'vp approved', 'cfo approved',
'pre-approved', 'pre-approv', 'approved by')):
return {'decision': 'APPROVE_FULL', 'approved_amount': total,
'reason_code': 'MATCH_CONFIRMED',
'explanation': f'Approval confirmed via escalation chain. Approving ${total:.2f}.'}
# Compliance cleared → APPROVE_FULL
if 'compliance' in notes and any(k in notes for k in ('cleared', 'approved', 'pass')):
return {'decision': 'APPROVE_FULL', 'approved_amount': total,
'reason_code': 'MATCH_CONFIRMED',
'explanation': f'Compliance review cleared. Approving ${total:.2f}.'}
# Fraudulent / duplicate / deny → REJECT
if any(k in notes for k in ('fraudulent', 'duplicate', 'already paid', 'deny',
'invalid', 'false claim')):
return {'decision': 'REJECT', 'approved_amount': 0.0,
'reason_code': 'DUPLICATE_INVOICE',
'explanation': 'Vendor response or audit confirms fraud/duplicate. Rejecting.'}
# Compliance flagged / SOX violation → REJECT
if any(k in notes for k in ('flagged', 'violation', 'sox', 'gdpr', 'non-compliant')):
return {'decision': 'REJECT', 'approved_amount': 0.0,
'reason_code': 'POLICY_VIOLATION',
'explanation': 'Compliance review flagged a violation. Rejecting.'}
# Confused vendor / ambiguous → ESCALATE
if any(k in notes for k in ('confused', 'unclear', 'unable to confirm')):
return {'decision': 'ESCALATE', 'approved_amount': 0.0,
'reason_code': 'MANAGER_REVIEW',
'explanation': 'Vendor response ambiguous. Escalating to manager.'}
# Default: safe rejection
return {'decision': 'REJECT', 'approved_amount': 0.0,
'reason_code': 'PENDING_CLARIFICATION',
'explanation': 'Could not resolve after investigation. Rejecting for safety.'}
# ── Metrics tracker ────────────────────────────────────────────────────────────
class Metrics:
def __init__(self):
self.step = 0
self.reward_history = [] # (step, mean_reward) — overall
self.loss_history = [] # (step, loss, grad_norm) — from TRL callback
self.diff_reward_hist = collections.defaultdict(list) # diff → [(step, mean)]
self.format_history = [] # (step, format_rate) — compliance over time
self.episode_len_hist = [] # all episode lengths (for histogram)
self.ep_len_by_task = collections.defaultdict(list) # task_id → [lengths]
self.ep_len_history = [] # (step, mean_ep_len) — rolling mean per step
self.decision_history = [] # [(step, Counter)] for stacked-bar over time
self.decision_counts = collections.Counter()
self.parse_failures = 0
self.env_errors = 0
self.format_scores = []
self.reward_by_task = collections.defaultdict(list)
self.total_calls = 0
self._start_time = time.time()
self._step_decisions = collections.Counter() # decisions in current step batch
def log_step(self, rewards, decisions, format_ok_list, task_ids, errors,
episode_lengths=None):
self.step += 1
self.total_calls += len(rewards)
mean_r = sum(rewards) / len(rewards) if rewards else 0.0
self.reward_history.append((self.step, mean_r))
diff_rewards: dict = collections.defaultdict(list)
for tid, r in zip(task_ids, rewards):
d = _TASK_DIFFICULTY.get(tid, 'easy')
diff_rewards[d].append(r)
for d, rs in diff_rewards.items():
self.diff_reward_hist[d].append((self.step, sum(rs) / len(rs)))
for d in decisions:
self.decision_counts[d] += 1
self._step_decisions[d] += 1
# Snapshot decision distribution every step for stacked-bar
self.decision_history.append((self.step, dict(self._step_decisions)))
fmt_ok_count = sum(1 for ok in format_ok_list if ok)
fmt_rate = fmt_ok_count / len(format_ok_list) if format_ok_list else 0.0
self.format_history.append((self.step, fmt_rate))
for ok in format_ok_list:
self.format_scores.append(1.0 if ok else 0.0)
for tid, r in zip(task_ids, rewards):
self.reward_by_task[tid].append(r)
if episode_lengths:
for tid, ep_len in zip(task_ids, episode_lengths):
self.episode_len_hist.append(ep_len)
self.ep_len_by_task[tid].append(ep_len)
self.ep_len_history.append((self.step, sum(episode_lengths) / len(episode_lengths)))
self.env_errors += errors
self._flush_live()
def _flush_live(self):
recent = self.reward_history[-20:]
recent_mean = sum(r for _, r in recent) / len(recent) if recent else 0.0
fmt_rate = sum(self.format_scores) / len(self.format_scores) if self.format_scores else 0.0
task_means = {t: round(sum(v)/len(v), 3) for t, v in self.reward_by_task.items()}
elapsed = (time.time() - self._start_time) / 60
payload = {
'step': self.step,
'total_calls': self.total_calls,
'recent_mean': round(recent_mean, 4),
'format_rate': round(fmt_rate, 4),
'parse_failures': self.parse_failures,
'env_errors': self.env_errors,
'elapsed_min': round(elapsed, 1),
'reward_history': [{'step': s, 'reward': r} for s, r in self.reward_history],
'loss_history': [{'step': s, 'loss': l, 'grad_norm': g} for s, l, g in self.loss_history],
'format_history': [{'step': s, 'rate': r} for s, r in self.format_history],
'diff_reward_hist': {d: [{'step': s, 'reward': r} for s, r in v]
for d, v in self.diff_reward_hist.items()},
'ep_len_history': [{'step': s, 'mean_len': l} for s, l in self.ep_len_history],
'decision_counts': dict(self.decision_counts),
'task_means': task_means,
}
try:
with open('/app/metrics_live.json', 'w') as f:
json.dump(payload, f)
except Exception:
pass
def print_summary(self):
recent = self.reward_history[-10:] if self.reward_history else []
recent_mean = sum(r for _, r in recent) / len(recent) if recent else 0.0
fmt_rate = sum(self.format_scores) / len(self.format_scores) if self.format_scores else 0.0
print(f'\n[METRICS] step={self.step} | recent_reward={recent_mean:.3f} | '
f'format_ok={fmt_rate:.1%} | parse_fails={self.parse_failures} | '
f'env_errors={self.env_errors} | total_calls={self.total_calls}')
top_decisions = self.decision_counts.most_common(5)
print(f'[METRICS] decisions: {dict(top_decisions)}')
if self.reward_by_task:
task_means = {t: round(sum(v)/len(v), 3) for t, v in self.reward_by_task.items()}
print(f'[METRICS] per_task_reward: {task_means}')
def save_all_metrics_figures(self, run_dir: str):
"""Save six training metric figures (reward curve, difficulty curves, format rate, etc.)."""
PALETTE = {'easy': '#3fb950', 'medium': '#d29922', 'hard': '#f85149', 'long': '#a371f7'}
BG = '#0d1117'
PANEL = '#161b22'
GRID = '#21262d'
TEXT = '#e6edf3'
SUBTEXT = '#8b949e'
ACCENT = '#58a6ff'
def _setup(ax, xlabel='', ylabel='', title=''):
ax.set_facecolor(PANEL)
ax.tick_params(colors=TEXT, labelsize=8)
for sp in ax.spines.values():
sp.set_color('#30363d')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.yaxis.grid(True, color=GRID, linewidth=0.6, alpha=0.8)
ax.xaxis.grid(True, color=GRID, linewidth=0.4, alpha=0.4)
ax.set_axisbelow(True)
if xlabel: ax.set_xlabel(xlabel, color=SUBTEXT, fontsize=9)
if ylabel: ax.set_ylabel(ylabel, color=SUBTEXT, fontsize=9)
if title: ax.set_title(title, color=TEXT, fontsize=10, fontweight='bold', pad=8)
def _smooth(values, window=None):
if len(values) < 3:
return values
w = window or max(3, len(values) // 12)
return np.convolve(values, np.ones(w)/w, mode='valid'), w
def _caption(fig, text):
fig.text(0.5, 0.01, text, ha='center', va='bottom',
color=SUBTEXT, fontsize=7, style='italic')
ts = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
# ── Figure 1: Mean Episode Return (reward curve) ──────────────────────
if self.reward_history:
fig, ax = plt.subplots(figsize=(10, 4))
fig.patch.set_facecolor(BG)
steps = [s for s, _ in self.reward_history]
rewards = [r for _, r in self.reward_history]
ax.plot(steps, rewards, color=ACCENT, alpha=0.25, linewidth=1, label='Per-batch mean')
if len(rewards) >= 5:
sm, w = _smooth(rewards)
ax.plot(steps[w-1:], sm, color=ACCENT, linewidth=2, label=f'EMA (w={w})')
ax.axhline(0.5, color=SUBTEXT, linestyle='--', linewidth=1, alpha=0.5, label='Chance baseline (0.5)')
recent_mean = sum(rewards[-20:]) / min(20, len(rewards))
ax.axhline(recent_mean, color='#f78166', linestyle=':', linewidth=1.5,
label=f'Recent mean = {recent_mean:.3f}')
ax.set_ylim(0, 1.05)
ax.legend(fontsize=8, facecolor=PANEL, edgecolor='#30363d', labelcolor=TEXT)
_setup(ax, xlabel='Training Step (reward function call batch)',
ylabel='Mean Episode Return [0.01 – 0.99]',
title='Training Reward Curve — AP Commander GRPO')
_caption(fig, f'Each step = one GRPO batch. Reward = discounted accumulated score from AP Commander environment. | {ts}')
plt.tight_layout(rect=[0, 0.04, 1, 1])
p = os.path.join(run_dir, 'fig1_reward_curve.png')
plt.savefig(p, dpi=130, bbox_inches='tight', facecolor=BG)
plt.close()
print(f'[METRICS] {p}')
# ── Figure 2: Per-Difficulty Learning Curves ──────────────────────────
if self.diff_reward_hist:
fig, ax = plt.subplots(figsize=(10, 4))
fig.patch.set_facecolor(BG)
for diff in _DIFFICULTY_ORDER:
hist = self.diff_reward_hist.get(diff, [])
if not hist:
continue
steps_d = [s for s, _ in hist]
rewards_d = [r for _, r in hist]
color = PALETTE.get(diff, ACCENT)
ax.plot(steps_d, rewards_d, color=color, alpha=0.20, linewidth=1)
if len(rewards_d) >= 5:
sm, w = _smooth(rewards_d)
ax.plot(steps_d[w-1:], sm, color=color, linewidth=2.5, label=f'{diff} (n={len(steps_d)})')
else:
ax.plot(steps_d, rewards_d, color=color, linewidth=2.5, label=diff)
for thr_diff, thr_val in _UNLOCK_THRESHOLDS.items():
ax.axhline(thr_val, color=PALETTE.get(thr_diff, SUBTEXT),
linestyle='--', linewidth=0.8, alpha=0.5)
ax.set_ylim(0, 1.05)
ax.legend(fontsize=9, facecolor=PANEL, edgecolor='#30363d', labelcolor=TEXT)
_setup(ax, xlabel='Training Step',
ylabel='Mean Reward per Difficulty Tier [0.01 – 0.99]',
title='Curriculum Learning Curves — Easy / Medium / Hard / Long-Horizon')
_caption(fig, f'Dashed lines = curriculum unlock thresholds. Each line = rolling mean of all tasks in that difficulty tier. | {ts}')
plt.tight_layout(rect=[0, 0.04, 1, 1])
p = os.path.join(run_dir, 'fig2_difficulty_curves.png')
plt.savefig(p, dpi=130, bbox_inches='tight', facecolor=BG)
plt.close()
print(f'[METRICS] {p}')
# ── Figure 3: Episode Length Distribution ─────────────────────────────
if self.episode_len_hist:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
fig.patch.set_facecolor(BG)
# Overall histogram
max_len = max(self.episode_len_hist)
bins = range(1, max_len + 2)
axes[0].hist(self.episode_len_hist, bins=bins, color=ACCENT, alpha=0.85,
edgecolor=BG, rwidth=0.8)
_setup(axes[0], xlabel='Episode Length (number of env steps)',
ylabel='Count of Episodes',
title='Episode Length Distribution (all tasks)')
axes[0].axvline(np.mean(self.episode_len_hist), color='#f78166',
linestyle='--', linewidth=1.5,
label=f'Mean = {np.mean(self.episode_len_hist):.1f}')
axes[0].legend(fontsize=8, facecolor=PANEL, edgecolor='#30363d', labelcolor=TEXT)
# Per-difficulty mean episode length bar
diff_ep_means = {}
for diff in _DIFFICULTY_ORDER:
lens = []
for tid, d in _TASK_DIFFICULTY.items():
if d == diff:
lens.extend(self.ep_len_by_task.get(tid, []))
if lens:
diff_ep_means[diff] = np.mean(lens)
if diff_ep_means:
diffs = list(diff_ep_means.keys())
means = list(diff_ep_means.values())
colors = [PALETTE.get(d, ACCENT) for d in diffs]
axes[1].bar(diffs, means, color=colors, alpha=0.85, edgecolor=BG, width=0.5)
for i, (d, m) in enumerate(zip(diffs, means)):
axes[1].text(i, m + 0.05, f'{m:.1f}', ha='center', color=TEXT, fontsize=9,
fontweight='bold')
axes[1].set_ylim(0, max(means) * 1.3)
_setup(axes[1], xlabel='Difficulty Tier',
ylabel='Mean Episode Length (steps)',
title='Mean Episode Length by Difficulty')
fig.suptitle('Episode Length Analysis — Multi-Step Decision Behavior', color=TEXT, fontsize=11, y=1.01)
_caption(fig, f'Long-horizon tasks expected to have higher mean episode lengths as agent learns to use ESCALATE/QUERY_VENDOR. | {ts}')
plt.tight_layout(rect=[0, 0.04, 1, 1])
p = os.path.join(run_dir, 'fig3_episode_lengths.png')
plt.savefig(p, dpi=130, bbox_inches='tight', facecolor=BG)
plt.close()
print(f'[METRICS] {p}')
# ── Figure 4: Format Compliance Rate Over Time ────────────────────────
if self.format_history:
fig, ax = plt.subplots(figsize=(10, 3.5))
fig.patch.set_facecolor(BG)
steps_f = [s for s, _ in self.format_history]
fmt_vals = [r for _, r in self.format_history]
ax.plot(steps_f, fmt_vals, color='#d29922', alpha=0.25, linewidth=1)
if len(fmt_vals) >= 5:
sm, w = _smooth(fmt_vals)
ax.plot(steps_f[w-1:], sm, color='#d29922', linewidth=2.5,
label=f'EMA (w={w})')
final_rate = sum(self.format_scores) / max(1, len(self.format_scores))
ax.axhline(final_rate, color='#3fb950', linestyle='--', linewidth=1.5,
label=f'Overall rate = {final_rate:.1%}')
ax.set_ylim(0, 1.05)
ax.legend(fontsize=8, facecolor=PANEL, edgecolor='#30363d', labelcolor=TEXT)
_setup(ax, xlabel='Training Step',
ylabel='Format Compliance Rate [0 – 1]',
title='JSON Format Compliance Over Training')
_caption(fig, f'Format compliance = fraction of completions producing valid JSON with correct fields. Parse failures = {self.parse_failures}. | {ts}')
plt.tight_layout(rect=[0, 0.04, 1, 1])
p = os.path.join(run_dir, 'fig4_format_compliance.png')
plt.savefig(p, dpi=130, bbox_inches='tight', facecolor=BG)
plt.close()
print(f'[METRICS] {p}')
# ── Figure 5: Decision Distribution Over Time (stacked bar) ──────────
if self.decision_history and len(self.decision_history) >= 3:
all_decisions = sorted(set(self.decision_counts.keys()))
# Sample ~20 evenly-spaced checkpoints for readability
n_checkpoints = min(20, len(self.decision_history))
idxs = [int(i * (len(self.decision_history) - 1) / (n_checkpoints - 1))
for i in range(n_checkpoints)]
ckpt_steps = [self.decision_history[i][0] for i in idxs]
ckpt_counts = [self.decision_history[i][1] for i in idxs]
# Convert to fractions
fracs = []
for c in ckpt_counts:
total_c = sum(c.values()) or 1
fracs.append({d: c.get(d, 0) / total_c for d in all_decisions})
fig, ax = plt.subplots(figsize=(12, 4))
fig.patch.set_facecolor(BG)
dec_colors = ['#3fb950','#f85149','#d29922','#a371f7','#58a6ff','#f0883e']
bottom = np.zeros(len(ckpt_steps))
for j, dec in enumerate(all_decisions):
vals = np.array([f[dec] for f in fracs])
ax.bar(range(len(ckpt_steps)), vals, bottom=bottom,
label=dec, color=dec_colors[j % len(dec_colors)],
alpha=0.85, edgecolor=BG)
bottom += vals
ax.set_xticks(range(len(ckpt_steps)))
ax.set_xticklabels([str(s) for s in ckpt_steps], rotation=45, fontsize=7)
ax.set_ylim(0, 1.05)
ax.legend(fontsize=7, facecolor=PANEL, edgecolor='#30363d', labelcolor=TEXT,
loc='upper right', bbox_to_anchor=(1.15, 1))
_setup(ax, xlabel='Training Step (checkpoint)',
ylabel='Fraction of Decisions',
title='Decision Distribution Over Training (Stacked Bar)')
_caption(fig, f'Each bar = cumulative decision distribution up to that checkpoint. Ideal: APPROVE_FULL grows for easy tasks, REJECT for fraud/duplicate tasks. | {ts}')
plt.tight_layout(rect=[0, 0.04, 0.88, 1])
p = os.path.join(run_dir, 'fig5_decision_distribution.png')
plt.savefig(p, dpi=130, bbox_inches='tight', facecolor=BG)
plt.close()
print(f'[METRICS] {p}')
# ── Figure 6: Per-Task Training Mean (horizontal bar) ─────────────────
if self.reward_by_task:
task_means = {t: sum(v)/len(v) for t, v in self.reward_by_task.items()}
tasks = sorted(task_means, key=lambda t: (_DIFFICULTY_ORDER.index(_TASK_DIFFICULTY.get(t,'easy')), t))
means = [task_means[t] for t in tasks]
colors = [PALETTE.get(_TASK_DIFFICULTY.get(t,'easy'), ACCENT) for t in tasks]
short = [t.replace('easy_','').replace('medium_','').replace('hard_','').replace('long_','').replace('_',' ').title() for t in tasks]
fig, ax = plt.subplots(figsize=(10, max(4, len(tasks) * 0.45)))
fig.patch.set_facecolor(BG)
yp = range(len(tasks))
ax.barh(list(yp), means, color=colors, alpha=0.85, edgecolor=BG)
ax.set_yticks(list(yp))
ax.set_yticklabels(short, fontsize=8)
ax.set_xlim(0, 1.05)
overall_mean = sum(means) / len(means)
ax.axvline(overall_mean, color='#f78166', linestyle='--', linewidth=1.5,
label=f'Overall mean = {overall_mean:.3f}')
ax.axvline(0.5, color=SUBTEXT, linestyle=':', linewidth=1, alpha=0.5)
for i, m in enumerate(means):
ax.text(m + 0.01, i, f'{m:.3f}', va='center', color=TEXT, fontsize=7)
from matplotlib.patches import Patch
legend_els = [Patch(facecolor=PALETTE[d], label=d.title()) for d in _DIFFICULTY_ORDER if d in PALETTE]
legend_els.append(plt.Line2D([0],[0], color='#f78166', linestyle='--', label=f'Mean {overall_mean:.3f}'))
ax.legend(handles=legend_els, fontsize=8, facecolor=PANEL, edgecolor='#30363d', labelcolor=TEXT)
_setup(ax, xlabel='Mean Training Reward [0.01 – 0.99]',
ylabel='Task',
title='Per-Task Training Mean Reward (all episodes)')
_caption(fig, f'Tasks ordered by difficulty. Green ≥ 0.7 = curriculum mastered. Orange = in progress. Red < 0.4 = needs more training. | {ts}')
plt.tight_layout(rect=[0, 0.04, 1, 1])
p = os.path.join(run_dir, 'fig6_per_task_means.png')
plt.savefig(p, dpi=130, bbox_inches='tight', facecolor=BG)
plt.close()
print(f'[METRICS] {p}')
METRICS = Metrics()
_EPISODE_LOG_PATH: str = '' # set to run_dir/episodes.jsonl once run_dir is known
# ── Helpers ────────────────────────────────────────────────────────────────────
def obs_to_prompt(obs: dict) -> str:
inv = obs['invoice']
lines = '\n'.join(
f" {li['description']}: qty={li['quantity']}, unit_price=${li['unit_price']:.2f}"
for li in inv.get('line_items', [])
)
pos = '\n'.join(
f" PO {p['po_number']} ({p['status']}) {p['vendor_name']}: " +
', '.join(f"{l['description']} qty={l['ordered_quantity']} @${l['agreed_unit_price']:.2f}"
for l in p.get('lines', []))
for p in obs.get('purchase_orders', [])
)
grns = '\n'.join(
f" GRN {g['grn_id']} (PO {g['po_number']}): " +
', '.join(f"{l['description']} recv={l['received_quantity']}"
for l in g.get('lines', []))
for g in obs.get('goods_receipts', [])
)
context = '\n'.join(f' {n}' for n in obs.get('context_notes', []))
paid = ', '.join(obs.get('paid_invoice_ids', []))
return (
f"TASK: {obs['task_name']}\n{obs['task_description']}\n\n"
f"INVOICE {inv['invoice_id']} | {inv['vendor_name']} | ${inv['invoice_total']:,.2f}\n{lines}\n"
f"Freight: ${inv.get('freight_charge',0):.2f}\n\n"
f"PURCHASE ORDERS:\n{pos}\n\nGOODS RECEIPTS:\n{grns}\n"
+ (f"PAID LEDGER: {paid}\n" if paid else "")
+ (f"CONTEXT:\n{context}\n" if context else "")
+ f"\nPOLICY:\n{obs['company_policy']}\n\nOutput JSON decision."
)
def parse_action(raw: str) -> tuple[dict, bool]:
"""Returns (action_dict, format_ok). format_ok=False means parse failed."""
clean = re.sub(r'```(?:json)?\s*|\s*```', '', raw).strip()
m = re.search(r'\{.*\}', clean, re.DOTALL)
if m:
try:
action = json.loads(m.group())
# Validate required fields and enum values
if (action.get('decision') in VALID_DECISIONS and
action.get('reason_code') in VALID_REASON_CODES and
isinstance(action.get('approved_amount'), (int, float)) and
isinstance(action.get('explanation'), str) and
len(action.get('explanation', '')) > 10):
return action, True
except Exception:
pass
METRICS.parse_failures += 1
return {'decision': 'REJECT', 'approved_amount': 0.0,
'reason_code': 'NO_PO_FOUND', 'explanation': 'parse error fallback'}, False
def run_episode(task_id: str, action_json: dict, seed=None) -> float:
try:
r = requests.post(f'{ENV_URL}/reset',
json={'task_id': task_id, 'seed': seed}, timeout=20)
r.raise_for_status()
data = r.json()
step_r = requests.post(f'{ENV_URL}/step',
json={'session_id': data['session_id'], 'action': action_json},
timeout=20)
step_r.raise_for_status()
return float(step_r.json()['reward']['score'])
except Exception:
return 0.01
def run_episode_accumulated(task_id: str, first_action: dict, seed=None,
discount: float = 0.9, max_steps: int = 20,
episode_log: list | None = None) -> tuple[float, int]:
"""
Run a full multi-step episode accumulating discounted per-step rewards.
Returns (score, episode_length) so callers can track step counts.
Model's first action starts the episode; _greedy_followup() handles
subsequent steps so multi-step sequences earn full accumulated credit.
E.g. QUERY_VENDOR→REJECT = 0.01 + 0.9*0.99 = 0.901 > shortcut REJECT = ~0.4
episode_log: if provided, appended with one dict per env step for JSONL logging.
"""
try:
r = requests.post(f'{ENV_URL}/reset',
json={'task_id': task_id, 'seed': seed}, timeout=20)
r.raise_for_status()
reset_data = r.json()
session_id = reset_data['session_id']
action = first_action
total = 0.0
steps_taken = 0
for step_n in range(max_steps):
step_r = requests.post(f'{ENV_URL}/step',
json={'session_id': session_id, 'action': action},
timeout=20)
step_r.raise_for_status()
result = step_r.json()
r_score = float(result['reward']['score'])
done = result['done']
obs_back = result.get('observation', {})
total += (discount ** step_n) * r_score
steps_taken = step_n + 1
if episode_log is not None:
episode_log.append({
'step_n': step_n,
'decision': action.get('decision'),
'approved_amount': action.get('approved_amount'),
'reason_code': action.get('reason_code'),
'explanation': (action.get('explanation') or '')[:120],
'step_score': round(r_score, 4),
'done': done,
'context_notes': obs_back.get('context_notes', []),
'action_history': obs_back.get('action_history', []),
})
if done:
break
action = _greedy_followup(obs_back)
return min(0.99, max(0.01, total)), steps_taken
except Exception as e:
return 0.01, 1
def env_reward_fn(completions, task_id=None, seed=None, **kwargs):
"""
Environment reward: accumulated discounted per-step reward from AP Commander.
Curriculum gating redirects locked tasks to easier ones during early training.
Writes one JSONL record per episode to _EPISODE_LOG_PATH for full verifiability.
"""
task_ids = task_id if task_id is not None else ['easy_perfect_match'] * len(completions)
seeds = seed if seed is not None else [random.randint(1, 999)] * len(completions)
rewards, decisions, format_ok_list, ep_lengths, errors = [], [], [], [], 0
for completion, tid, s in zip(completions, task_ids, seeds):
# Curriculum gating disabled: all 20 tasks train simultaneously from step 1.
# Gating caused mode collapse in Run 2 — hard tasks never trained.
action, fmt_ok = parse_action(completion)
episode_steps = []
try:
score, ep_len = run_episode_accumulated(
tid, action, seed=int(s), episode_log=episode_steps)
except Exception as e:
score, ep_len = 0.01, 1
errors += 1
rewards.append(score)
ep_lengths.append(ep_len)
decisions.append(action.get('decision', 'UNKNOWN'))
format_ok_list.append(fmt_ok)
CURRICULUM.record(tid, score)
if _EPISODE_LOG_PATH:
try:
record = {
'reward_step': METRICS.step + 1,
'call_n': METRICS.total_calls + len(rewards),
'task_id': tid,
'seed': int(s),
'format_ok': fmt_ok,
'score': round(score, 4),
'episode_len': ep_len,
'final_decision': action.get('decision'),
'steps': episode_steps,
'ts': datetime.datetime.now().isoformat(),
}
with open(_EPISODE_LOG_PATH, 'a') as _f:
_f.write(json.dumps(record) + '\n')
except Exception:
pass
if METRICS.total_calls % LOG_SAMPLES_EVERY == 0:
print(f'\n[SAMPLE] task={tid} seed={s} fmt={fmt_ok} '
f'score={score:.3f} ep_len={ep_len}')
print(f' {action.get("decision")} ${action.get("approved_amount")} '
f'{action.get("reason_code")}')
print(f' {str(action.get("explanation",""))[:100]}')
print(f' curriculum: {CURRICULUM.status_line()}')
if episode_steps:
actor_notes = [n for step in episode_steps
for n in step.get('context_notes', [])]
if actor_notes:
print(f' actor_responses: {actor_notes[:2]}')
METRICS.log_step(rewards, decisions, format_ok_list, list(task_ids), errors,
episode_lengths=ep_lengths)
if METRICS.step % 5 == 0:
METRICS.print_summary()
print(f'[CURRICULUM] {CURRICULUM.status_line()}')
# Graceful stop: app.py writes this flag when Stop button is clicked
if os.path.exists('/app/stop_requested'):
print(f'\n[STOP] Stop flag detected at step {METRICS.step}. Saving weights before exit...')
os.remove('/app/stop_requested')
raise _StopTraining()
return rewards
def format_reward_fn(completions, **kwargs):
"""Format reward: +0.15 if valid JSON with correct fields, -0.15 otherwise.
Increased from ±0.05 — format failures were 55% in Run 2, drowning env signal."""
results = []
for completion in completions:
_, ok = parse_action(completion)
results.append(0.15 if ok else -0.15)
return results
OVERSIGHT_SYSTEM_PROMPT = """You are a Fleet AI Oversight Agent reviewing completed AP Clerk decisions.
For each episode summary provided, decide: CLEAR, FLAG_FOR_REVIEW, or ESCALATE_TO_AUDIT.
Output valid JSON only: {"episode_id": "...", "verdict": "CLEAR|FLAG_FOR_REVIEW|ESCALATE_TO_AUDIT", "signal": "specific reason with $ amounts or %", "confidence": 0.0-1.0}
Your response must start with { and end with } with no other text."""
def eval_oversight(model, tokenizer, seed: int = 99) -> float:
"""Run one oversight session (5 episodes) and return mean reward across all steps."""
import torch, json as _json
model.eval()
try:
reset = requests.post(f'{ENV_URL}/oversight/reset',
json={'num_episodes': 5, 'seed': seed}, timeout=20).json()
session_id = reset['session_id']
summaries = reset['observation']['episode_summaries']
scores = []
for ep in summaries:
prompt = (
f"Episode ID: {ep['episode_id']}\n"
f"Vendor: {ep['vendor_name']} Invoice: {ep['invoice_id']} "
f"Total: ${ep['invoice_total']:.2f}\n"
f"Decision: {ep['final_decision']} Reason: {ep['reason_code']}\n"
f"Explanation: {ep['explanation']}\n"
f"Known fraud patterns: {reset['observation'].get('known_fraud_patterns', [])}\n"
f"Audit budget remaining: {reset['observation']['audit_budget']}"
)
messages = [{'role': 'system', 'content': OVERSIGHT_SYSTEM_PROMPT},
{'role': 'user', 'content': prompt}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors='pt').to('cuda')
with torch.no_grad():
out = model.generate(**inputs, max_new_tokens=150, temperature=0.1, do_sample=True)
raw = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
try:
action = _json.loads(raw.strip())
action.setdefault('episode_id', ep['episode_id'])
action.setdefault('verdict', 'CLEAR')
action.setdefault('signal', 'no signal')
action.setdefault('confidence', 0.5)
except Exception:
action = {'episode_id': ep['episode_id'], 'verdict': 'CLEAR',
'signal': 'parse error', 'confidence': 0.5}
resp = requests.post(f'{ENV_URL}/oversight/step',
json={'session_id': session_id, 'action': action},
timeout=20).json()
scores.append(float(resp['reward']['score']))
mean = sum(scores) / len(scores) if scores else 0.01
print(f' oversight session mean: {mean:.3f} scores={[round(s,2) for s in scores]}')
return mean
except Exception as e:
print(f' oversight eval error: {e}')
return 0.01
# ── Eval helper ────────────────────────────────────────────────────────────────
def eval_task(model, tokenizer, task_id: str, seed: int = 99) -> float:
import torch
model.eval()
try:
reset = requests.post(f'{ENV_URL}/reset', json={'task_id': task_id, 'seed': seed}, timeout=20).json()
obs, session_id = reset['observation'], reset['session_id']
messages = [{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': obs_to_prompt(obs)}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors='pt').to('cuda')
with torch.no_grad():
out = model.generate(**inputs, max_new_tokens=250, temperature=0.1, do_sample=True)
raw = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
action, fmt_ok = parse_action(raw)
score = float(requests.post(f'{ENV_URL}/step',
json={'session_id': session_id, 'action': action},
timeout=20).json()['reward']['score'])
print(f' output: {raw[:120].strip()}')
return score
except Exception as e:
print(f' eval error: {e}')
return 0.01
# ── Main ───────────────────────────────────────────────────────────────────────
def _make_run_dir() -> str:
"""Create timestamped run directory under /app/runs/grpo/MODEL-NEpoch-DATETIME."""
model_slug = MODEL_NAME.split('/')[-1].lower().replace('.', '-')
ts = datetime.datetime.now().strftime('%Y-%m-%d_%H%M')
run_dir = f'/app/runs/grpo/{model_slug}-{NUM_EPOCHS}ep-{ts}'
os.makedirs(run_dir, exist_ok=True)
return run_dir
def _save_loss_curve(trainer, run_dir: str):
"""Extract loss + grad_norm from TRL trainer log history and save as loss_curve.png."""
try:
history = trainer.state.log_history
steps = [e['step'] for e in history if 'loss' in e]
losses = [e['loss'] for e in history if 'loss' in e]
grads = [e.get('grad_norm', None) for e in history if 'loss' in e]
BG, TEXT, BLUE, ORANGE = '#0d1117', '#e6edf3', '#58a6ff', '#f0883e'
fig, axes = plt.subplots(2, 1, figsize=(10, 6), facecolor=BG)
fig.suptitle(
f'AP Commander GRPO — Loss & Gradient Norm | {MODEL_NAME.split("/")[-1]}',
color=TEXT, fontsize=11, fontweight='bold'
)
ax1 = axes[0]
ax1.set_facecolor(BG)
ax1.plot(steps, losses, color=BLUE, linewidth=1.2, alpha=0.6, label='per-step loss')
if len(steps) > 10:
smooth = [sum(losses[max(0,i-10):i+1])/len(losses[max(0,i-10):i+1]) for i in range(len(losses))]
ax1.plot(steps, smooth, color=BLUE, linewidth=2, label='smooth (w=10)')
ax1.axhline(0, color='white', linewidth=0.5, alpha=0.3)
ax1.set_ylabel('Loss', color=TEXT); ax1.set_xlabel('')
ax1.tick_params(colors=TEXT); ax1.spines[:].set_color('#30363d')
ax1.legend(facecolor=BG, labelcolor=TEXT, fontsize=8)
for spine in ax1.spines.values(): spine.set_color('#30363d')
ax2 = axes[1]
ax2.set_facecolor(BG)
valid_grads = [(s, g) for s, g in zip(steps, grads) if g is not None]
if valid_grads:
gs, gv = zip(*valid_grads)
ax2.plot(gs, gv, color=ORANGE, linewidth=1.2, alpha=0.7)
ax2.set_ylabel('Grad Norm', color=TEXT); ax2.set_xlabel('Training Step', color=TEXT)
ax2.tick_params(colors=TEXT)
for spine in ax2.spines.values(): spine.set_color('#30363d')
fig.text(0.5, 0.01,
f'GRPO loss (group-relative policy gradient). Negative values normal mid-training. '
f'Grad norm collapse (<0.1) indicates entropy saturation.',
ha='center', color='#8b949e', fontsize=7)
plt.tight_layout(rect=[0, 0.04, 1, 1])
out = os.path.join(run_dir, 'loss_curve.png')
plt.savefig(out, dpi=130, bbox_inches='tight', facecolor=BG)
plt.close()
print(f'[LOSS CURVE] Saved {out} ({len(steps)} steps)')
except Exception as e:
print(f'[LOSS CURVE] skipped: {e}')
def _save_reward_curve_png(run_dir: str):
"""Save final reward curve PNG from METRICS for submission evidence."""
try:
if not METRICS.reward_history:
return
BG, TEXT, BLUE = '#0d1117', '#e6edf3', '#58a6ff'
steps = [s for s, _ in METRICS.reward_history]
rewards = [r for _, r in METRICS.reward_history]
fig, ax = plt.subplots(figsize=(10, 4), facecolor=BG)
ax.set_facecolor(BG)
ax.plot(steps, rewards, color=BLUE, alpha=0.4, linewidth=1, label='Per-step')
if len(rewards) >= 10:
w = max(5, len(rewards) // 15)
sm = [sum(rewards[max(0,i-w):i+1])/len(rewards[max(0,i-w):i+1]) for i in range(len(rewards))]
ax.plot(steps, sm, color=BLUE, linewidth=2.5, label=f'Smooth (w={w})')
mean_r = sum(rewards[-20:]) / min(20, len(rewards))
ax.axhline(mean_r, color='#f78166', linestyle='--', linewidth=1.2,
label=f'Recent mean: {mean_r:.3f}')
ax.set_ylim(0, 1.05)
ax.set_xlabel('Training Step', color=TEXT, fontsize=10)
ax.set_ylabel('Mean Reward', color=TEXT, fontsize=10)
ax.set_title(f'GRPO Reward Curve — {MODEL_NAME.split("/")[-1]} | {NUM_EPOCHS} epochs',
color=TEXT, fontsize=11, fontweight='bold')
ax.tick_params(colors=TEXT)
for spine in ax.spines.values(): spine.set_color('#30363d')
ax.legend(facecolor=BG, labelcolor=TEXT, fontsize=9)
fig.text(0.5, 0.01, f'Mean reward over {len(steps)} training steps. '
f'Higher = better decision quality across {len(METRICS.reward_by_task)} tasks.',
ha='center', color='#8b949e', fontsize=7)
plt.tight_layout(rect=[0, 0.04, 1, 1])
out = os.path.join(run_dir, 'reward_curve.png')
plt.savefig(out, dpi=130, bbox_inches='tight', facecolor=BG)
plt.close()
print(f'[REWARD CURVE] Saved {out}')
except Exception as e:
print(f'[REWARD CURVE] skipped: {e}')
def main():
# Authenticate with HF Hub if token provided (needed for gated models like Llama-3)
hf_token = os.environ.get('HF_TOKEN') or os.environ.get('HUGGING_FACE_HUB_TOKEN')
if hf_token:
from huggingface_hub import login
login(token=hf_token, add_to_git_credential=False)
print('[AUTH] Logged in to HF Hub.')
else:
print('[AUTH] No HF_TOKEN set — using public models only (Qwen recommended).')
# All run artifacts go into this timestamped dir — never overwrite a previous run
RUN_DIR = _make_run_dir()
print(f'[RUN] Artifacts → {RUN_DIR}')
# Point the global episode log path so env_reward_fn can write structured logs
global _EPISODE_LOG_PATH
_EPISODE_LOG_PATH = os.path.join(RUN_DIR, 'episodes.jsonl')
print(f'[RUN] Episode log → {_EPISODE_LOG_PATH}')
print(f'[ENV] Checking {ENV_URL}...')
h = requests.get(f'{ENV_URL}/health', timeout=30).json()
print(f"[ENV] status={h['status']} tasks={h.get('total_tasks')}")
print(f'[MODEL] Loading {MODEL_NAME}...')
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, TaskType
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map='auto',
trust_remote_code=True,
)
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
lora_cfg = LoraConfig(
r=16, lora_alpha=32,
target_modules=['q_proj','k_proj','v_proj','o_proj','gate_proj','up_proj','down_proj'],
lora_dropout=0, bias='none',
task_type=TaskType.CAUSAL_LM,
)
model = get_peft_model(model, lora_cfg)
model.print_trainable_parameters()
# Baseline eval (before training)
print('\n[BASELINE] Before training:')
baseline = {}
for t in EVAL_TASKS:
s = eval_task(model, tokenizer, t)
baseline[t] = s
print(f' {t}: {s:.3f}')
baseline['oversight_agent'] = eval_oversight(model, tokenizer, seed=99)
print(f' oversight_agent: {baseline["oversight_agent"]:.3f}')
print(f' Mean: {sum(baseline.values())/len(baseline):.3f}')
model.train()
# Hard/long get more seeds so multi-step sequences see enough variation to learn from
_SEEDS_PER_DIFF = {'easy': 5, 'medium': 8, 'hard': 20, 'long': 20}
task_seed_pairs = [
(tid, s)
for tid in TRAIN_TASKS
for s in range(1, _SEEDS_PER_DIFF.get(_TASK_DIFFICULTY.get(tid, 'easy'), 5) + 1)
]
total_prompts = len(task_seed_pairs)
print(f'\n[DATASET] Building prompts ({total_prompts} total: easy×5 medium×8 hard/long×20 seeds)...')
rows = []
for task_id, seed in task_seed_pairs:
try:
reset = requests.post(f'{ENV_URL}/reset', json={'task_id': task_id, 'seed': seed}, timeout=20).json()
obs = reset['observation']
messages = [{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': obs_to_prompt(obs)}]
rows.append({
'prompt': tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True),
'task_id': task_id,
'seed': seed,
})
except Exception as e:
print(f' skip {task_id} seed={seed}: {e}')
dataset = Dataset.from_list(rows)
by_diff = {d: sum(1 for r in rows if _TASK_DIFFICULTY.get(r['task_id'],'easy')==d) for d in ['easy','medium','hard','long']}
print(f'[DATASET] {len(dataset)} samples — easy:{by_diff["easy"]} medium:{by_diff["medium"]} hard:{by_diff["hard"]} long:{by_diff["long"]} | curriculum: {CURRICULUM.status_line()}')
# Train
print(f'\n[TRAIN] {NUM_EPOCHS} epochs | {NUM_GENERATIONS} generations/prompt | {len(dataset)} samples')
model.train()
_cuda_ok = torch.cuda.is_available()
_use_bf16 = _cuda_ok and torch.cuda.get_device_capability()[0] >= 8 # Ampere+ (A100, H100)
_use_fp16 = _cuda_ok and not _use_bf16
print(f'[DTYPE] cuda={_cuda_ok} bf16={_use_bf16} fp16={_use_fp16}')
# per_device_train_batch_size must equal num_generations (TRL GRPO requirement).
config = GRPOConfig(
output_dir = './ap_commander_grpo',
num_train_epochs = NUM_EPOCHS,
per_device_train_batch_size = NUM_GENERATIONS,
num_generations = NUM_GENERATIONS,
gradient_accumulation_steps = 1,
learning_rate = 1e-5,
max_completion_length = 200,
temperature = 0.7,
beta = 0.1,
bf16 = _use_bf16,
fp16 = _use_fp16,
logging_steps = 1,
save_steps = 999,
report_to = 'none',
remove_unused_columns = False,
)
from transformers import TrainerCallback
class _LossCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
if logs and 'loss' in logs:
METRICS.loss_history.append((
state.global_step,
round(float(logs['loss']), 5),
round(float(logs.get('grad_norm', 0)), 5),
))
METRICS._flush_live()
trainer = GRPOTrainer(
model=model, processing_class=tokenizer,
reward_funcs=[env_reward_fn, format_reward_fn],
args=config, train_dataset=dataset,
callbacks=[_LossCallback()],
)
stopped_early = False
try:
result = trainer.train()
print(f'\n[TRAIN] Done. Loss: {result.training_loss:.4f}')
except _StopTraining:
stopped_early = True
print(f'\n[STOP] Early stop at step {METRICS.step}. Saving weights now...')
# Save curves and metrics figures
_save_loss_curve(trainer, RUN_DIR)
_save_reward_curve_png(RUN_DIR)
METRICS.print_summary()
METRICS.save_all_metrics_figures(RUN_DIR)
# Save LoRA adapters (guide point 16: save adapters directly, do NOT merge 4-bit naively)
adapter_dir = os.path.join(RUN_DIR, 'adapter')
label = 'early-stop' if stopped_early else 'final'
print(f'[SAVE] Saving LoRA adapters ({label}) to {adapter_dir}...')
model.save_pretrained(adapter_dir)
tokenizer.save_pretrained(adapter_dir)
# Upload adapter to HF Hub — repo is {username}/ap-commander-adapter, auto-detected from token
hf_token_save = os.environ.get('HF_TOKEN') or os.environ.get('HUGGING_FACE_HUB_TOKEN')
if hf_token_save:
try:
from huggingface_hub import HfApi
api = HfApi(token=hf_token_save)
username = api.whoami(token=hf_token_save)['name']
adapter_repo = f'{username}/ap-commander-adapter'
api.create_repo(repo_id=adapter_repo, repo_type='model', exist_ok=True, token=hf_token_save)
api.upload_folder(
folder_path=adapter_dir,
repo_id=adapter_repo,
repo_type='model',
commit_message=f'GRPO {datetime.datetime.now().strftime("%Y-%m-%d")}{MODEL_NAME} {NUM_EPOCHS}ep',
)
print(f'[SAVE] Adapter pushed to HF Hub: {adapter_repo}')
except Exception as e:
print(f'[SAVE] HF Hub upload skipped: {e}')
else:
print('[SAVE] HF Hub upload skipped: HF_TOKEN not set')
# Post-training eval (all tasks + oversight)
print('\n[POST-EVAL] After training:')
post = {}
model.eval()
for t in EVAL_TASKS:
s = eval_task(model, tokenizer, t)
post[t] = s
print(f' {t}: {s:.3f}')
post['oversight_agent'] = eval_oversight(model, tokenizer, seed=99)
print(f' oversight_agent: {post["oversight_agent"]:.3f}')
print(f' Mean: {sum(post.values())/len(post):.3f}')
print('\n[COMPARE]')
for t in list(EVAL_TASKS) + ['oversight_agent']:
d = post[t] - baseline[t]
sym = '+' if d >= 0 else ''
print(f' {t:<35} {baseline[t]:.3f} -> {post[t]:.3f} ({sym}{d:.3f})')
# ── Before/After comparison figure (results.png — key result for demo) ────
BG, TEXT, SUBTEXT = '#0d1117', '#e6edf3', '#8b949e'
PANEL, GRID = '#161b22', '#21262d'
_fmt_rate = sum(METRICS.format_scores) / max(1, len(METRICS.format_scores))
eval_tasks_sorted = sorted(
EVAL_TASKS,
key=lambda t: (_DIFFICULTY_ORDER.index(_TASK_DIFFICULTY.get(t,'easy')), t)
)
DIFF_COLORS = {'easy': '#3fb950', 'medium': '#d29922', 'hard': '#f85149', 'long': '#a371f7'}
fig = plt.figure(figsize=(18, max(8, len(eval_tasks_sorted) * 0.45 + 2)))
fig.patch.set_facecolor(BG)
gs = fig.add_gridspec(1, 2, wspace=0.38)
# Panel left: before/after horizontal bars
ax_l = fig.add_subplot(gs[0, 0])
ax_l.set_facecolor(PANEL)
yp = np.arange(len(eval_tasks_sorted))
short = [t.replace('easy_','').replace('medium_','').replace('hard_','').replace('long_','')
.replace('_',' ').title() for t in eval_tasks_sorted]
bar_h = 0.35
bars_b = ax_l.barh(yp - bar_h/2, [baseline.get(t, 0) for t in eval_tasks_sorted],
bar_h, label='Before GRPO', color='#f85149', alpha=0.85, edgecolor=BG)
bars_a = ax_l.barh(yp + bar_h/2, [post.get(t, 0) for t in eval_tasks_sorted],
bar_h, label='After GRPO', color='#3fb950', alpha=0.85, edgecolor=BG)
ax_l.set_yticks(yp)
ax_l.set_yticklabels(short, fontsize=8, color=TEXT)
ax_l.set_xlim(0, 1.15)
ax_l.axvline(0.5, color=SUBTEXT, linestyle='--', linewidth=1, alpha=0.5)
# Color-code y-tick labels by difficulty
for i, t in enumerate(eval_tasks_sorted):
ax_l.get_yticklabels()[i].set_color(DIFF_COLORS.get(_TASK_DIFFICULTY.get(t,'easy'), TEXT))
ax_l.legend(fontsize=9, facecolor=PANEL, edgecolor='#30363d', labelcolor=TEXT)
ax_l.set_xlabel('Task Score [0.01 – 0.99]', color=SUBTEXT, fontsize=9)
ax_l.set_ylabel('Task (color = difficulty tier)', color=SUBTEXT, fontsize=9)
ax_l.set_title(f'Before vs After GRPO — {NUM_EPOCHS} Epochs', color=TEXT, fontsize=11,
fontweight='bold', pad=10)
ax_l.tick_params(colors=TEXT, labelsize=8)
for sp in ax_l.spines.values(): sp.set_color('#30363d')
ax_l.spines['top'].set_visible(False); ax_l.spines['right'].set_visible(False)
ax_l.xaxis.grid(True, color=GRID, linewidth=0.6, alpha=0.7)
ax_l.set_axisbelow(True)
# Panel right: delta (improvement) per task
ax_r = fig.add_subplot(gs[0, 1])
ax_r.set_facecolor(PANEL)
deltas = [post.get(t, 0) - baseline.get(t, 0) for t in eval_tasks_sorted]
d_colors = ['#3fb950' if d >= 0 else '#f85149' for d in deltas]
ax_r.barh(yp, deltas, color=d_colors, alpha=0.85, edgecolor=BG)
ax_r.set_yticks(yp)
ax_r.set_yticklabels(short, fontsize=8, color=TEXT)
ax_r.axvline(0, color=SUBTEXT, linewidth=1)
for i, d in enumerate(deltas):
ax_r.text(d + 0.005 * np.sign(d + 1e-9), i, f'{d:+.3f}',
va='center', color=TEXT, fontsize=7)
ax_r.set_xlabel('Score Delta (After − Before)', color=SUBTEXT, fontsize=9)
ax_r.set_ylabel('Task', color=SUBTEXT, fontsize=9)
ax_r.set_title('GRPO Improvement per Task', color=TEXT, fontsize=11,
fontweight='bold', pad=10)
ax_r.tick_params(colors=TEXT, labelsize=8)
for sp in ax_r.spines.values(): sp.set_color('#30363d')
ax_r.spines['top'].set_visible(False); ax_r.spines['right'].set_visible(False)
ax_r.xaxis.grid(True, color=GRID, linewidth=0.6, alpha=0.7)
ax_r.set_axisbelow(True)
mean_before = sum(baseline.get(t,0) for t in eval_tasks_sorted) / len(eval_tasks_sorted)
mean_after = sum(post.get(t,0) for t in eval_tasks_sorted) / len(eval_tasks_sorted)
fig.suptitle(
f'AP Commander GRPO — {MODEL_NAME.split("/")[-1]} | {NUM_EPOCHS} epochs | '
f'{NUM_GENERATIONS} generations | {len(TRAIN_TASKS)} tasks\n'
f'Overall: {mean_before:.3f}{mean_after:.3f} (+{mean_after-mean_before:.3f}) '
f'| format={_fmt_rate:.1%} | parse_fails={METRICS.parse_failures} '
f'| {datetime.datetime.now().strftime("%Y-%m-%d")}',
color=TEXT, fontsize=10, y=1.01
)
fig.text(0.5, -0.01,
'Task colors: green=easy, yellow=medium, red=hard, purple=long-horizon. '
'Score range [0.01, 0.99] as per AP Commander environment specification.',
ha='center', color=SUBTEXT, fontsize=8, style='italic')
results_png = os.path.join(RUN_DIR, 'results.png')
plt.savefig(results_png, dpi=130, bbox_inches='tight', facecolor=BG)
plt.close()
print(f'[DONE] Saved {results_png}')
# Save JSON
fmt_rate = sum(METRICS.format_scores) / max(1, len(METRICS.format_scores))
output = {
'timestamp': datetime.datetime.now().isoformat(),
'run_dir': RUN_DIR,
'model': MODEL_NAME,
'epochs': NUM_EPOCHS,
'num_generations': NUM_GENERATIONS,
'per_device_train_batch_size': NUM_GENERATIONS,
'gradient_accumulation_steps': 2,
'learning_rate': 1e-5,
'lora_r': 16,
'lora_alpha': 32,
'seeds_easy_medium': 5,
'seeds_hard_long': 10,
'total_train_prompts': len(dataset),
'train_tasks': TRAIN_TASKS,
'eval_tasks': list(EVAL_TASKS),
'hardware': 'A10G (HF Spaces)',
'baseline': baseline,
'post_training': post,
'delta': {t: round(post.get(t,0) - baseline.get(t,0), 4) for t in EVAL_TASKS},
'overall_baseline': round(mean_before, 4),
'overall_post': round(mean_after, 4),
'overall_delta': round(mean_after - mean_before, 4),
'episode_log': _EPISODE_LOG_PATH,
'metrics': {
'total_reward_calls': METRICS.total_calls,
'parse_failures': METRICS.parse_failures,
'env_errors': METRICS.env_errors,
'format_rate': round(fmt_rate, 4),
'decision_counts': dict(METRICS.decision_counts),
'per_task_mean': {t: round(sum(v)/len(v), 4) for t, v in METRICS.reward_by_task.items()},
'mean_episode_length': round(sum(METRICS.episode_len_hist) / max(1, len(METRICS.episode_len_hist)), 2),
'by_difficulty_post': {d: round(sum(post.get(t,0) for t,diff in _TASK_DIFFICULTY.items()
if diff==d and t in post) /
max(1, sum(1 for t,diff in _TASK_DIFFICULTY.items()
if diff==d and t in post)), 4)
for d in _DIFFICULTY_ORDER},
},
'figures': [
'fig1_reward_curve.png',
'fig2_difficulty_curves.png',
'fig3_episode_lengths.png',
'fig4_format_compliance.png',
'fig5_decision_distribution.png',
'fig6_per_task_means.png',
'results.png',
],
}
results_json = os.path.join(RUN_DIR, 'training_results.json')
with open(results_json, 'w') as f:
json.dump(output, f, indent=2)
print(f'[DONE] Saved {results_json}')
# Copy live metrics into run dir as snapshot
try:
import shutil
shutil.copy('/app/metrics_live.json', os.path.join(RUN_DIR, 'metrics_live.json'))
except Exception:
pass
# Persist entire run dir to HF Space repo (runs/grpo/MODEL-NEP-DATETIME/)
# so artifacts survive container restarts and each run is independently addressable
repo_run_path = RUN_DIR.replace('/app/', '') # strip /app/ prefix for repo path
hf_token_up = os.environ.get('HF_TOKEN') or os.environ.get('HUGGING_FACE_HUB_TOKEN')
if hf_token_up:
try:
from huggingface_hub import HfApi
api = HfApi(token=hf_token_up)
username = api.whoami(token=hf_token_up)['name']
training_repo = f'{username}/ap-commander-training'
api.upload_folder(
folder_path=RUN_DIR,
path_in_repo=repo_run_path,
repo_id=training_repo,
repo_type='space',
commit_message=f'Run artifacts: {os.path.basename(RUN_DIR)}',
ignore_patterns=['adapter/*'],
)
print(f'[UPLOAD] Run folder → {repo_run_path} in {training_repo}')
except Exception as e:
print(f'[UPLOAD] artifact upload failed: {e}')
else:
print('[UPLOAD] artifact upload skipped: HF_TOKEN not set')
if __name__ == '__main__':
main()