| from __future__ import annotations |
|
|
| import argparse |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import torch |
| from torch.nn.utils import clip_grad_norm_ |
| from tqdm import trange |
|
|
| from minidreamer.config import ensure_run_dirs, load_config, merge_dicts, save_config |
| from minidreamer.data.collect_random import collect_bootstrap_dataset |
| from minidreamer.data.replay_buffer import ReplayBuffer |
| from minidreamer.evaluation import evaluate_random_policy, evaluate_world_model |
| from minidreamer.envs.make_env import make_env_from_config |
| from minidreamer.models.world_model import WorldModel |
| from minidreamer.planning.cem import DiscreteCEMPlanner |
| from minidreamer.planning.evaluate_planner import evaluate_planner |
| from minidreamer.serialization import save_world_model_checkpoint |
| from minidreamer.utils.common import get_device, seed_everything, write_json, write_jsonl |
|
|
|
|
| def train_world_model_updates( |
| model: WorldModel, |
| replay: ReplayBuffer, |
| optimizer: torch.optim.Optimizer, |
| config: dict[str, Any], |
| num_updates: int, |
| device: torch.device, |
| ) -> list[dict[str, float]]: |
| if num_updates <= 0: |
| return [] |
| model.train() |
| logs: list[dict[str, float]] = [] |
| progress = trange(num_updates, desc="world-model-updates", leave=False) |
| for _ in progress: |
| batch = ReplayBuffer.batch_to_torch(replay.sample_sequences(split="train"), device=device) |
| losses = model.compute_losses(batch, config) |
| optimizer.zero_grad(set_to_none=True) |
| losses["loss"].backward() |
| clip_grad_norm_(model.parameters(), float(config["training"].get("grad_clip_norm", 100.0))) |
| optimizer.step() |
| log_row = { |
| "loss": float(losses["loss"].detach().cpu()), |
| "reward_loss": float(losses["reward_loss"].cpu()), |
| "done_loss": float(losses["done_loss"].cpu()), |
| "kl_loss": float(losses["kl_loss"].cpu()), |
| "recon_loss": float(losses["recon_loss"].cpu()), |
| } |
| logs.append(log_row) |
| progress.set_postfix({key: f"{value:.3f}" for key, value in log_row.items()}) |
| return logs |
|
|
|
|
| def optimizer_to_device(optimizer: torch.optim.Optimizer, device: torch.device) -> None: |
| for state in optimizer.state.values(): |
| for key, value in state.items(): |
| if torch.is_tensor(value): |
| state[key] = value.to(device) |
|
|
|
|
| def load_training_state( |
| checkpoint_path: str | Path, |
| config: dict[str, Any], |
| action_dim: int, |
| device: torch.device, |
| ) -> tuple[dict[str, Any], WorldModel, torch.optim.Optimizer, dict[str, Any]]: |
| payload = torch.load(checkpoint_path, map_location=device, weights_only=False) |
| resolved_config = merge_dicts(payload["config"], config) |
| model = WorldModel.from_config(resolved_config, action_dim=action_dim).to(device) |
| model.load_state_dict(payload["model_state"]) |
| optimizer = torch.optim.Adam(model.parameters(), lr=float(resolved_config["training"]["lr"])) |
| optimizer_state = payload.get("optimizer_state") |
| if optimizer_state is not None: |
| optimizer.load_state_dict(optimizer_state) |
| optimizer_to_device(optimizer, device) |
| return resolved_config, model, optimizer, payload.get("metadata", {}) |
|
|
|
|
| def find_existing_run_artifacts(base_dir: str | Path) -> list[Path]: |
| base = Path(base_dir) |
| if not base.exists(): |
| return [] |
|
|
| artifact_files = [ |
| base / "metrics" / "run_summary.json", |
| base / "metrics" / "train_metrics.jsonl", |
| base / "metrics" / "eval_metrics.jsonl", |
| base / "checkpoints" / "world_model_latest.pt", |
| base / "replay" / "metadata.json", |
| ] |
| found = [path for path in artifact_files if path.exists()] |
| if found: |
| return found |
|
|
| for subdir_name in ("checkpoints", "metrics", "replay"): |
| subdir = base / subdir_name |
| if subdir.exists(): |
| for child in subdir.iterdir(): |
| found.append(child) |
| break |
| return found |
|
|
|
|
| def collect_planner_steps( |
| env, |
| replay: ReplayBuffer, |
| model: WorldModel, |
| planner: DiscreteCEMPlanner, |
| num_steps: int, |
| random_action_fraction: float, |
| rng: np.random.Generator, |
| ) -> dict[str, int]: |
| collected_steps = 0 |
| episodes = 0 |
| success_episodes = 0 |
| model.eval() |
| while collected_steps < num_steps: |
| obs, _ = env.reset() |
| observations = [obs] |
| actions: list[int] = [] |
| rewards: list[float] = [] |
| terminated_flags: list[float] = [] |
| truncated_flags: list[float] = [] |
| done_flags: list[float] = [] |
| terminated = False |
| truncated = False |
|
|
| with torch.no_grad(): |
| state = model.posterior_step(model.initial_state(1), None, obs, sample=False) |
| while not (terminated or truncated): |
| if rng.random() < random_action_fraction: |
| action = int(env.action_space.sample()) |
| else: |
| action = planner.plan(state).action |
| obs, reward, terminated, truncated, _ = env.step(action) |
| actions.append(action) |
| rewards.append(float(reward)) |
| terminated_flags.append(float(terminated)) |
| truncated_flags.append(float(truncated)) |
| done_flags.append(float(terminated or truncated)) |
| observations.append(obs) |
| collected_steps += 1 |
| if terminated or truncated: |
| break |
| state = model.posterior_step(state, action, obs, sample=False) |
|
|
| replay.add_episode( |
| obs=np.asarray(observations, dtype=np.float32), |
| actions=np.asarray(actions, dtype=np.int64), |
| rewards=np.asarray(rewards, dtype=np.float32), |
| terminated=np.asarray(terminated_flags, dtype=np.float32), |
| truncated=np.asarray(truncated_flags, dtype=np.float32), |
| done=np.asarray(done_flags, dtype=np.float32), |
| ) |
| episodes += 1 |
| success_episodes += int(bool(terminated and np.sum(rewards) > 0.0)) |
| return { |
| "env_steps": collected_steps, |
| "episodes": episodes, |
| "success_episodes": success_episodes, |
| } |
|
|
|
|
| def run_training( |
| config: dict[str, Any], |
| output_dir: str | Path, |
| replay_dir: str | Path | None = None, |
| resume_checkpoint: str | Path | None = None, |
| allow_overwrite_existing_output: bool = False, |
| ) -> dict[str, Any]: |
| seed = config.get("project", {}).get("seed", 0) |
| seed_everything(seed) |
| existing_artifacts = find_existing_run_artifacts(output_dir) |
| if existing_artifacts and resume_checkpoint is None and not allow_overwrite_existing_output: |
| preview = ", ".join(str(path) for path in existing_artifacts[:3]) |
| raise FileExistsError( |
| f"Refusing to overwrite existing run directory '{output_dir}'. " |
| f"Found existing artifacts: {preview}. " |
| "Choose a new --output-dir, resume with --resume-checkpoint, " |
| "or pass --allow-overwrite-existing-output to overwrite intentionally." |
| ) |
| run_dirs = ensure_run_dirs(output_dir) |
| device = get_device(config.get("training", {}).get("device")) |
|
|
| env = make_env_from_config(config, seed=seed) |
| action_dim = env.action_space.n |
| env.close() |
|
|
| if replay_dir is not None and Path(replay_dir).exists(): |
| replay = ReplayBuffer.load(replay_dir) |
| collection_summary = {"replay_loaded": replay.summary()} |
| else: |
| replay, collection_summary = collect_bootstrap_dataset(config, output_dir=run_dirs["replay"], seed=seed) |
|
|
| resume_metadata: dict[str, Any] = {} |
| if resume_checkpoint is not None: |
| config, model, optimizer, resume_metadata = load_training_state( |
| checkpoint_path=resume_checkpoint, |
| config=config, |
| action_dim=action_dim, |
| device=device, |
| ) |
| else: |
| model = WorldModel.from_config(config, action_dim=action_dim).to(device) |
| optimizer = torch.optim.Adam(model.parameters(), lr=float(config["training"]["lr"])) |
|
|
| save_config(config, run_dirs["base"] / "resolved_config.yaml") |
| training_logs: list[dict[str, float]] = [] |
| evaluation_logs: list[dict[str, float]] = [] |
|
|
| train_collect_ratio = float(config["collection"].get("train_collect_ratio", 1.0)) |
| total_updates_budget = int(config["training"]["train_steps"]) |
| if resume_checkpoint is not None: |
| updates_done = int(resume_metadata.get("updates_done", 0)) |
| checkpoint_env_steps = int(resume_metadata.get("env_steps", 0)) |
| if replay.env_steps > checkpoint_env_steps and updates_done < total_updates_budget: |
| collect_steps_per_iteration = max(1, int(config["collection"].get("collect_steps_per_iteration", 1))) |
| per_iteration_updates = int( |
| config["collection"].get( |
| "gradient_updates_per_iteration", |
| round(collect_steps_per_iteration * train_collect_ratio), |
| ) |
| ) |
| missed_iterations = max(0, round((replay.env_steps - checkpoint_env_steps) / collect_steps_per_iteration)) |
| catch_up_updates = min(total_updates_budget - updates_done, per_iteration_updates * missed_iterations) |
| catch_up_logs = train_world_model_updates(model, replay, optimizer, config, catch_up_updates, device) |
| training_logs.extend(catch_up_logs) |
| updates_done += len(catch_up_logs) |
| else: |
| initial_updates = min(total_updates_budget, max(1, int(round(replay.env_steps * train_collect_ratio)))) |
| training_logs.extend(train_world_model_updates(model, replay, optimizer, config, initial_updates, device)) |
| updates_done = len(training_logs) |
|
|
| comparison_budgets = config.get("comparison", {}).get("env_steps", [replay.env_steps]) |
| target_env_steps = int(max(comparison_budgets)) |
| rng = np.random.default_rng(seed) |
| env = make_env_from_config(config, seed=seed) |
| planner = DiscreteCEMPlanner.from_config(model, env.action_space.n, config) |
| eval_every_steps = int(config["evaluation"].get("eval_every_env_steps", target_env_steps)) |
| next_eval_step = replay.env_steps |
|
|
| while replay.env_steps < target_env_steps and updates_done < total_updates_budget: |
| collect_steps = min( |
| int(config["collection"]["collect_steps_per_iteration"]), |
| target_env_steps - replay.env_steps, |
| ) |
| collection_row = collect_planner_steps( |
| env, |
| replay, |
| model, |
| planner, |
| num_steps=collect_steps, |
| random_action_fraction=float(config["collection"].get("random_action_fraction_after_planner", 0.0)), |
| rng=rng, |
| ) |
| updates = int(config["collection"].get("gradient_updates_per_iteration", round(collection_row["env_steps"] * train_collect_ratio))) |
| updates = min(updates, total_updates_budget - updates_done) |
| training_logs.extend(train_world_model_updates(model, replay, optimizer, config, updates, device)) |
| updates_done = len(training_logs) |
| replay.save(run_dirs["replay"]) |
|
|
| if replay.env_steps >= next_eval_step: |
| world_model_metrics = evaluate_world_model(config, model, replay, split="val", max_episodes=10) |
| planner_metrics = evaluate_planner(config, model, episodes=min(10, config["evaluation"]["episodes"]), seed=seed) |
| random_metrics = evaluate_random_policy(config, episodes=min(10, config["evaluation"]["episodes"]), seed=seed) |
| eval_row = { |
| "env_steps": replay.env_steps, |
| "updates_done": updates_done, |
| **{f"world_model/{key}": value for key, value in world_model_metrics.items()}, |
| **{f"planner/{key}": value for key, value in planner_metrics.items()}, |
| **{f"random/{key}": value for key, value in random_metrics.items()}, |
| } |
| evaluation_logs.append(eval_row) |
| next_eval_step += eval_every_steps |
| save_world_model_checkpoint( |
| run_dirs["checkpoints"] / f"world_model_env_steps_{replay.env_steps}.pt", |
| model, |
| config, |
| optimizer=optimizer, |
| metadata={"env_steps": replay.env_steps, "updates_done": updates_done}, |
| ) |
|
|
| env.close() |
| save_world_model_checkpoint( |
| run_dirs["checkpoints"] / "world_model_latest.pt", |
| model, |
| config, |
| optimizer=optimizer, |
| metadata={"env_steps": replay.env_steps, "updates_done": updates_done}, |
| ) |
| write_json(run_dirs["metrics"] / "collection_summary.json", collection_summary) |
| write_jsonl(run_dirs["metrics"] / "train_metrics.jsonl", training_logs) |
| write_jsonl(run_dirs["metrics"] / "eval_metrics.jsonl", evaluation_logs) |
| summary = { |
| "replay": replay.summary(), |
| "updates_done": updates_done, |
| "device": str(device), |
| } |
| write_json(run_dirs["metrics"] / "run_summary.json", summary) |
| return summary |
|
|
|
|
| def build_arg_parser() -> argparse.ArgumentParser: |
| parser = argparse.ArgumentParser(description="Train the MiniDreamer world model.") |
| parser.add_argument("--config", type=Path, required=True) |
| parser.add_argument("--output-dir", type=Path, required=True) |
| parser.add_argument("--replay-dir", type=Path, default=None, help="Optional existing replay directory.") |
| parser.add_argument("--resume-checkpoint", type=Path, default=None, help="Optional checkpoint to resume from.") |
| parser.add_argument( |
| "--allow-overwrite-existing-output", |
| action="store_true", |
| help="Allow overwriting an existing run directory when not resuming.", |
| ) |
| return parser |
|
|
|
|
| def main() -> None: |
| parser = build_arg_parser() |
| args = parser.parse_args() |
| config = load_config(args.config) |
| summary = run_training( |
| config, |
| args.output_dir, |
| replay_dir=args.replay_dir, |
| resume_checkpoint=args.resume_checkpoint, |
| allow_overwrite_existing_output=args.allow_overwrite_existing_output, |
| ) |
| print(summary) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|