AntiAtropos / training /train.py
div18
show script
85b91e2
#!/usr/bin/env python3
"""
train.py β€” AntiAtropos QLoRA Reward-Based Training (HF Jobs Edition)
=====================================================================
Training loop: generate β†’ evaluate β†’ reward β†’ update β†’ log β†’ checkpoint
This is NOT supervised fine-tuning. The model generates actions, the OpenEnv
environment (running on HF Spaces) evaluates them, and we use the reward
signal to update the policy via REINFORCE/GRPO/RLOO.
Architecture (from training.md):
- GPU = compute only (ephemeral)
- Hub = source of truth (persistent)
- Training = reproducible + resumable
- Metrics = structured + queryable
Usage:
# RECOMMENDED: Launch via HF Jobs (auto-provisions GPU, pushes to Hub):
python training/launch_train.py --run-id my_run
# Or run directly (requires a running AntiAtropos server):
ANTIATROPOS_ENV_URL=http://localhost:8000 \
ANTIATROPOS_HUB_MODEL_REPO=Keshav051/antiatropos-qlora \
python training/train.py --run-id my_run --num-iterations 15
# Override defaults:
python training/train.py --run-id run_007 --num-iterations 500 --num-episodes 6
# See all options:
python training/train.py --help
"""
from __future__ import annotations
import argparse
import json
import os
import random
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import torch
import yaml
# ────────────────────────────────────────────────────────────
# Module path setup β€” allow imports from training/ package
# ────────────────────────────────────────────────────────────
TRAINING_DIR = Path(__file__).resolve().parent
PROJECT_DIR = TRAINING_DIR.parent
if str(TRAINING_DIR) not in sys.path:
sys.path.insert(0, str(TRAINING_DIR))
if str(PROJECT_DIR) not in sys.path:
sys.path.insert(0, str(PROJECT_DIR))
from model_utils import (
attach_lora,
detect_gpu_tier,
find_latest_checkpoint,
gpu_scaled_config,
load_base_model,
push_adapter_to_hub,
push_to_hub,
save_checkpoint,
)
from openenv_loop import (
OpenEnvClient,
rollout_batch,
rollout_episode,
rollout_heuristic_episode,
)
from eval import evaluate
from plotting import (
generate_all_plots,
push_plots_to_hub,
episodes_to_plot_data,
)
# ────────────────────────────────────────────────────────────
# Config Loading
# ────────────────────────────────────────────────────────────
def load_config(config_path: str) -> Dict[str, Any]:
"""Load config from YAML, apply env var overrides, GPU auto-scale."""
with open(config_path) as f:
cfg = yaml.safe_load(f)
# Apply env var overrides (ANTIATROPOS_<KEY>=<value>)
env_overrides = {}
for key, value in os.environ.items():
if key.startswith("ANTIATROPOS_"):
cfg_key = key[len("ANTIATROPOS_"):].lower()
env_overrides[cfg_key] = value
for key, value in env_overrides.items():
if key in cfg:
orig = cfg[key]
if isinstance(orig, bool):
cfg[key] = value.lower() in ("true", "1", "yes")
elif isinstance(orig, int):
cfg[key] = int(value)
elif isinstance(orig, float):
cfg[key] = float(value)
elif isinstance(orig, list):
cfg[key] = json.loads(value)
else:
cfg[key] = value
print(f"[config] Env override: {key} = {cfg[key]}")
# GPU auto-scaling (only if not explicitly overridden)
cfg = gpu_scaled_config(cfg)
return cfg
# ────────────────────────────────────────────────────────────
# REINFORCE Loss (PyTorch)
# ────────────────────────────────────────────────────────────
def compute_returns(rewards: List[float], gamma: float) -> List[float]:
"""Compute discounted returns from a list of rewards."""
returns = []
g = 0.0
for r in reversed(rewards):
g = r + gamma * g
returns.insert(0, g)
return returns
def reinforce_baseline_loss_fn(
model,
tokenizer,
episodes: List,
cfg: Dict[str, Any],
) -> torch.Tensor:
"""Compute REINFORCE with baseline loss across episodes.
Uses per-mini-batch gradient accumulation:
- Pre-compute ALL advantages on CPU first (enables global normalization).
- For each mini-batch: forward β†’ compute loss β†’ backward() immediately.
- Frees the computation graph after every mini-batch.
- Returns a detached scalar; gradients already sit in model.parameters().grad.
This keeps peak VRAM to ONE forward pass worth of activations (~8-9 GiB)
instead of accumulating all mini-batch graphs simultaneously (which caused
OOM when 3 batches Γ— ~8.9 GiB each = 26+ GiB were held concurrently).
Caller (train.py) must check `if loss.requires_grad` before calling
loss.backward() β€” this function returns requires_grad=False so the
caller's backward() is skipped cleanly.
"""
import math as _math
gamma = cfg.get("reward_gamma", 0.99)
normalize_adv = cfg.get("advantage_normalize", True)
loss_batch_size = cfg.get("loss_batch_size", 1)
max_seq_len_cap = cfg.get("max_seq_length", 512)
pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
# ── Phase 1: Collect all (transition, return) pairs (CPU) ──────────────
all_pairs: List[Tuple] = []
for ep in episodes:
if not ep.transitions:
continue
rewards = [t.reward for t in ep.transitions]
returns = compute_returns(rewards, gamma)
for trans, ret in zip(ep.transitions, returns):
if trans.input_ids is not None:
all_pairs.append((trans, ret))
if not all_pairs:
return torch.tensor(0.0, device=model.device)
# ── Phase 2: Compute normalized advantages on CPU (global normalization) ─
raw_returns = torch.tensor([p[1] for p in all_pairs], dtype=torch.float32)
if normalize_adv and len(raw_returns) > 1:
advantages = (raw_returns - raw_returns.mean()) / (raw_returns.std() + 1e-8)
else:
advantages = raw_returns
# advantages stays on CPU until we move per-batch slices to GPU
# ── Phase 3: Gradient accumulation β€” one forward/backward per mini-batch ─
# Each iteration: build batch β†’ forward β†’ loss β†’ backward β†’ del graph.
# Only one forward pass worth of activations lives in VRAM at any time.
n_batches = _math.ceil(len(all_pairs) / loss_batch_size)
total_loss_val = 0.0
for batch_idx, batch_start in enumerate(range(0, len(all_pairs), loss_batch_size)):
batch = all_pairs[batch_start:batch_start + loss_batch_size]
batch_advs = advantages[batch_start:batch_start + loss_batch_size] # CPU tensor
batch_ids = [p[0].input_ids for p in batch]
batch_masks = [p[0].attention_mask for p in batch]
# Truncate outlier-length sequences (tail keeps action tokens)
batch_ids = [ids[-max_seq_len_cap:] if ids.shape[0] > max_seq_len_cap else ids for ids in batch_ids]
batch_masks = [m[-max_seq_len_cap:] if m.shape[0] > max_seq_len_cap else m for m in batch_masks]
# Build per-transition action masks: only compute log-probs over
# the GENERATED action tokens, not the prompt tokens.
# prompt_len is stored on each transition; after possible truncation
# we need to recompute the mask offset.
batch_action_masks = []
for ids, p in zip(batch_ids, batch):
plen = p[0].prompt_len # original prompt length before truncation
seq_len = ids.shape[0]
# If sequence was truncated from the left, adjust prompt_len:
# the kept portion starts at max(0, original_len - max_seq_len_cap)
original_len = p[0].input_ids.shape[0] if not isinstance(p[0].input_ids, int) else seq_len
if isinstance(p[0].input_ids, torch.Tensor) and p[0].input_ids.shape[0] > max_seq_len_cap:
offset = p[0].input_ids.shape[0] - max_seq_len_cap
plen = max(0, plen - offset)
amask = torch.zeros(seq_len, dtype=torch.long)
if plen < seq_len:
amask[plen:] = 1 # action tokens after prompt
batch_action_masks.append(amask)
# Left-pad to same length within mini-batch
max_len = max(ids.shape[0] for ids in batch_ids)
padded_ids, padded_masks, padded_action_masks = [], [], []
for ids, mask, amask in zip(batch_ids, batch_masks, batch_action_masks):
pad_len = max_len - ids.shape[0]
if pad_len > 0:
padded_ids.append(torch.cat([torch.full((pad_len,), pad_id, device=ids.device), ids]))
padded_masks.append(torch.cat([torch.zeros(pad_len, device=mask.device, dtype=mask.dtype), mask]))
# Action mask: left-pad with zeros (padding tokens are never action tokens)
padded_action_masks.append(torch.cat([torch.zeros(pad_len, dtype=torch.long), amask]))
else:
padded_ids.append(ids)
padded_masks.append(mask)
padded_action_masks.append(amask)
input_ids = torch.stack(padded_ids)
attention_mask = torch.stack(padded_masks)
# ── Forward pass ─────────────────────────────────────────────────────
torch.cuda.empty_cache()
if torch.cuda.is_available():
alloc = torch.cuda.memory_allocated() / 1024**3
free, total_mem = torch.cuda.mem_get_info()
torch.cuda.reset_peak_memory_stats()
print(f" [loss_fwd b{batch_idx+1}/{n_batches}] "
f"shape={input_ids.shape} alloc={alloc:.2f}GiB "
f"free={free/1024**3:.1f}/{total_mem/1024**3:.1f}GiB", flush=True)
outputs = model(
input_ids=input_ids.to(model.device),
attention_mask=attention_mask.to(model.device),
use_cache=False, # No KV-cache needed for single-pass loss forward.
)
if torch.cuda.is_available():
peak = torch.cuda.max_memory_allocated() / 1024**3
free2, _ = torch.cuda.mem_get_info()
print(f" [loss_fwd b{batch_idx+1}/{n_batches}] "
f"post-fwd peak={peak:.2f}GiB free={free2/1024**3:.1f}GiB", flush=True)
# ── Memory-efficient NLL via fused cross_entropy ─────────────────────
# F.cross_entropy(reduction='none') uses a single fused CUDA kernel:
# it never materialises the full [B, S-1, V] log-prob matrix (~623 MiB
# at V=151936, batch=1, seq=512) β€” it computes log-softmax + NLL in one
# pass, keeping only the per-token scalar result.
logits = outputs.logits
shift_logits = logits[:, :-1, :].contiguous() # (B, S-1, V)
shift_labels = input_ids[:, 1:].contiguous() # (B, S-1)
shift_labels = shift_labels.to(model.device)
shift_mask = attention_mask[:, 1:].to(model.device) # (B, S-1)
# token_nll: (B, S-1), zero for padded positions
token_nll = torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.clamp(min=0).view(-1),
reduction="none",
).view(shift_labels.shape) # (B, S-1)
token_nll = token_nll * shift_mask
# ── Action-only log-probs for REINFORCE ──
# Only sum NLL over action tokens (past prompt_len), not the prompt.
# This is critical: log Ο€(action | prompt) β‰  log Ο€(prompt+action).
# Masking out prompt tokens prevents the gradient from pushing on
# tokens the model can't control and eliminates a massive source of
# noise and variance in the REINFORCE gradient.
stacked_action_masks = torch.stack(padded_action_masks) # (B, S)
shift_action_mask = stacked_action_masks[:, 1:].to(model.device) # (B, S-1)
# Zero out NLL for prompt positions β€” only keep action token NLL
action_nll = token_nll * shift_action_mask
seq_log_probs = -(action_nll.sum(dim=1)) # (B,) sum of action-token log-probs only
# Number of action tokens per sequence (for optional normalization)
n_action_tokens = shift_action_mask.sum(dim=1).clamp(min=1) # (B,)
# ── Chunked vocab entropy (avoids materialising full [B, S, V]) ────────
# logsumexp over V gives the log-normaliser (1 scalar per token, ~4 MiB).
# We then accumulate -sum(p*log_p) chunk-by-chunk: each chunk is CHUNK_V
# columns β†’ peak extra alloc β‰ˆ BΓ—SΓ—CHUNK_VΓ—4B = 1Γ—511Γ—4096Γ—4 β‰ˆ 8 MiB
# instead of 623 MiB for the full V=151936 matrix.
CHUNK_V = 4096
# Exact single-pass entropy without materialising [B,S,V]:
# logsumexp over V gives the normaliser; we then compute -sum(p*log_p) chunk-by-chunk.
log_Z = shift_logits.logsumexp(dim=-1, keepdim=True) # (B, S-1, 1)
entropy_per_token = torch.zeros(shift_logits.shape[:2], device=model.device)
for v_start in range(0, shift_logits.size(-1), CHUNK_V):
chunk_logits = shift_logits[:, :, v_start:v_start + CHUNK_V] # (B, S-1, c)
log_p_chunk = chunk_logits - log_Z # log-prob for this slice
p_chunk = log_p_chunk.exp() # prob for this slice
entropy_per_token += -(p_chunk * log_p_chunk).sum(dim=-1) # accumulate (B, S-1)
del log_Z
# Free logits immediately before backward
del outputs, logits, shift_logits, token_nll
torch.cuda.empty_cache()
if torch.cuda.is_available():
peak = torch.cuda.max_memory_allocated() / 1024**3
free2, _ = torch.cuda.mem_get_info()
print(f" [loss_fwd b{batch_idx+1}/{n_batches}] "
f"post-fwd peak={peak:.2f}GiB free={free2/1024**3:.1f}GiB", flush=True)
# ── Per-mini-batch loss: REINFORCE + entropy bonus ────────────────────
ent_coef = cfg.get("entropy_coef", 0.001)
n_valid_tokens = shift_mask.sum(dim=1).clamp(min=1) # (B,)
# Only compute entropy over action tokens (same region as log-probs)
n_action_valid = (shift_action_mask * shift_mask).sum(dim=1).clamp(min=1) # (B,)
avg_token_entropy = ((entropy_per_token * shift_action_mask * shift_mask).sum(dim=1) / n_action_valid).mean()
print(f" [entropy b{batch_idx+1}/{n_batches}] "
f"avg_token_entropy={avg_token_entropy.item():.3f}nats "
f"ent_coef={ent_coef} "
f"reinforce={-(batch_advs.to(model.device) * seq_log_probs).mean().item():.4f} "
f"ent_bonus={ent_coef * avg_token_entropy.item():.4f}", flush=True)
batch_advs_gpu = batch_advs.to(model.device)
# Normalize log-probs by number of action tokens to prevent
# length-dependent gradient scaling. Without this, sequences with
# more action tokens get disproportionately large gradients.
norm_seq_log_probs = seq_log_probs / n_action_tokens # (B,)
batch_loss = (
-(batch_advs_gpu * norm_seq_log_probs).mean() # REINFORCE (length-normalized)
- ent_coef * avg_token_entropy # per-token entropy bonus
) / n_batches
# ── Backward immediately β€” frees entire computation graph ─────────────
batch_loss.backward()
total_loss_val += batch_loss.item() * n_batches
del batch_loss, seq_log_probs, batch_advs_gpu, avg_token_entropy, entropy_per_token
torch.cuda.empty_cache()
# Return detached scalar for logging (requires_grad=False β†’ caller skips backward)
return torch.tensor(total_loss_val / n_batches, device=model.device)
def grpo_loss_fn(
model,
tokenizer,
episodes: List,
cfg: Dict[str, Any],
) -> torch.Tensor:
"""GRPO (Group Relative Policy Optimization) loss.
Requires episodes to be structured as K groups of same-(task_id, seed) rollouts.
Each group's advantages are normalised relative to that group's mean/std,
eliminating the need for a value-function baseline.
Uses the same OOM-safe per-mini-batch backward() as reinforce_baseline_loss_fn.
"""
import math as _math
gamma = cfg.get("reward_gamma", 0.99)
loss_batch_size = cfg.get("loss_batch_size", 1)
max_seq_len_cap = cfg.get("max_seq_length", 512)
pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
# ── Phase 1: Group episodes by (task_id, seed), compute group advantages ─
# Key: (task_id, seed) β€” both must match for same-state comparison.
# Seed is stored on Episode.seed (set during rollout collection).
groups: Dict[tuple, List] = {}
for ep in episodes:
key = (ep.task_id, ep.seed) # EXACT grouping β€” not id() approximation
groups.setdefault(key, []).append(ep)
all_pairs: List[Tuple] = [] # (transition, advantage)
for key, group in groups.items():
if len(group) == 1:
# Group of 1 β€” advantage=0, no gradient signal.
# Happens if num_episodes is not a multiple of grpo_k.
print(f" [grpo] WARNING: group {key} has only 1 episode β€” "
f"num_episodes must be grpo_k Γ— num_tasks", flush=True)
# Episode-level returns (discounted sum from step 0)
group_returns = []
for ep in group:
rewards = [t.reward for t in ep.transitions]
returns = compute_returns(rewards, gamma)
group_returns.append(returns[0] if returns else 0.0)
group_mean = sum(group_returns) / len(group_returns)
group_std = (sum((r - group_mean) ** 2 for r in group_returns)
/ max(len(group_returns) - 1, 1)) ** 0.5 + 1e-8 # Bessel-corrected
for ep, ep_return in zip(group, group_returns):
advantage = (ep_return - group_mean) / group_std
for trans in ep.transitions:
if trans.input_ids is not None:
all_pairs.append((trans, advantage))
if not all_pairs:
return torch.tensor(0.0, device=model.device)
# ── Phase 2: Advantage tensor on CPU ────────────────────────────────────
advantages = torch.tensor([p[1] for p in all_pairs], dtype=torch.float32)
# ── Phase 3: OOM-safe mini-batch forward/backward ───────────────────────
n_batches = _math.ceil(len(all_pairs) / loss_batch_size)
total_loss_val = 0.0
ent_coef = cfg.get("entropy_coef", 0.001)
CHUNK_V = 4096
for batch_idx, batch_start in enumerate(range(0, len(all_pairs), loss_batch_size)):
batch = all_pairs[batch_start:batch_start + loss_batch_size]
batch_advs = advantages[batch_start:batch_start + loss_batch_size]
batch_ids = [p[0].input_ids for p in batch]
batch_masks = [p[0].attention_mask for p in batch]
# Truncate + left-pad
batch_ids = [ids[-max_seq_len_cap:] if ids.shape[0] > max_seq_len_cap else ids
for ids in batch_ids]
batch_masks = [m[-max_seq_len_cap:] if m.shape[0] > max_seq_len_cap else m
for m in batch_masks]
# Build action masks (same as reinforce_baseline_loss_fn)
batch_action_masks = []
for ids, p in zip(batch_ids, batch):
plen = p[0].prompt_len
seq_len = ids.shape[0]
if isinstance(p[0].input_ids, torch.Tensor) and p[0].input_ids.shape[0] > max_seq_len_cap:
offset = p[0].input_ids.shape[0] - max_seq_len_cap
plen = max(0, plen - offset)
amask = torch.zeros(seq_len, dtype=torch.long)
if plen < seq_len:
amask[plen:] = 1
batch_action_masks.append(amask)
max_len = max(ids.shape[0] for ids in batch_ids)
padded_ids, padded_masks, padded_action_masks = [], [], []
for ids, mask, amask in zip(batch_ids, batch_masks, batch_action_masks):
pad_len = max_len - ids.shape[0]
if pad_len > 0:
padded_ids.append(torch.cat(
[torch.full((pad_len,), pad_id, device=ids.device), ids]))
padded_masks.append(torch.cat(
[torch.zeros(pad_len, device=mask.device, dtype=mask.dtype), mask]))
padded_action_masks.append(torch.cat(
[torch.zeros(pad_len, dtype=torch.long), amask]))
else:
padded_ids.append(ids)
padded_masks.append(mask)
padded_action_masks.append(amask)
input_ids = torch.stack(padded_ids)
attention_mask = torch.stack(padded_masks)
torch.cuda.empty_cache()
outputs = model(
input_ids=input_ids.to(model.device),
attention_mask=attention_mask.to(model.device),
use_cache=False,
)
# Fused cross_entropy NLL (no [B,S,V] materialisation)
logits = outputs.logits
shift_logits = logits[:, :-1, :].contiguous() # (B, S-1, V)
shift_labels = input_ids[:, 1:].contiguous().to(model.device)
shift_mask_g = attention_mask[:, 1:].to(model.device)
token_nll = torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.clamp(min=0).view(-1),
reduction="none",
).view(shift_labels.shape)
token_nll = token_nll * shift_mask_g
# ── Action-only log-probs for GRPO ──
stacked_action_masks = torch.stack(padded_action_masks) # (B, S)
shift_action_mask = stacked_action_masks[:, 1:].to(model.device) # (B, S-1)
action_nll = token_nll * shift_action_mask
seq_log_probs = -(action_nll.sum(dim=1)) # (B,)
# Chunked entropy
log_Z = shift_logits.logsumexp(dim=-1, keepdim=True)
entropy_per_token = torch.zeros(shift_logits.shape[:2], device=model.device)
for v_start in range(0, shift_logits.size(-1), CHUNK_V):
chunk = shift_logits[:, :, v_start:v_start + CHUNK_V] - log_Z
p_chunk = chunk.exp()
entropy_per_token += -(p_chunk * chunk).sum(dim=-1)
del log_Z, outputs, logits, shift_logits, token_nll
torch.cuda.empty_cache()
n_valid = (shift_action_mask * shift_mask_g).sum(dim=1).clamp(min=1)
avg_entropy = ((entropy_per_token * shift_action_mask * shift_mask_g).sum(dim=1) / n_valid).mean()
batch_advs_gpu = batch_advs.to(model.device)
# Length-normalized log-probs (same as reinforce_baseline_loss_fn)
n_action_tokens_grpo = shift_action_mask.sum(dim=1).clamp(min=1) # (B,)
norm_seq_log_probs = seq_log_probs / n_action_tokens_grpo
batch_loss = (
-(batch_advs_gpu * norm_seq_log_probs).mean()
- ent_coef * avg_entropy
) / n_batches
batch_loss.backward()
total_loss_val += batch_loss.item() * n_batches
del batch_loss, seq_log_probs, batch_advs_gpu, avg_entropy, entropy_per_token
torch.cuda.empty_cache()
return torch.tensor(total_loss_val / n_batches, device=model.device)
# ────────────────────────────────────────────────────────────
# Run Files Push (to hub_model_repo/<run_id>/)
# ────────────────────────────────────────────────────────────
def push_run_files_to_hub(
run_id: str,
output_dir: Path,
hub_model_repo: str,
iteration: int,
) -> None:
"""Upload step_metrics.jsonl, iter_metrics.jsonl, training.log, and eval results.
Files are uploaded under <run_id>/ in the model repo alongside checkpoints.
Called every checkpoint_interval iterations and at the end of training.
"""
if not hub_model_repo:
return
files_to_push = [
("step_metrics.jsonl", f"{run_id}/step_metrics.jsonl"),
("iter_metrics.jsonl", f"{run_id}/iter_metrics.jsonl"),
("training.log", f"{run_id}/training.log"),
("run_info.json", f"{run_id}/run_info.json"),
]
# Also push eval results if they exist
eval_path = output_dir / "eval" / "eval_results.json"
if eval_path.exists():
files_to_push.append(("eval/eval_results.json", f"{run_id}/eval_results.json"))
final_eval_path = output_dir / "final_eval" / "eval_results.json"
if final_eval_path.exists():
files_to_push.append(("final_eval/eval_results.json", f"{run_id}/final_eval_results.json"))
try:
from huggingface_hub import HfApi
api = HfApi()
pushed = []
for local_name, hub_path in files_to_push:
local_path = output_dir / local_name
if not local_path.exists():
continue
try:
api.upload_file(
path_or_fileobj=str(local_path),
path_in_repo=hub_path,
repo_id=hub_model_repo,
repo_type="model",
commit_message=f"[{run_id}] iter {iteration}: {local_name}",
)
pushed.append(local_name)
except Exception as e:
print(f" [push] Failed to push {local_name}: {e}")
if pushed:
print(f" [push] \u2192 HF model {hub_model_repo}/{run_id}/: {', '.join(pushed)}",
flush=True)
except Exception as e:
print(f"[train] Hub file push failed: {e}")
# ────────────────────────────────────────────────────────────
# Local JSONL Metrics Writers
# ────────────────────────────────────────────────────────────
def write_step_metrics(
run_id: str,
iteration: int,
episode_idx: int,
task_id: str,
step: int,
transition, # Transition dataclass from openenv_loop
output_dir: Path,
) -> None:
"""Append one per-step row to step_metrics.jsonl.
Each row captures the full cluster state at that step:
action chosen, reward received, and all per-node + cluster-level
metrics so you can graph queue depth, latency, cost, SLA violations,
action distribution, etc. over time.
"""
obs = transition.obs_dict or {}
action = transition.action
row: Dict[str, Any] = {
# ── Identity ──────────────────────────────────────────────
"run_id": run_id,
"iteration": iteration,
"episode_idx": episode_idx,
"task_id": task_id,
"step": step,
"ts": __import__("datetime").datetime.utcnow().isoformat() + "Z",
# ── Action ────────────────────────────────────────────────
"action_type": action.action_type,
"target_node": action.target_node_id,
"parameter": round(action.parameter, 4),
"is_valid": action.is_valid,
# ── Reward ────────────────────────────────────────────────
"reward": round(transition.reward, 6),
# ── Cluster-level metrics ──────────────────────────────────
"avg_latency_ms": round(obs.get("average_latency_ms", 0.0), 3),
"error_rate": round(obs.get("error_rate", 0.0), 6),
"total_queue_backlog": round(obs.get("total_queue_backlog", 0.0), 4),
"cost_per_hour": round(obs.get("current_cost_per_hour", 0.0), 4),
"sla_violations": obs.get("sla_violations", 0),
}
# ── Per-node metrics (flat columns: n0_q, n0_l, n0_s, ...) ──
for node in obs.get("nodes", []):
nid = node.get("node_id", "")
key = nid.replace("-", "") # "node-0" β†’ "node0"
row[f"{key}_status"] = node.get("status", "")[:1] # H/D/F
row[f"{key}_queue"] = round(node.get("queue_depth", 0.0), 4)
row[f"{key}_latency"] = round(node.get("latency_ms", 0.0), 2)
row[f"{key}_inflow"] = round(node.get("incoming_request_rate", 0.0), 2)
row[f"{key}_outflow"] = round(node.get("outflow_rate", 0.0), 2)
row[f"{key}_capacity"] = round(node.get("capacity", 0.0), 4)
row[f"{key}_pending"] = round(node.get("pending_capacity", 0.0), 4)
path = output_dir / "step_metrics.jsonl"
with open(path, "a") as f:
f.write(json.dumps(row) + "\n")
def write_iter_metrics(
run_id: str,
iteration: int,
loss: float,
avg_reward: float,
grad_norm: float,
total_invalid: int,
num_episodes: int,
iter_time_s: float,
output_dir: Path,
) -> None:
"""Append one per-iteration row to iter_metrics.jsonl."""
row = {
"run_id": run_id,
"iteration": iteration,
"ts": __import__("datetime").datetime.utcnow().isoformat() + "Z",
"loss": round(loss, 6),
"avg_reward": round(avg_reward, 6),
"grad_norm": round(grad_norm, 4),
"invalid_actions": total_invalid,
"num_episodes": num_episodes,
"iter_time_s": round(iter_time_s, 2),
}
path = output_dir / "iter_metrics.jsonl"
with open(path, "a") as f:
f.write(json.dumps(row) + "\n")
class _TeeLogger:
"""Duplicates writes to both the original stream and a log file.
Activated at the start of train() so that every print() β€” VRAM stats,
step logs, entropy, iteration summaries, tracebacks β€” goes to both
the HF job terminal stream AND a persistent training.log on disk.
"""
def __init__(self, stream, log_path: Path):
self._stream = stream
self._file = open(log_path, "a", buffering=1, encoding="utf-8") # line-buffered
def write(self, data: str) -> None:
self._stream.write(data)
self._file.write(data)
def flush(self) -> None:
self._stream.flush()
self._file.flush()
def fileno(self) -> int:
return self._stream.fileno() # subprocess / os compatibility
def isatty(self) -> bool:
return False
def close(self) -> None:
try:
self._file.flush()
self._file.close()
except Exception:
pass
@property
def original_stream(self):
return self._stream
def _log_vram(where: str) -> None:
"""Print CUDA memory usage at key diagnostic points."""
if not torch.cuda.is_available():
return
free, total = torch.cuda.mem_get_info()
alloc = torch.cuda.memory_allocated() / (1024 ** 3)
reserved = torch.cuda.memory_reserved() / (1024 ** 3)
peak = torch.cuda.max_memory_allocated() / (1024 ** 3)
print(f" [VRAM @{where}] "
f"alloc={alloc:6.2f}GiB reserved={reserved:6.2f}GiB "
f"peak={peak:6.2f}GiB free={free/1024**3:.1f}/{total/1024**3:.1f}GiB",
flush=True)
def train(cfg: Dict[str, Any]) -> None:
"""Main training loop."""
# ---- Reproducibility ----
seed = cfg.get("seed", 42)
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
run_id = cfg.get("run_id", "exp_001")
# ── Run-specific output directory ───────────────────────────────────────
# Structure: <base_output_dir>/<run_id>/
# checkpoint-0010/ ← saved every checkpoint_interval iters
# checkpoint-0020/
# ...
# final_adapter/ ← saved at end of training
# run_info.json ← written at startup; identifies this run
#
# Using run_id as a subfolder means multiple runs never overwrite each other.
base_output_dir = Path(cfg.get("output_dir", "/workspace/antiatropos_checkpoints"))
output_dir = base_output_dir / run_id
output_dir.mkdir(parents=True, exist_ok=True)
# ── Activate full-run logging to disk ────────────────────────────────────
# Every print() from here on is tee'd to training.log (line-buffered).
# This mirrors the HF job terminal stream to a persistent file so you
# can inspect the full log even after the job completes or crashes.
log_path = output_dir / "training.log"
_orig_stdout = sys.stdout
_orig_stderr = sys.stderr
sys.stdout = _TeeLogger(sys.stdout, log_path)
sys.stderr = _TeeLogger(sys.stderr, log_path)
print(f"[train] Full log: {log_path}")
# Write run manifest so checkpoints are always identifiable
import json as _json
run_info = {
"run_id": run_id,
"started_at": __import__("datetime").datetime.utcnow().isoformat() + "Z",
"config": {k: v for k, v in cfg.items() if not k.startswith("_")},
}
run_info_path = output_dir / "run_info.json"
run_info_path.write_text(_json.dumps(run_info, indent=2, default=str))
print(f"[train] Run directory: {output_dir}")
print(f"[train] Run manifest: {run_info_path}")
hub_model_repo = cfg.get("hub_model_repo", "")
push_to_hub_flag = cfg.get("push_to_hub", True)
# ---- Verify environment ----
env_url = cfg.get("env_url", "https://pranavkk-antiatropos.hf.space")
client = OpenEnvClient(env_url)
if not client.verify():
print("[train] FATAL: Cannot reach environment. Aborting.")
sys.exit(1)
# ---- Load model ----
print("\n[train] Loading model...")
model, tokenizer = load_base_model(cfg)
_log_vram("model_loaded")
# ---- Check for resume ----
start_iteration = 0
ckpt_path = find_latest_checkpoint(hub_model_repo) if hub_model_repo else None
if ckpt_path:
local_ckpt = download_checkpoint(hub_model_repo, ckpt_path)
model = load_from_checkpoint(model, tokenizer, local_ckpt)
try:
start_iteration = int(ckpt_path.split("-")[1])
except (ValueError, IndexError):
start_iteration = 0
print(f"[train] Resuming from iteration {start_iteration}")
else:
model = attach_lora(model, cfg, seed=seed)
# Unsloth's attach_lora already enables gradient checkpointing via
# use_gradient_checkpointing="unsloth" in get_peft_model().
# Do NOT call gradient_checkpointing_enable() again β€” it conflicts
# with Unsloth's custom implementation and can increase VRAM usage.
# ---- Optimizer ----
lr = cfg.get("learning_rate", 2e-4)
weight_decay = cfg.get("weight_decay", 0.01)
optim_name = cfg.get("optim", "adamw_8bit")
optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=lr,
weight_decay=weight_decay,
)
# ---- Loss function ----
loss_type = cfg.get("loss_type", "reinforce_baseline")
loss_fns = {
"reinforce_baseline": reinforce_baseline_loss_fn,
"grpo": grpo_loss_fn,
}
loss_fn = loss_fns.get(loss_type, reinforce_baseline_loss_fn)
print(f"[train] Loss function: {loss_type}")
# ---- Config ----
num_iterations = cfg.get("num_iterations", 500)
num_episodes = cfg.get("num_episodes_per_iteration", 4)
max_steps = cfg.get("max_steps_per_episode", 60)
tasks = cfg.get("tasks", ["task-1", "task-2", "task-3"])
max_grad_norm = cfg.get("max_grad_norm", 1.0)
checkpoint_interval = cfg.get("checkpoint_interval", 10) # default: every 10 iters
eval_interval = cfg.get("eval_interval", 50)
plot_interval = cfg.get("plot_interval", 25)
# ---- Training loop ----
print(f"\n{'='*70}")
print(f"ANTIATROPOS QLORA TRAINING")
print(f"{'='*70}")
print(f" Run ID: {run_id}")
print(f" Loss type: {loss_type}")
print(f" Iterations: {num_iterations}")
print(f" Episodes/iter: {num_episodes}")
print(f" Tasks: {tasks}")
print(f" Max steps: {max_steps}")
print(f" Learning rate: {lr}")
print(f" Hub model: {hub_model_repo or '(not configured)'}")
print(f" Output dir: {output_dir}")
print(f"{'='*70}\n")
# Keep model in eval mode during rollout to minimise VRAM pressure.
# for_training() is called only right before the loss forward pass.
model.eval()
_log_vram("eval_after_attach")
metrics_buffer: List[Dict] = []
eval_metrics_history: List[Dict] = []
recent_episodes_data: List[Dict] = [] # For plotting action distributions
for iteration in range(start_iteration, num_iterations):
iter_start = time.time()
# ---- Collect rollouts (parallel batch) ----
# GRPO requires K episodes per task from the SAME seed so grpo_loss_fn
# can group by (task_id, seed) and compute within-group advantages.
# REINFORCE uses unique seeds per episode for diversity.
if loss_type == "grpo":
k = cfg.get("grpo_k", 2)
# Validate: num_episodes must be k * num_tasks
expected = k * len(tasks)
if num_episodes != expected:
print(f" [grpo] WARNING: num_episodes={num_episodes} β‰  "
f"grpo_k({k}) Γ— num_tasks({len(tasks)})={expected}. "
f"Forcing to {expected}.", flush=True)
num_episodes = expected
# Each task gets k copies with the same per-task seed
task_ids = [tasks[t] for t in range(len(tasks)) for _ in range(k)]
task_seeds = [seed + iteration * 100 + t for t in range(len(tasks))]
seeds = [task_seeds[t] for t in range(len(tasks)) for _ in range(k)]
else:
task_ids = [tasks[ep_idx % len(tasks)] for ep_idx in range(num_episodes)]
seeds = [seed + iteration * 1000 + ep_idx for ep_idx in range(num_episodes)]
_log_vram(f"i{iteration}_pre_rollout")
try:
use_parallel = cfg.get("parallel_episodes", True)
if use_parallel and num_episodes > 1:
episodes = rollout_batch(
env_url, model, tokenizer, task_ids,
max_steps, cfg, seeds,
)
else:
# Fallback: sequential rollout (for debugging)
episodes = []
for ep_idx in range(num_episodes):
task_id = tasks[ep_idx % len(tasks)]
seed_ep = seed + iteration * 1000 + ep_idx
ep = rollout_episode(
client, model, tokenizer, task_id,
max_steps, cfg, seed=seed_ep,
)
episodes.append(ep)
except Exception as e:
print(f" [iter {iteration}] Batch rollout failed: {e}")
continue
# ---- Clear VRAM before loss (generation KV-cache on GPU) ----
torch.cuda.empty_cache()
import gc
gc.collect()
# Move rollout tensors to CPU β€” loss will move them back in batches
for ep in episodes:
for t in ep.transitions:
if t.input_ids is not None:
t.input_ids = t.input_ids.cpu()
if t.attention_mask is not None:
t.attention_mask = t.attention_mask.cpu()
_log_vram(f"i{iteration}_after_offload")
# ---- Compute loss (standard train mode β€” base 4-bit stays frozen, only LoRA needs gradients) ----
model.train()
_log_vram(f"i{iteration}_after_train")
loss = loss_fn(model, tokenizer, episodes, cfg)
# ---- Optimizer step (loss_fn already called .backward() per mini-batch) ----
grad_norm = torch.nn.utils.clip_grad_norm_(
filter(lambda p: p.requires_grad, model.parameters()),
max_grad_norm,
)
optimizer.step()
optimizer.zero_grad()
# Clear training intermediates and return to eval for next rollout
torch.cuda.empty_cache()
model.eval()
_log_vram(f"i{iteration}_post_grad")
# ---- Compute iteration metrics ----
avg_reward = sum(ep.avg_reward for ep in episodes) / len(episodes)
total_invalid = sum(ep.num_invalid for ep in episodes)
iter_time = time.time() - iter_start
# ---- Write per-step metrics (one row per episode step) ----
# Done here (post-training) because training may mutate episode objects.
for ep_idx, ep in enumerate(episodes):
for step_idx, t in enumerate(ep.transitions):
write_step_metrics(
run_id=run_id,
iteration=iteration,
episode_idx=ep_idx,
task_id=ep.task_id,
step=step_idx + 1,
transition=t,
output_dir=output_dir,
)
# ---- Write per-iteration metrics ----
_grad_norm_val = (
grad_norm.item() if torch.is_tensor(grad_norm) else float(grad_norm)
)
write_iter_metrics(
run_id=run_id,
iteration=iteration,
loss=loss.item(),
avg_reward=avg_reward,
grad_norm=_grad_norm_val,
total_invalid=total_invalid,
num_episodes=len(episodes),
iter_time_s=iter_time,
output_dir=output_dir,
)
print(f" [iter {iteration:4d}] loss={loss.item():.4f} "
f"avg_reward={avg_reward:.4f} "
f"invalid={total_invalid} "
f"grad_norm={_grad_norm_val:.4f} "
f"time={iter_time:.1f}s")
# Store episode data for plotting (keep recent window)
ep_data = episodes_to_plot_data(episodes)
recent_episodes_data.extend(ep_data)
if len(recent_episodes_data) > 200: # Keep last ~200 episodes
recent_episodes_data = recent_episodes_data[-200:]
# ---- Checkpoint + push run files ----
if (iteration + 1) % checkpoint_interval == 0:
# Pad iteration number so ls sorts correctly: checkpoint-0010, checkpoint-0050, ...
ckpt_name = f"checkpoint-{iteration + 1:04d}"
ckpt_dir = output_dir / ckpt_name
ckpt_dir.mkdir(parents=True, exist_ok=True)
model.save_pretrained(str(ckpt_dir))
tokenizer.save_pretrained(str(ckpt_dir))
# Write a small metadata file so you know exactly what's in each checkpoint
ckpt_meta = {
"run_id": run_id,
"iteration": iteration + 1,
"avg_reward": avg_reward,
"loss": loss.item(),
"saved_at": __import__("datetime").datetime.utcnow().isoformat() + "Z",
}
(ckpt_dir / "checkpoint_meta.json").write_text(
_json.dumps(ckpt_meta, indent=2)
)
print(f" [ckpt] Saved \u2192 {ckpt_dir} "
f"(reward={avg_reward:.4f} loss={loss.item():.4f})", flush=True)
if push_to_hub_flag and hub_model_repo:
push_to_hub(
str(ckpt_dir),
hub_model_repo,
commit_message=f"[{run_id}] {ckpt_name}",
path_in_repo=f"{run_id}/{ckpt_name}",
)
# Push run files (metrics, logs) alongside checkpoint
push_run_files_to_hub(run_id, output_dir, hub_model_repo, iteration + 1)
# ---- Evaluation ----
if (iteration + 1) % eval_interval == 0:
eval_results = evaluate(
client, model, tokenizer, cfg,
output_dir=str(output_dir / "eval"),
)
eval_row = {
"run_id": run_id,
"step": iteration,
"type": "eval",
}
# Flatten eval results for plotting
for k, v in eval_results.items():
if not isinstance(v, dict):
eval_row[f"eval_{k}"] = v
for tid, tv in eval_results.get("per_task", {}).items():
for mk, mv in tv.items():
eval_row[f"eval_{tid}_{mk}"] = mv
eval_metrics_history.append(eval_row)
# Re-enable training mode
model.train()
# ---- Plotting ----
if (iteration + 1) % plot_interval == 0:
try:
plot_paths = generate_all_plots(
train_metrics=metrics_buffer,
eval_metrics=eval_metrics_history,
episodes_data=recent_episodes_data,
output_dir=str(output_dir),
cfg=cfg,
)
if push_to_hub_flag and hub_model_repo:
push_plots_to_hub(plot_paths, hub_model_repo, iteration, run_id=run_id)
except Exception as e:
print(f" [iter {iteration}] Plotting failed: {e}")
# ────────────────────────────────────────────────────────
# Final save + push
# ────────────────────────────────────────────────────────
print(f"\n{'='*70}")
print(f"TRAINING COMPLETE")
print(f"{'='*70}")
# Save final adapter
final_dir = str(output_dir / "final_adapter")
Path(final_dir).mkdir(parents=True, exist_ok=True)
model.save_pretrained(final_dir)
tokenizer.save_pretrained(final_dir)
print(f"[train] Final adapter saved to {final_dir}")
# Push to Hub β€” scoped under run_id/final_adapter/ so it never overwrites other runs
if push_to_hub_flag and hub_model_repo:
push_to_hub(
final_dir,
hub_model_repo,
commit_message=f"[{run_id}] final_adapter",
path_in_repo=f"{run_id}/final_adapter",
)
# Final evaluation
final_eval = evaluate(
client, model, tokenizer, cfg,
output_dir=str(output_dir / "final_eval"),
)
# Final plots (full training history)
try:
final_eval_row = {
"run_id": run_id,
"step": num_iterations,
"type": "eval",
}
for k, v in final_eval.items():
if not isinstance(v, dict):
final_eval_row[f"eval_{k}"] = v
for tid, tv in final_eval.get("per_task", {}).items():
for mk, mv in tv.items():
final_eval_row[f"eval_{tid}_{mk}"] = mv
eval_metrics_history.append(final_eval_row)
plot_paths = generate_all_plots(
train_metrics=metrics_buffer,
eval_metrics=eval_metrics_history,
episodes_data=recent_episodes_data,
output_dir=str(output_dir),
cfg=cfg,
)
if push_to_hub_flag and hub_model_repo:
push_plots_to_hub(plot_paths, hub_model_repo, num_iterations, run_id=run_id)
except Exception as e:
print(f"[train] Final plotting failed: {e}")
print(f"\n[train] All done. Final adapter: {final_dir}")
if hub_model_repo:
print(f"[train] Hub repo: https://huggingface.co/{hub_model_repo}")
# \u2500\u2500 Final push of all run files
if hub_model_repo:
push_run_files_to_hub(run_id, output_dir, hub_model_repo, num_iterations)
# ── Flush and close the TeeLogger ────────────────────────────────────────
# Restore original stdout/stderr so any code after train() works normally.
print(f"[train] Full training log saved to: {log_path}", flush=True)
if isinstance(sys.stdout, _TeeLogger):
sys.stdout.close()
sys.stdout = _orig_stdout
if isinstance(sys.stderr, _TeeLogger):
sys.stderr.close()
sys.stderr = _orig_stderr
# ────────────────────────────────────────────────────────────
# CLI Entry Point
# ────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(
description="AntiAtropos QLoRA Training β€” HF Jobs Edition"
)
parser.add_argument(
"--config", type=str,
default=str(TRAINING_DIR / "config.yaml"),
help="Path to config.yaml (default: training/config.yaml)",
)
# ---- Quick overrides for smoke runs ----
parser.add_argument("--num-iterations", type=int, default=None,
help="Total training iterations (default: from config)")
parser.add_argument("--num-episodes", type=int, default=None,
help="Episodes per iteration (default: from config)")
parser.add_argument("--max-steps", type=int, default=None,
help="Max steps per episode (default: from config)")
parser.add_argument("--loss-type", type=str, default=None,
choices=["reinforce_baseline", "grpo"],
help="Loss function type")
parser.add_argument("--env-mode", type=str, default=None,
choices=["simulated", "hybrid", "live"],
help="Environment mode (default: from config)")
parser.add_argument("--eval-interval", type=int, default=None,
help="Evaluate every N iterations")
parser.add_argument("--checkpoint-interval", type=int, default=None,
help="Checkpoint every N iterations")
parser.add_argument("--plot-interval", type=int, default=None,
help="Generate plots every N iterations")
parser.add_argument("--run-id", type=str, default=None,
help="Unique run identifier")
parser.add_argument("--output-dir", type=str, default=None,
help="Local output directory")
parser.add_argument("--no-push", action="store_true",
help="Disable all Hub pushes (model + metrics + plots)")
parser.add_argument("--smoke", action="store_true",
help="Quick smoke run: 10 iters, 2 episodes, 20 steps, "
"no push, eval/ckpt/plot every 5")
args = parser.parse_args()
# Load config
cfg = load_config(args.config)
# ---- Smoke run preset ----
if args.smoke:
cfg["num_iterations"] = 10
cfg["num_episodes_per_iteration"] = 2
cfg["max_steps_per_episode"] = 40
cfg["eval_interval"] = 5
cfg["checkpoint_interval"] = 5
cfg["plot_interval"] = 5
cfg["push_to_hub"] = False
cfg["eval_episodes"] = 1
if not args.run_id:
cfg["run_id"] = "smoke_test"
if not args.output_dir:
cfg["output_dir"] = "/tmp/antiatropos_smoke"
print("[SMOKE MODE] 10 iters x 2 episodes x 40 steps β€” no Hub push")
# ---- CLI overrides (explicit args beat smoke preset) ----
if args.num_iterations is not None:
cfg["num_iterations"] = args.num_iterations
if args.num_episodes is not None:
cfg["num_episodes_per_iteration"] = args.num_episodes
if args.max_steps is not None:
cfg["max_steps_per_episode"] = args.max_steps
if args.loss_type is not None:
cfg["loss_type"] = args.loss_type
if args.env_mode is not None:
cfg["env_mode"] = args.env_mode
if args.eval_interval is not None:
cfg["eval_interval"] = args.eval_interval
if args.checkpoint_interval is not None:
cfg["checkpoint_interval"] = args.checkpoint_interval
if args.plot_interval is not None:
cfg["plot_interval"] = args.plot_interval
if args.run_id is not None:
cfg["run_id"] = args.run_id
if args.output_dir is not None:
cfg["output_dir"] = args.output_dir
if args.no_push:
cfg["push_to_hub"] = False
cfg["hub_model_repo"] = ""
# Allow HF_TOKEN from env
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN")
if hf_token:
os.environ["HF_TOKEN"] = hf_token
os.environ["HUGGING_FACE_HUB_TOKEN"] = hf_token
train(cfg)
if __name__ == "__main__":
main()