| """Data collection with DAgger and oracle replay. |
| |
| Implements model episode rollout with replanning and DAgger-style |
| data collection using the BFS oracle and efficiency filter. |
| Supports parallel episode collection via ``ThreadPoolExecutor``. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import copy |
| import logging |
| import os |
| import random |
| import time |
| from concurrent.futures import ThreadPoolExecutor |
| from typing import TYPE_CHECKING |
| from types import SimpleNamespace |
|
|
| import numpy as np |
| import torch |
|
|
| from src.buffer import ReplayBuffer |
| from src.curriculum import DynamicCurriculum, efficiency_filter |
| from src.diffusion.sampling import greedy_sample, remdm_sample |
| from src.envs.minihack_env import collect_oracle_trajectory, make_env |
|
|
| if TYPE_CHECKING: |
| from src.models.denoiser import ModelEMA |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @torch.no_grad() |
| def run_model_episode( |
| model: torch.nn.Module, |
| env_id: str, |
| cfg: SimpleNamespace, |
| device: torch.device | str, |
| seed: int | None = None, |
| max_steps: int = 500, |
| des_file: str | None = None, |
| blind_global: bool = False, |
| stochastic: bool = False, |
| ) -> dict: |
| """Roll out the diffusion model on a single episode. |
| |
| Maintains a ``seq_len``-length plan and replans every |
| ``cfg.replan_every`` steps. |
| |
| Args: |
| model: Denoising model (eval mode). |
| env_id: MiniHack registry ID. |
| cfg: Config namespace. |
| device: Torch device. |
| seed: Optional RNG seed. |
| max_steps: Maximum episode length. |
| des_file: Optional ``.des`` file content for custom scenarios. |
| blind_global: If ``True``, zero out global map (local-only ablation). |
| stochastic: If ``True``, use stochastic ReMDM sampling (evaluation). |
| If ``False`` (default), use greedy argmax (DAgger collection). |
| |
| Returns: |
| Dict with ``"local"`` ``[T,9,9]``, ``"global"`` ``[T,21,79]``, |
| ``"actions"`` ``[T]``, ``"won"`` bool, ``"steps"`` int, |
| ``"total_reward"`` float, ``"seed"`` int. |
| """ |
| if seed is None: |
| seed = random.randint(0, 2**31 - 1) |
|
|
| _use_stochastic = stochastic |
|
|
| env = make_env(env_id, des_file, cfg) |
| try: |
| (local, glb), _info = env.reset(seed=seed) |
|
|
| locals_list = [local] |
| globals_list = [glb] |
| actions_list: list[int] = [] |
| won = False |
| total_reward = 0.0 |
| plan: torch.Tensor | None = None |
| step_in_plan = 0 |
|
|
| model.eval() |
| for step_idx in range(max_steps): |
| |
| if plan is None or step_in_plan >= cfg.replan_every: |
| local_t = torch.from_numpy( |
| local[np.newaxis] |
| ).long().to(device) |
| glb_t = torch.from_numpy( |
| glb[np.newaxis] |
| ).long().to(device) |
| if _use_stochastic: |
| plan = remdm_sample( |
| model, local_t, glb_t, cfg, device, |
| physics_aware=getattr( |
| cfg, "physics_aware_sampling", False, |
| ), |
| blind_global=blind_global, |
| ) |
| else: |
| plan = greedy_sample( |
| model, local_t, glb_t, cfg, device, |
| blind_global=blind_global, |
| ) |
| step_in_plan = 0 |
|
|
| action = plan[0, step_in_plan].item() |
| action = max(0, min(action, cfg.action_dim - 1)) |
| actions_list.append(action) |
| step_in_plan += 1 |
|
|
| (local, glb), reward, terminated, truncated, info = env.step( |
| action, |
| ) |
| total_reward += reward |
| locals_list.append(local) |
| globals_list.append(glb) |
|
|
| if info.get("won", False): |
| won = True |
| if terminated or truncated: |
| break |
| finally: |
| env.close() |
|
|
| |
| locals_arr = np.stack(locals_list[:-1], axis=0).astype(np.int16) |
| globals_arr = np.stack(globals_list[:-1], axis=0).astype(np.int16) |
| actions_arr = np.array(actions_list, dtype=np.int64) |
|
|
| return { |
| "local": locals_arr, |
| "global": globals_arr, |
| "actions": actions_arr, |
| "won": won, |
| "steps": len(actions_list), |
| "total_reward": total_reward, |
| "seed": seed, |
| } |
|
|
|
|
| def _collect_episode_thread( |
| model: torch.nn.Module, |
| env_id: str, |
| seed: int, |
| cfg: SimpleNamespace, |
| ) -> dict | None: |
| """Thread worker: run one paired (model + oracle) episode. |
| |
| Both NLE (C code) and PyTorch CPU inference release the GIL, |
| so true parallelism is achieved with threads. Each call uses |
| its own model copy and env instance. |
| |
| Args: |
| model: CPU-resident eval-mode model (thread's own copy). |
| env_id: MiniHack environment ID. |
| seed: RNG seed for the episode. |
| cfg: Config namespace. |
| |
| Returns: |
| Stats dict or ``None`` on failure. |
| """ |
| try: |
| model_result = run_model_episode( |
| model, env_id, cfg, "cpu", seed, |
| ) |
| oracle_result = collect_oracle_trajectory(env_id, seed, cfg) |
| oracle_steps = ( |
| len(oracle_result["actions"]) if oracle_result else 999 |
| ) |
| return { |
| "env_id": env_id, |
| "seed": seed, |
| "model_won": model_result["won"], |
| "model_steps": model_result["steps"], |
| "oracle_steps": oracle_steps, |
| "oracle_result": oracle_result, |
| } |
| except Exception: |
| logger.error( |
| f"Thread worker failed for {env_id} seed={seed}", exc_info=True, |
| ) |
| return None |
|
|
|
|
| class DataCollector: |
| """DAgger-style data collector. |
| |
| Each iteration: sample an environment from the curriculum, run the |
| model, run the oracle on the same seed, apply efficiency filter, and |
| optionally add the oracle trajectory to the buffer. |
| |
| Supports parallel episode collection via ``cfg.num_collection_workers``. |
| |
| Uses a live reference to the ``ModelEMA`` object so the collector |
| always uses the latest EMA weights (synced before each rollout). |
| |
| Args: |
| ema: EMA tracker holding shadow weights. |
| model: Training model (architecture template for EMA snapshot). |
| buffer: Replay buffer to populate. |
| curriculum: Dynamic environment curriculum. |
| cfg: Config namespace. |
| device: Torch device. |
| """ |
|
|
| def __init__( |
| self, |
| ema: "ModelEMA", |
| model: torch.nn.Module, |
| buffer: ReplayBuffer, |
| curriculum: DynamicCurriculum, |
| cfg: SimpleNamespace, |
| device: torch.device | str, |
| ) -> None: |
| self._ema = ema |
| self._model_template = model |
| |
| self.ema_model = ema.make_eval_model(model) |
| self.buffer = buffer |
| self.curriculum = curriculum |
| self.cfg = cfg |
| self.device = device |
| self._num_workers = getattr(cfg, "num_collection_workers", 0) |
| self._last_profile: dict[str, float] = {} |
| self._thread_pool: ThreadPoolExecutor | None = None |
| self._thread_models: list[torch.nn.Module] = [] |
| if self._num_workers > 0: |
| n = min(self._num_workers, os.cpu_count() or 4) |
| self._thread_pool = ThreadPoolExecutor(max_workers=n) |
| |
| for _ in range(n): |
| m = copy.deepcopy(model).cpu() |
| m.eval() |
| self._thread_models.append(m) |
|
|
| def _sync_ema(self) -> None: |
| """Copy latest EMA shadow weights into the eval model.""" |
| self._ema.apply_to(self.ema_model) |
| self.ema_model.eval() |
|
|
| def collect_one_iteration(self) -> dict: |
| """Run one DAgger collection iteration (single episode). |
| |
| Returns: |
| Stats dict with ``"env_id"``, ``"model_won"``, |
| ``"model_steps"``, ``"oracle_steps"``, |
| ``"added_to_buffer"`` keys. |
| """ |
| self._sync_ema() |
| env_id = self.curriculum.sample_env() |
| seed = random.randint(0, 2**31 - 1) |
|
|
| |
| model_result = run_model_episode( |
| self.ema_model, env_id, self.cfg, self.device, seed, |
| ) |
|
|
| |
| oracle_result = collect_oracle_trajectory( |
| env_id, seed, self.cfg, |
| ) |
| oracle_steps = ( |
| len(oracle_result["actions"]) if oracle_result else 999 |
| ) |
|
|
| |
| add = efficiency_filter( |
| model_result["won"], |
| model_result["steps"], |
| oracle_steps, |
| self.cfg.efficiency_multiplier, |
| ) |
|
|
| if add and oracle_result is not None: |
| self.buffer.add(oracle_result) |
|
|
| self.curriculum.update(env_id, model_result["won"]) |
|
|
| return { |
| "env_id": env_id, |
| "model_won": model_result["won"], |
| "model_steps": model_result["steps"], |
| "oracle_steps": oracle_steps, |
| "added_to_buffer": add and oracle_result is not None, |
| } |
|
|
| def collect_batch_parallel( |
| self, n_episodes: int, |
| ) -> list[dict]: |
| """Collect multiple episodes in parallel using threads. |
| |
| Both NLE env calls and PyTorch CPU inference release the GIL, |
| enabling true parallelism. Each thread uses a pre-allocated |
| CPU model copy. Weights are synced from EMA once per call. |
| |
| Args: |
| n_episodes: Number of episodes to collect. |
| |
| Returns: |
| List of per-episode stats dicts. |
| """ |
| assert self._thread_pool is not None, ( |
| "collect_batch_parallel requires num_collection_workers > 0" |
| ) |
| self._sync_ema() |
|
|
| |
| ema_sd = self.ema_model.state_dict() |
| cpu_sd = {k: v.cpu() for k, v in ema_sd.items()} |
| for tm in self._thread_models: |
| tm.load_state_dict(cpu_sd) |
| tm.eval() |
|
|
| |
| tasks = [] |
| for _ in range(n_episodes): |
| env_id = self.curriculum.sample_env() |
| seed = random.randint(0, 2**31 - 1) |
| tasks.append((env_id, seed)) |
|
|
| |
| n_models = len(self._thread_models) |
| futures = [] |
| for i, (env_id, seed) in enumerate(tasks): |
| model = self._thread_models[i % n_models] |
| f = self._thread_pool.submit( |
| _collect_episode_thread, model, env_id, seed, self.cfg, |
| ) |
| futures.append(f) |
|
|
| results = [f.result() for f in futures] |
|
|
| |
| stats_list = [] |
| for res in results: |
| if res is None: |
| continue |
|
|
| add = efficiency_filter( |
| res["model_won"], |
| res["model_steps"], |
| res["oracle_steps"], |
| self.cfg.efficiency_multiplier, |
| ) |
|
|
| oracle_result = res["oracle_result"] |
| if add and oracle_result is not None: |
| self.buffer.add(oracle_result) |
|
|
| self.curriculum.update(res["env_id"], res["model_won"]) |
|
|
| stats_list.append({ |
| "env_id": res["env_id"], |
| "model_won": res["model_won"], |
| "model_steps": res["model_steps"], |
| "oracle_steps": res["oracle_steps"], |
| "added_to_buffer": add and oracle_result is not None, |
| }) |
|
|
| return stats_list |
|
|
| |
|
|
| def collect_batch_gpu(self, n_episodes: int) -> list[dict]: |
| """Collect episodes with GPU-batched model inference. |
| |
| Runs all model episodes with batched GPU forward passes |
| (B=n_episodes instead of B=1), then runs oracle rollouts |
| in parallel threads for efficiency filtering. |
| |
| Args: |
| n_episodes: Number of episodes to collect. |
| |
| Returns: |
| List of per-episode stats dicts. |
| """ |
| self._sync_ema() |
| cfg = self.cfg |
| self._last_profile = {} |
|
|
| tasks = [ |
| (self.curriculum.sample_env(), random.randint(0, 2**31 - 1)) |
| for _ in range(n_episodes) |
| ] |
|
|
| |
| t0 = time.perf_counter() |
| model_results = self._run_model_episodes_batched(tasks) |
| model_time = time.perf_counter() - t0 |
|
|
| |
| t0 = time.perf_counter() |
| n_workers = min(n_episodes, os.cpu_count() or 4) |
| with ThreadPoolExecutor(max_workers=n_workers) as pool: |
| oracle_futures = [ |
| pool.submit( |
| collect_oracle_trajectory, env_id, seed, cfg, |
| ) |
| for env_id, seed in tasks |
| ] |
| oracle_results = [f.result() for f in oracle_futures] |
| oracle_time = time.perf_counter() - t0 |
|
|
| |
| stats_list: list[dict] = [] |
| for (env_id, _seed), m_res, o_res in zip( |
| tasks, model_results, oracle_results, |
| ): |
| oracle_steps = ( |
| len(o_res["actions"]) if o_res else 999 |
| ) |
| add = efficiency_filter( |
| m_res["won"], |
| m_res["steps"], |
| oracle_steps, |
| cfg.efficiency_multiplier, |
| ) |
| if add and o_res is not None: |
| self.buffer.add(o_res) |
| self.curriculum.update(env_id, m_res["won"]) |
| stats_list.append({ |
| "env_id": env_id, |
| "model_won": m_res["won"], |
| "model_steps": m_res["steps"], |
| "oracle_steps": oracle_steps, |
| "added_to_buffer": add and o_res is not None, |
| }) |
|
|
| self._last_profile["model_rollout_sec"] = model_time |
| self._last_profile["oracle_rollout_sec"] = oracle_time |
| return stats_list |
|
|
| @torch.no_grad() |
| def _run_model_episodes_batched( |
| self, |
| tasks: list[tuple[str, int]], |
| ) -> list[dict]: |
| """Run model episodes with batched GPU forward passes. |
| |
| Creates one env per episode, steps them in lockstep, and |
| batches all replanning into single GPU forward passes |
| (B = number of active envs needing a replan). |
| |
| Args: |
| tasks: List of ``(env_id, seed)`` pairs. |
| |
| Returns: |
| List of trajectory dicts matching |
| ``run_model_episode`` output format. |
| """ |
| cfg = self.cfg |
| device = self.device |
| model = self.ema_model |
| model.eval() |
| n = len(tasks) |
| max_steps = 500 |
| K = getattr( |
| cfg, "diffusion_steps_collect", cfg.diffusion_steps_eval, |
| ) |
| cs = cfg.crop_size |
|
|
| |
| envs: list = [] |
| cur_local = np.zeros((n, cs, cs), dtype=np.int16) |
| cur_global = np.zeros( |
| (n, cfg.map_h, cfg.map_w), dtype=np.int16, |
| ) |
|
|
| t_reset = time.perf_counter() |
| for i, (env_id, seed) in enumerate(tasks): |
| env = make_env(env_id, None, cfg) |
| (local, glb), _ = env.reset(seed=seed) |
| envs.append(env) |
| cur_local[i] = local |
| cur_global[i] = glb |
| reset_time = time.perf_counter() - t_reset |
|
|
| |
| obs_local = np.zeros( |
| (n, max_steps + 1, cs, cs), dtype=np.int16, |
| ) |
| obs_global = np.zeros( |
| (n, max_steps + 1, cfg.map_h, cfg.map_w), |
| dtype=np.int16, |
| ) |
| act_buf = np.zeros((n, max_steps), dtype=np.int64) |
| obs_local[:, 0] = cur_local |
| obs_global[:, 0] = cur_global |
|
|
| |
| plans = np.zeros((n, cfg.seq_len), dtype=np.int64) |
| step_in_plan = np.zeros(n, dtype=np.int32) |
| need_replan = np.ones(n, dtype=bool) |
| done = np.zeros(n, dtype=bool) |
| won = np.zeros(n, dtype=bool) |
| total_reward = np.zeros(n, dtype=np.float64) |
| n_steps = np.zeros(n, dtype=np.int32) |
|
|
| inference_time = 0.0 |
| env_step_time = 0.0 |
|
|
| try: |
| for _ in range(max_steps): |
| |
| replan_idx = np.where( |
| need_replan & ~done, |
| )[0] |
| if len(replan_idx) > 0: |
| t0 = time.perf_counter() |
| local_t = torch.from_numpy( |
| cur_local[replan_idx], |
| ).long().to(device) |
| glb_t = torch.from_numpy( |
| cur_global[replan_idx], |
| ).long().to(device) |
| batch_plans = greedy_sample( |
| model, local_t, glb_t, cfg, device, |
| num_steps=K, |
| ).cpu().numpy() |
| plans[replan_idx] = batch_plans |
| step_in_plan[replan_idx] = 0 |
| need_replan[replan_idx] = False |
| inference_time += time.perf_counter() - t0 |
|
|
| |
| t0 = time.perf_counter() |
| any_active = False |
| for i in range(n): |
| if done[i]: |
| continue |
| any_active = True |
|
|
| action = int(plans[i, step_in_plan[i]]) |
| action = max( |
| 0, min(action, cfg.action_dim - 1), |
| ) |
| act_buf[i, n_steps[i]] = action |
| step_in_plan[i] += 1 |
| n_steps[i] += 1 |
|
|
| if step_in_plan[i] >= cfg.replan_every: |
| need_replan[i] = True |
|
|
| obs, reward, term, trunc, info = ( |
| envs[i].step(action) |
| ) |
| local, glb = obs |
| total_reward[i] += reward |
| cur_local[i] = local |
| cur_global[i] = glb |
| obs_local[i, n_steps[i]] = local |
| obs_global[i, n_steps[i]] = glb |
|
|
| if info.get("won", False): |
| won[i] = True |
| if term or trunc: |
| done[i] = True |
| env_step_time += time.perf_counter() - t0 |
|
|
| if not any_active: |
| break |
| finally: |
| for env in envs: |
| env.close() |
|
|
| |
| results: list[dict] = [] |
| for i in range(n): |
| T = int(n_steps[i]) |
| results.append({ |
| "local": obs_local[i, :T].copy(), |
| "global": obs_global[i, :T].copy(), |
| "actions": act_buf[i, :T].copy(), |
| "won": bool(won[i]), |
| "steps": T, |
| "total_reward": float(total_reward[i]), |
| "seed": tasks[i][1], |
| }) |
|
|
| self._last_profile.update({ |
| "env_reset_sec": reset_time, |
| "gpu_inference_sec": inference_time, |
| "env_step_sec": env_step_time, |
| }) |
| return results |
|
|