remdm-minihack / main.py
MathisW78's picture
Demo notebook payload (source + checkpoint + assets)
f748552 verified
from __future__ import annotations
import argparse
import logging
import random
from pathlib import Path
from typing import Any
import numpy as np
import torch
from src.config import load_config
from src.planners.baselines import ALL_BASELINE_ALGOS, run_baselines
from src.planners.logging import Logger
from src.planners.offline import run_offline
from src.planners.online import run_dagger
from src.planners.inference import run_inference
from src.planners.collect_oracle import run_collect
from src.planners.smoke import run_smoke
# =============================================================================
# Logging
# =============================================================================
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
# =============================================================================
# Utils
# =============================================================================
def _parse_overrides(extras: list[str]) -> dict[str, Any]:
return {
k.lstrip("-"): v
for item in extras if "=" in item
for k, v in [item.split("=", 1)]
}
def _set_seed(seed: int | None) -> int:
if seed is None:
seed = random.randint(0, 2**31 - 1)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
return seed
# =============================================================================
# CLI
# =============================================================================
def parse_args() -> tuple[argparse.Namespace, list[str]]:
parser = argparse.ArgumentParser(
description="ReMDM-MiniHack: Masked Diffusion Planner",
)
parser.add_argument(
"--mode",
required=True,
choices=[
"smoke", "offline", "dagger", "inference", "collect", "baselines",
],
)
parser.add_argument("--config", default="configs/defaults.yaml")
parser.add_argument(
"--algo", default=None, choices=list(ALL_BASELINE_ALGOS),
help="Baseline algorithm (required for --mode baselines)",
)
parser.add_argument(
"--seeds", type=int, nargs="+", default=None,
help=(
"Explicit list of seeds for --mode baselines "
"(e.g. --seeds 0 1 2)."
),
)
parser.add_argument(
"--n-seeds", type=int, default=None,
help=(
"Number of seeds starting from 0 (alternative to --seeds; "
"only used by --mode baselines)."
),
)
parser.add_argument("--data", default=None)
parser.add_argument("--checkpoint", default=None)
parser.add_argument(
"--wandb-artifact", default=None,
help=(
"W&B artifact reference to download as checkpoint, e.g. "
"'entity/project/checkpoint-iter1000:latest'"
),
)
parser.add_argument("--no-warm-start", action="store_true")
parser.add_argument("--no-ema", action="store_true")
parser.add_argument("--envs", nargs="+", default=None)
parser.add_argument(
"--des", nargs="+", default=None,
help="Paths to .des scenario files for custom environment evaluation",
)
parser.add_argument("--episodes", type=int, default=50)
parser.add_argument("--output", default=None)
parser.add_argument(
"--blind-global", action="store_true",
help="Zero out global map observations (local-only ablation)",
)
return parser.parse_known_args()
# =============================================================================
# Config
# =============================================================================
def build_config(args, extras):
config_path = args.config
if args.mode == "smoke" and config_path == "configs/defaults.yaml":
config_path = "configs/smoke.yaml"
cfg = load_config(config_path, _parse_overrides(extras))
seed = _set_seed(cfg.seed)
logger.info(f"Seed: {seed}")
return cfg
# =============================================================================
# Validation
# =============================================================================
def validate(args) -> None:
if args.mode == "inference" and not args.checkpoint and not args.wandb_artifact:
raise ValueError(
"--checkpoint or --wandb-artifact required for inference mode"
)
if args.mode == "baselines" and args.algo is None:
raise ValueError(
"--algo is required for --mode baselines "
f"(choose one of {list(ALL_BASELINE_ALGOS)})"
)
def _resolve_seeds(args, cfg) -> list[int]:
"""Build the seed list for --mode baselines."""
if args.seeds is not None:
return list(args.seeds)
if args.n_seeds is not None:
return list(range(int(args.n_seeds)))
return [cfg.seed if cfg.seed is not None else 0]
# =============================================================================
# Dispatch (no lambdas, cleaner)
# =============================================================================
def _resolve_path(p: str | None) -> str | None:
"""Resolve a user-provided path to absolute, or return None."""
if p is None:
return None
return str(Path(p).resolve())
def _resolve_checkpoint(args, cfg) -> str | None:
"""Return a local checkpoint path from --checkpoint or --wandb-artifact."""
if args.checkpoint:
return _resolve_path(args.checkpoint)
artifact_ref = args.wandb_artifact
if artifact_ref:
from src.planners.logging import download_artifact
path = download_artifact(artifact_ref)
if path is None:
raise RuntimeError(
f"Failed to download W&B artifact: {artifact_ref}"
)
return path
return None
def run_mode(mode: str, cfg, args) -> None:
data_path = _resolve_path(args.data)
output_path = _resolve_path(args.output)
des_files = (
[str(Path(d).resolve()) for d in args.des]
if args.des else None
)
if mode == "smoke":
run_smoke(cfg)
elif mode == "offline":
ckpt = _resolve_checkpoint(args, cfg)
run_offline(cfg, data_path, checkpoint_path=ckpt)
elif mode == "dagger":
ckpt = _resolve_checkpoint(args, cfg)
run_dagger(cfg, ckpt, args.no_warm_start)
elif mode == "collect":
run_collect(cfg)
elif mode == "baselines":
run_baselines(
cfg,
algo=args.algo,
seeds=_resolve_seeds(args, cfg),
output_path=output_path,
)
elif mode == "inference":
ckpt = _resolve_checkpoint(args, cfg)
if ckpt is None:
raise ValueError(
"--checkpoint or --wandb-artifact required for inference"
)
log = Logger(cfg)
run_inference(
cfg,
ckpt,
args.envs,
args.episodes,
output_path,
not args.no_ema,
log=log,
des_files=des_files,
blind_global=args.blind_global,
)
log.finish()
# =============================================================================
# Entry point
# =============================================================================
def main() -> None:
args, extras = parse_args()
validate(args)
cfg = build_config(args, extras)
if torch.cuda.is_available():
torch.set_float32_matmul_precision("high")
run_mode(args.mode, cfg, args)
if __name__ == "__main__":
main()