"""DAgger online training loop. Orchestrates the full DAgger pipeline: collect data via model + oracle, train on buffer, evaluate periodically, and checkpoint. """ from __future__ import annotations import logging import random import time from pathlib import Path from types import SimpleNamespace import numpy as np import torch import torch.nn as nn import yaml from src.buffer import ReplayBuffer from src.config import make_run_dir from src.diffusion.forward import q_sample from src.diffusion.loss import auxiliary_goal_loss, mdlm_loss from src.diffusion.schedules import get_schedule from src.models.denoiser import ModelEMA, make_model, try_compile from src.planners.collect import DataCollector from src.planners.inference import Evaluator, save_eval_json from src.planners.logging import ( Logger, gpu_memory_mb, reset_gpu_memory_stats, compute_param_norm, compute_param_drift, ) from src.curriculum import DynamicCurriculum from src.envs.minihack_env import collect_oracle_trajectory logger = logging.getLogger(__name__) class Trainer: """Full DAgger training loop. Args: model: Denoising model. ema_model: EMA tracker. optimizer: Torch optimizer. scheduler: Optional LR scheduler. buffer: Replay buffer. collector: DAgger data collector. evaluator: Evaluation runner. log: Centralised logger. cfg: Config namespace. device: Torch device. """ def __init__( self, model: nn.Module, ema_model: ModelEMA, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler | None, buffer: ReplayBuffer, collector: DataCollector, evaluator: Evaluator, log: Logger, cfg: SimpleNamespace, device: torch.device | str, raw_model: nn.Module | None = None, ) -> None: self.model = model # raw_model is the uncompiled model used for eval deep-copies. # When torch.compile is off, raw_model is the same as model. self._raw_model = raw_model if raw_model is not None else model self.ema_model = ema_model self.optimizer = optimizer self.scheduler = scheduler self.buffer = buffer self.collector = collector self.evaluator = evaluator self.log = log self.cfg = cfg self.device = device self._schedule_fn = get_schedule(cfg.noise_schedule) # Snapshot of initial weights for param drift tracking self._init_state = { k: v.clone() for k, v in self._raw_model.state_dict().items() if v.is_floating_point() } # AMP scaler: enabled only when use_amp=true and on CUDA self._use_amp = ( getattr(cfg, "use_amp", False) and str(device).startswith("cuda") ) self._scaler = torch.amp.GradScaler("cuda", enabled=self._use_amp) # ── Main loop ──────────────────────────────────────────────── def train( self, start_iter: int = 0, start_env_steps: int = 0, ) -> None: """Run the DAgger training loop. The budget is ``cfg.total_timesteps`` — total env.step() calls across model + oracle rollouts. Iteration count is derived; it depends on how many env steps each iteration consumes (which in turn depends on episode length and efficiency filter outcomes). Args: start_iter: Iteration index to resume from (for logging). start_env_steps: Cumulative env steps already consumed. """ cfg = self.cfg env_steps_total = start_env_steps iteration = start_iter last_id_eval_step = start_env_steps last_ood_eval_step = start_env_steps last_ckpt_step = start_env_steps while env_steps_total < cfg.total_timesteps: reset_gpu_memory_stats() iter_start = time.perf_counter() # 1. Collect N episodes per iteration n_eps = getattr(cfg, "episodes_per_iteration", 1) num_workers = getattr(cfg, "num_collection_workers", 0) model_wins = 0 added_total = 0 # Accumulators across all n_eps episodes — must be summed, # NOT taken from a single (last) episode, otherwise the # unified env-step budget undercounts by ~n_eps×. model_steps_iter = 0 oracle_steps_iter = 0 last_env_id: str = "" collect_start = time.perf_counter() use_gpu_batch = ( str(self.device).startswith("cuda") and n_eps > 1 ) if use_gpu_batch: # GPU-batched collection (all envs in lockstep) batch_stats = self.collector.collect_batch_gpu(n_eps) for s in batch_stats: model_wins += int(s["model_won"]) added_total += int(s["added_to_buffer"]) model_steps_iter += int(s["model_steps"]) oracle_steps_iter += int(s["oracle_steps"]) last_env_id = s.get("env_id", last_env_id) elif num_workers > 0 and n_eps > 1: # Threaded CPU collection (fallback) batch_stats = self.collector.collect_batch_parallel( n_eps, ) for s in batch_stats: model_wins += int(s["model_won"]) added_total += int(s["added_to_buffer"]) model_steps_iter += int(s["model_steps"]) oracle_steps_iter += int(s["oracle_steps"]) last_env_id = s.get("env_id", last_env_id) else: # Sequential collection (reference behaviour) for _ in range(n_eps): s = self.collector.collect_one_iteration() model_wins += int(s["model_won"]) added_total += int(s["added_to_buffer"]) model_steps_iter += int(s["model_steps"]) oracle_steps_iter += int(s["oracle_steps"]) last_env_id = s.get("env_id", last_env_id) collect_time = time.perf_counter() - collect_start collect_stats = { "env_id": last_env_id, "model_won": model_wins, "added_to_buffer": added_total, "model_steps": model_steps_iter, "oracle_steps": oracle_steps_iter, } # Advance the unified env-step budget. Both model and oracle # rollouts consume real env.step() calls (the oracle rollout # runs in its own env instance in collect_oracle_trajectory), # so both contribute to the budget. iter_env_steps = model_steps_iter + oracle_steps_iter env_steps_total += iter_env_steps # 2. Gradient steps (EMA updated after each step) self.model.train() step_metrics: list[dict[str, float]] = [] train_start = time.perf_counter() for _ in range(cfg.grad_steps_per_iteration): m = self._train_step() step_metrics.append(m) self.ema_model.update(self._raw_model) train_time = time.perf_counter() - train_start iter_time = time.perf_counter() - iter_start # 4. Log n_steps = len(step_metrics) or 1 avg_loss = sum(m["loss"] for m in step_metrics) / n_steps avg_loss_diff = sum(m["loss_diff"] for m in step_metrics) / n_steps avg_loss_aux = sum(m["loss_aux"] for m in step_metrics) / n_steps avg_grad_norm = sum(m["grad_norm"] for m in step_metrics) / n_steps current_lr = ( self.scheduler.get_last_lr()[0] if self.scheduler is not None else self.cfg.dagger_lr ) # Global gate value (how open is the global stream) gate_val = None if hasattr(self._raw_model, "global_gate"): gate_val = torch.sigmoid( self._raw_model.global_gate ).item() # Buffer online fraction buf_total = len(self.buffer) buf_online_frac = ( (buf_total - self.buffer.offline_size) / max(buf_total, 1) if hasattr(self.buffer, "offline_size") else 0.0 ) # Samples per second total_samples = n_steps * cfg.dagger_batch_size samples_per_sec = total_samples / max(train_time, 1e-6) # Env steps per second (uses the iter-summed total, not a # single episode — same bug class as the env-step budget). env_steps_per_sec = iter_env_steps / max(collect_time, 1e-6) metrics = { "diffusion/loss": avg_loss, "diffusion/loss_diff": avg_loss_diff, "diffusion/loss_aux": avg_loss_aux, "train/buffer_size": buf_total, "train/buffer_online_frac": buf_online_frac, "train/model_won": int(collect_stats["model_won"]), "train/added_to_buffer": int( collect_stats["added_to_buffer"] ), "train/episodes_collected": n_eps, "train/model_steps": collect_stats["model_steps"], "train/oracle_steps": collect_stats["oracle_steps"], "train/efficiency_ratio": ( collect_stats["model_steps"] / max(collect_stats["oracle_steps"], 1) ), "train/lr": current_lr, "train/grad_norm": avg_grad_norm, "train/env_steps": env_steps_total, "train/progress": env_steps_total / cfg.total_timesteps, "speed/iter_time_sec": iter_time, "speed/collect_time_sec": collect_time, "speed/train_step_time_sec": train_time, "speed/samples_per_sec": samples_per_sec, "speed/env_steps_per_sec": env_steps_per_sec, "speed/gpu_memory_mb": gpu_memory_mb(), # Keep old perf/ keys for backward compat "perf/iter_time_s": iter_time, "perf/collect_time_s": collect_time, "perf/train_time_s": train_time, "perf/grad_steps_per_sec": ( cfg.grad_steps_per_iteration / max(train_time, 1e-6) ), } if gate_val is not None: metrics["train/global_gate"] = gate_val metrics["model/ema_gate_value"] = gate_val # Model health (every 10 iters to avoid overhead) if iteration % 10 == 0: metrics["model/param_norm"] = compute_param_norm( self._raw_model ) metrics["model/param_drift_from_init"] = compute_param_drift( self._raw_model, self._init_state ) # Profile breakdown from GPU-batched collection _profile = getattr(self.collector, "_last_profile", {}) for _pk, _pv in _profile.items(): metrics[f"profile/{_pk}"] = _pv self.log.log(metrics, step=iteration) # 5. ID eval — triggered when env-step delta crosses threshold if ( cfg.id_eval_every_timesteps > 0 and env_steps_total - last_id_eval_step >= cfg.id_eval_every_timesteps ): eval_model = self.ema_model.make_eval_model(self._raw_model) results = self.evaluator.evaluate( cfg.id_envs, eval_model, cfg.eval_episodes_per_env, cfg, self.device, ) self.log.log_eval(results, step=iteration, prefix="eval_id") mean_id_wr = float(np.mean( [s["win_rate"] for s in results.values()] )) if results else 0.0 self.log.log( { "eval_id/mean_win_rate": mean_id_wr, **{ f"curriculum/{env_id}/win_rate": self.collector.curriculum.win_rate(env_id) for env_id in self.cfg.id_envs }, }, step=iteration, ) last_id_eval_step = env_steps_total # 6. OOD eval — env-step-triggered if ( cfg.ood_eval_every_timesteps > 0 and env_steps_total - last_ood_eval_step >= cfg.ood_eval_every_timesteps ): eval_model = self.ema_model.make_eval_model(self._raw_model) results = self.evaluator.evaluate( cfg.ood_envs, eval_model, cfg.eval_episodes_per_env, cfg, self.device, ) self.log.log_eval(results, step=iteration, prefix="eval_ood") mean_ood_wr = float(np.mean( [s["win_rate"] for s in results.values()] )) if results else 0.0 self.log.log( {"eval_ood/mean_win_rate": mean_ood_wr}, step=iteration, ) last_ood_eval_step = env_steps_total # 7. Checkpoint — env-step-triggered if ( cfg.checkpoint_every_timesteps > 0 and env_steps_total - last_ckpt_step >= cfg.checkpoint_every_timesteps ): self.save_checkpoint(iteration, env_steps_total) last_ckpt_step = env_steps_total iteration += 1 # Final checkpoint if cfg.save_policy: self.save_checkpoint(iteration, env_steps_total) # ── Single gradient step ───────────────────────────────────── def _train_step(self) -> dict[str, float]: """One gradient step on a buffer sample. Uses AMP (mixed precision) when ``cfg.use_amp`` is ``True`` and training on CUDA. Returns: Dict with ``"loss"``, ``"loss_diff"``, ``"loss_aux"``, and ``"grad_norm"`` scalars. """ cfg = self.cfg batch = self.buffer.sample(cfg.dagger_batch_size) if batch is None: return {"loss": 0.0, "loss_diff": 0.0, "loss_aux": 0.0, "grad_norm": 0.0} local_np, global_np, actions_np = batch local_t = torch.from_numpy(local_np).long().to(self.device) global_t = torch.from_numpy(global_np).long().to(self.device) actions_t = torch.from_numpy(actions_np).long().to(self.device) B = actions_t.shape[0] t = torch.rand(B, device=self.device).clamp(1e-5, 1.0 - 1e-5) zt = q_sample( actions_t, t, cfg.mask_token, cfg.pad_token, self._schedule_fn, ) t_discrete = (t * cfg.num_diffusion_steps).long().clamp( 0, cfg.num_diffusion_steps - 1, ) self.optimizer.zero_grad() with torch.amp.autocast("cuda", enabled=self._use_amp): out = self.model(local_t, global_t, zt, t_discrete) loss_diff = mdlm_loss( out["actions"], actions_t, zt, t, cfg.mask_token, cfg.pad_token, self._schedule_fn, weight_clip=cfg.loss_weight_clip, label_smoothing=cfg.label_smoothing, use_importance_weighting=cfg.use_importance_weighting, ) loss_aux = torch.tensor(0.0, device=self.device) if "goal_pred" in out: loss_aux = auxiliary_goal_loss(out["goal_pred"], global_t) loss = loss_diff + cfg.aux_loss_weight * loss_aux self._scaler.scale(loss).backward() self._scaler.unscale_(self.optimizer) grad_norm = nn.utils.clip_grad_norm_( self.model.parameters(), cfg.dagger_grad_clip, ) self._scaler.step(self.optimizer) self._scaler.update() if self.scheduler is not None: self.scheduler.step() return { "loss": loss.item(), "loss_diff": loss_diff.item(), "loss_aux": loss_aux.item(), "grad_norm": grad_norm.item(), } # ── Checkpointing ──────────────────────────────────────────── def save_checkpoint( self, iteration: int, env_steps: int, ) -> None: """Save a training checkpoint. Args: iteration: Current iteration number (for filename + metadata). env_steps: Cumulative env.step() count consumed so far. """ ckpt_dir = Path(self.cfg.checkpoint_dir) ckpt_dir.mkdir(parents=True, exist_ok=True) path = ckpt_dir / f"iter{iteration}.pth" # Capture W&B run ID for seamless resumption wandb_run_id: str | None = None if self.log._use_wandb and self.log._run is not None: wandb_run_id = self.log._run.id state = { "model_state_dict": self._raw_model.state_dict(), "ema_state_dict": self.ema_model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "scheduler_state_dict": ( self.scheduler.state_dict() if self.scheduler is not None else None ), "curriculum_state": self.collector.curriculum.state_dict(), "iteration": iteration, "env_steps": env_steps, "wandb_run_id": wandb_run_id, "rng_states": { "torch": torch.get_rng_state(), "numpy": np.random.get_state(), "python": random.getstate(), }, } try: torch.save(state, path) logger.info(f"Checkpoint saved: {path}") except Exception: logger.error( f"Failed to save checkpoint to {path}", exc_info=True, ) # Save config snapshot alongside checkpoint config_path = ckpt_dir / f"config_iter{iteration}.yaml" try: cfg_dict = { k: v for k, v in vars(self.cfg).items() if not k.startswith("_") } with open(config_path, "w") as f: yaml.dump(cfg_dict, f, default_flow_style=False) except Exception: logger.error("Failed to save config snapshot", exc_info=True) config_path = None # Run eval at checkpoint and save JSON try: eval_model = self.ema_model.make_eval_model(self._raw_model) id_results = self.evaluator.evaluate( self.cfg.id_envs, eval_model, self.cfg.checkpoint_eval_episodes, self.cfg, self.device, ) ood_results = self.evaluator.evaluate( self.cfg.ood_envs, eval_model, self.cfg.checkpoint_eval_episodes, self.cfg, self.device, ) id_winrate = float(np.mean( [s["win_rate"] for s in id_results.values()] )) if id_results else 0.0 ood_winrate = float(np.mean( [s["win_rate"] for s in ood_results.values()] )) if ood_results else 0.0 current_lr = ( self.scheduler.get_last_lr()[0] if self.scheduler is not None else self.cfg.dagger_lr ) training_meta = { "iteration": iteration, "env_steps": env_steps, "total_timesteps": self.cfg.total_timesteps, "lr": current_lr, "dagger_batch_size": self.cfg.dagger_batch_size, "aux_loss_weight": self.cfg.aux_loss_weight, "buffer_size": len(self.buffer), "buffer_capacity": self.cfg.buffer_capacity, "ema_decay": self.cfg.ema_decay, "grad_steps_per_iteration": self.cfg.grad_steps_per_iteration, "episodes_per_iteration": getattr( self.cfg, "episodes_per_iteration", 1 ), "id_winrate": id_winrate, "ood_winrate": ood_winrate, "per_env_id": { env_id: { "win_rate": s["win_rate"], "wins": s.get("wins", 0), "avg_reward": s["avg_reward"], "avg_steps": s["avg_steps"], "n_episodes": s["n_episodes"], } for env_id, s in id_results.items() }, "per_env_ood": { env_id: { "win_rate": s["win_rate"], "wins": s.get("wins", 0), "avg_reward": s["avg_reward"], "avg_steps": s["avg_steps"], "n_episodes": s["n_episodes"], } for env_id, s in ood_results.items() }, } json_path = ckpt_dir / f"eval_iter{iteration}.json" save_eval_json( {"id": id_results, "ood": ood_results}, str(json_path), metadata=training_meta, ) # W&B checkpoint log — per-env step metrics + aggregates self.log.log_eval( id_results, step=iteration, prefix="ckpt_eval_id", ) self.log.log_eval( ood_results, step=iteration, prefix="ckpt_eval_ood", ) self.log.log( { "ckpt_eval/id_winrate": id_winrate, "ckpt_eval/ood_winrate": ood_winrate, }, step=iteration, ) self.log.log_summary({ f"ckpt_{iteration}/id_winrate": id_winrate, f"ckpt_{iteration}/ood_winrate": ood_winrate, }) except Exception: logger.error("Checkpoint eval failed", exc_info=True) # HuggingFace Hub upload (no-op if HF_TOKEN or hub_run_id not set) try: from scripts.hf_upload import maybe_upload_checkpoint maybe_upload_checkpoint( str(ckpt_dir), getattr(self.cfg, "hub_run_id", None), getattr(self.cfg, "hub_repo_id", None), ) except Exception: logger.error("HF Hub upload failed", exc_info=True) # W&B artifact upload self.log.log_checkpoint_artifact( checkpoint_path=str(path), config_path=str(config_path) if config_path else None, iteration=iteration, metadata={ "iteration": iteration, "buffer_size": len(self.buffer), }, ) def load_checkpoint(self, path: str) -> tuple[int, int]: """Load a training checkpoint. Args: path: Path to ``.pth`` checkpoint file. Returns: ``(start_iter, start_env_steps)`` — the iteration and cumulative env-step count to resume from. """ ckpt = torch.load( path, map_location=self.device, weights_only=False, ) self._raw_model.load_state_dict(ckpt["model_state_dict"]) self.ema_model.load_state_dict(ckpt["ema_state_dict"]) self.optimizer.load_state_dict(ckpt["optimizer_state_dict"]) if ( self.scheduler is not None and ckpt.get("scheduler_state_dict") is not None ): self.scheduler.load_state_dict(ckpt["scheduler_state_dict"]) if "curriculum_state" in ckpt: self.collector.curriculum.load_state_dict( ckpt["curriculum_state"], ) # Restore RNG states (best-effort) rng = ckpt.get("rng_states", {}) try: if "torch" in rng: torch.set_rng_state(rng["torch"]) if "numpy" in rng: np.random.set_state(rng["numpy"]) if "python" in rng: random.setstate(rng["python"]) except Exception: logger.warning( "RNG state restore failed; continuing with fresh state", ) iteration = ckpt.get("iteration", 0) env_steps = ckpt.get("env_steps", 0) resume_from = iteration + 1 logger.info( f"Resumed from checkpoint: {path} (iter {iteration}, " f"env_steps={env_steps}), starting at iter {resume_from}" ) return resume_from, env_steps def run_dagger( cfg: SimpleNamespace, checkpoint_path: str | None, no_warm_start: bool, ) -> None: """DAgger online training loop.""" make_run_dir(cfg, tag="dagger") device = cfg.device logger.info(f"DAgger training on {device}") raw_model = make_model(cfg).to(device) # EMA and eval always use the raw (uncompiled) model — deep-copying # a compiled model breaks FX tracing. ema = ModelEMA(raw_model, decay=cfg.ema_decay) # torch.compile: wrap for training only; shares parameters with raw_model model = try_compile(raw_model, cfg) optimizer = torch.optim.AdamW( raw_model.parameters(), lr=cfg.dagger_lr, weight_decay=cfg.weight_decay, ) buffer = ReplayBuffer(cfg.buffer_capacity, cfg.seq_len, cfg.pad_token) curriculum = DynamicCurriculum( cfg.id_envs, cfg.curriculum_queue_size, cfg.curriculum_preseed, ) # Seed buffer with some oracle data for i, env_id in enumerate(cfg.id_envs): for s in range(3): traj = collect_oracle_trajectory(env_id, seed=i * 100 + s, cfg=cfg) if traj is not None: buffer.add(traj) logger.info(f"Buffer seeded with {len(buffer)} windows") # If resuming, extract W&B run ID from checkpoint before Logger init # so the same W&B run is continued (curve continuity). if checkpoint_path and not no_warm_start: resume_id = getattr(cfg, "wandb_resume_id", None) if not resume_id: ckpt_peek = torch.load( checkpoint_path, map_location="cpu", weights_only=False, ) saved_id = ckpt_peek.get("wandb_run_id") if saved_id: cfg.wandb_resume_id = saved_id logger.info( f"W&B run ID from checkpoint: {saved_id}" ) del ckpt_peek # DataCollector uses raw_model for eval copies (not compiled) collector = DataCollector(ema, raw_model, buffer, curriculum, cfg, device) evaluator = Evaluator() log = Logger(cfg) trainer = Trainer( model, ema, optimizer, None, buffer, collector, evaluator, log, cfg, device, raw_model=raw_model, ) start_iter = 0 start_env_steps = 0 if checkpoint_path and not no_warm_start: start_iter, start_env_steps = trainer.load_checkpoint( checkpoint_path, ) trainer.train( start_iter=start_iter, start_env_steps=start_env_steps, ) log.finish()