Spaces:
Sleeping
Sleeping
| from functools import partial | |
| from pathlib import Path | |
| import shutil | |
| import time | |
| from typing import List, Optional, Tuple | |
| from hydra.utils import instantiate | |
| import numpy as np | |
| from omegaconf import DictConfig, OmegaConf | |
| import torch | |
| import torch.distributed as dist | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm, trange | |
| import wandb | |
| from agent import Agent | |
| from coroutines.collector import make_collector, NumToCollect | |
| from data import BatchSampler, collate_segments_to_batch, Dataset, DatasetTraverser, CSGOHdf5Dataset | |
| from envs import make_atari_env, WorldModelEnv | |
| from utils import ( | |
| broadcast_if_needed, | |
| build_ddp_wrapper, | |
| CommonTools, | |
| configure_opt, | |
| count_parameters, | |
| get_lr_sched, | |
| keep_agent_copies_every, | |
| Logs, | |
| move_opt_to, | |
| process_confusion_matrices_if_any_and_compute_classification_metrics, | |
| save_info_for_import_script, | |
| save_with_backup, | |
| set_seed, | |
| StateDictMixin, | |
| try_until_no_except, | |
| wandb_log, | |
| ) | |
| class Trainer(StateDictMixin): | |
| def __init__(self, cfg: DictConfig, root_dir: Path) -> None: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| OmegaConf.resolve(cfg) | |
| self._cfg = cfg | |
| self._rank = dist.get_rank() if dist.is_initialized() else 0 | |
| self._world_size = dist.get_world_size() if dist.is_initialized() else 1 | |
| # Pick a random seed | |
| set_seed(torch.seed() % 10 ** 9) | |
| # Device | |
| self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu", self._rank) | |
| print(f"Starting on {self._device}") | |
| self._use_cuda = self._device.type == "cuda" | |
| if self._use_cuda: | |
| torch.cuda.set_device(self._rank) # fix compilation error on multi-gpu nodes | |
| # Init wandb | |
| if self._rank == 0: | |
| try_until_no_except( | |
| partial(wandb.init, config=OmegaConf.to_container(cfg, resolve=True), reinit=True, resume=True, **cfg.wandb) | |
| ) | |
| # Flags | |
| self._is_static_dataset = cfg.static_dataset.path is not None | |
| self._is_model_free = cfg.training.model_free | |
| # Checkpointing | |
| self._path_ckpt_dir = Path("checkpoints") | |
| self._path_state_ckpt = self._path_ckpt_dir / "state.pt" | |
| self._keep_agent_copies = partial( | |
| keep_agent_copies_every, | |
| every=cfg.checkpointing.save_agent_every, | |
| path_ckpt_dir=self._path_ckpt_dir, | |
| num_to_keep=cfg.checkpointing.num_to_keep, | |
| ) | |
| self._save_info_for_import_script = partial( | |
| save_info_for_import_script, run_name=cfg.wandb.name, path_ckpt_dir=self._path_ckpt_dir | |
| ) | |
| # First time, init files hierarchy | |
| if not cfg.common.resume and self._rank == 0: | |
| self._path_ckpt_dir.mkdir(exist_ok=False, parents=False) | |
| path_config = Path("config") / "trainer.yaml" | |
| path_config.parent.mkdir(exist_ok=False, parents=False) | |
| shutil.move(".hydra/config.yaml", path_config) | |
| wandb.save(str(path_config)) | |
| shutil.copytree(src=root_dir / "src", dst="./src") | |
| shutil.copytree(src=root_dir / "scripts", dst="./scripts") | |
| if cfg.env.train.id == "csgo": | |
| assert cfg.env.path_data_low_res is not None and cfg.env.path_data_full_res is not None, "Make sure to download CSGO data and set the relevant paths in cfg.env" | |
| assert self._is_static_dataset | |
| num_actions = cfg.env.num_actions | |
| dataset_full_res = CSGOHdf5Dataset(Path(cfg.env.path_data_full_res)) | |
| # Envs (atari only) | |
| else: | |
| if self._rank == 0: | |
| train_env = make_atari_env(num_envs=cfg.collection.train.num_envs, device=self._device, **cfg.env.train) | |
| test_env = make_atari_env(num_envs=cfg.collection.test.num_envs, device=self._device, **cfg.env.test) | |
| num_actions = int(test_env.num_actions) | |
| else: | |
| num_actions = None | |
| num_actions, = broadcast_if_needed(num_actions) | |
| dataset_full_res = None | |
| num_workers = cfg.training.num_workers_data_loaders | |
| use_manager = cfg.training.cache_in_ram and (num_workers > 0) | |
| p = Path(cfg.static_dataset.path) if self._is_static_dataset else Path("dataset") | |
| self.train_dataset = Dataset(p / "train", dataset_full_res, "train_dataset", cfg.training.cache_in_ram, use_manager) | |
| self.test_dataset = Dataset(p / "test", dataset_full_res, "test_dataset", cache_in_ram=True) | |
| self.train_dataset.load_from_default_path() | |
| self.test_dataset.load_from_default_path() | |
| # Create models | |
| self.agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(self._device) | |
| self._agent = build_ddp_wrapper(**self.agent._modules) if dist.is_initialized() else self.agent | |
| if cfg.initialization.path_to_ckpt is not None: | |
| self.agent.load(**cfg.initialization) | |
| # Collectors | |
| if not self._is_static_dataset and self._rank == 0: | |
| self._train_collector = make_collector( | |
| train_env, self.agent.actor_critic, self.train_dataset, cfg.collection.train.epsilon | |
| ) | |
| self._test_collector = make_collector( | |
| test_env, self.agent.actor_critic, self.test_dataset, cfg.collection.test.epsilon, reset_every_collect=True | |
| ) | |
| ###################################################### | |
| # Optimizers and LR schedulers | |
| def build_opt(name: str) -> torch.optim.AdamW: | |
| return configure_opt(getattr(self.agent, name), **getattr(cfg, name).optimizer) | |
| def build_lr_sched(name: str) -> torch.optim.lr_scheduler.LambdaLR: | |
| return get_lr_sched(self.opt.get(name), getattr(cfg, name).training.lr_warmup_steps) | |
| model_names = ["denoiser", "upsampler", "rew_end_model", "actor_critic"] | |
| self._model_names = ["actor_critic"] if self._is_model_free else [name for name in model_names if getattr(self.agent, name) is not None] | |
| self.opt = CommonTools(**{name: build_opt(name) for name in self._model_names}) | |
| self.lr_sched = CommonTools(**{name: build_lr_sched(name) for name in self._model_names}) | |
| # Data loaders | |
| make_data_loader = partial( | |
| DataLoader, | |
| dataset=self.train_dataset, | |
| collate_fn=collate_segments_to_batch, | |
| num_workers=num_workers, | |
| persistent_workers=(num_workers > 0), | |
| pin_memory=self._use_cuda, | |
| pin_memory_device=str(self._device) if self._use_cuda else "", | |
| ) | |
| make_batch_sampler = partial(BatchSampler, self.train_dataset, self._rank, self._world_size) | |
| def get_sample_weights(sample_weights: List[float]) -> Optional[List[float]]: | |
| return None if (self._is_static_dataset and cfg.static_dataset.ignore_sample_weights) else sample_weights | |
| c = cfg.denoiser.training | |
| seq_length = cfg.agent.denoiser.inner_model.num_steps_conditioning + 1 + c.num_autoregressive_steps | |
| bs = make_batch_sampler(c.batch_size, seq_length, get_sample_weights(c.sample_weights)) | |
| dl_denoiser_train = make_data_loader(batch_sampler=bs) | |
| dl_denoiser_test = DatasetTraverser(self.test_dataset, c.batch_size, seq_length) | |
| if self.agent.upsampler is not None: | |
| c = cfg.upsampler.training | |
| seq_length = cfg.agent.upsampler.inner_model.num_steps_conditioning + 1 + c.num_autoregressive_steps | |
| bs = make_batch_sampler(c.batch_size, seq_length, get_sample_weights(c.sample_weights)) | |
| dl_upsampler_train = make_data_loader(batch_sampler=bs) | |
| dl_upsampler_test = DatasetTraverser(self.test_dataset, c.batch_size, seq_length) | |
| else: | |
| dl_upsampler_train = dl_upsampler_test = None | |
| if self.agent.rew_end_model is not None: | |
| c = cfg.rew_end_model.training | |
| bs = make_batch_sampler(c.batch_size, c.seq_length, get_sample_weights(c.sample_weights), can_sample_beyond_end=True) | |
| dl_rew_end_model_train = make_data_loader(batch_sampler=bs) | |
| dl_rew_end_model_test = DatasetTraverser(self.test_dataset, c.batch_size, c.seq_length) | |
| else: | |
| dl_rew_end_model_train = dl_rew_end_model_test = None | |
| self._data_loader_train = CommonTools(dl_denoiser_train, dl_upsampler_train, dl_rew_end_model_train, None) | |
| self._data_loader_test = CommonTools(dl_denoiser_test, dl_upsampler_test, dl_rew_end_model_test, None) | |
| # RL env | |
| if self.agent.actor_critic is not None: | |
| actor_critic_loss_cfg = instantiate(cfg.actor_critic.actor_critic_loss) | |
| if self._is_model_free: | |
| assert self.agent.actor_critic is not None | |
| rl_env = make_atari_env(num_envs=cfg.actor_critic.training.batch_size, device=self._device, **cfg.env.train) | |
| else: | |
| c = cfg.actor_critic.training | |
| sl = cfg.agent.denoiser.inner_model.num_steps_conditioning | |
| if self.agent.upsampler is not None: | |
| sl = max(sl, cfg.agent.upsampler.inner_model.num_steps_conditioning) | |
| bs = make_batch_sampler(c.batch_size, sl, get_sample_weights(c.sample_weights)) | |
| dl_actor_critic = make_data_loader(batch_sampler=bs) | |
| wm_env_cfg = instantiate(cfg.world_model_env) | |
| rl_env = WorldModelEnv(self.agent.denoiser, self.agent.upsampler, self.agent.rew_end_model, dl_actor_critic, wm_env_cfg) | |
| if cfg.training.compile_wm: | |
| rl_env.predict_next_obs = torch.compile(rl_env.predict_next_obs, mode="reduce-overhead") | |
| rl_env.predict_rew_end = torch.compile(rl_env.predict_rew_end, mode="reduce-overhead") | |
| else: | |
| actor_critic_loss_cfg = None | |
| rl_env = None | |
| # Setup training | |
| sigma_distribution_cfg = instantiate(cfg.denoiser.sigma_distribution) | |
| sigma_distribution_cfg_upsampler = instantiate(cfg.upsampler.sigma_distribution) if self.agent.upsampler is not None else None | |
| self.agent.setup_training(sigma_distribution_cfg, sigma_distribution_cfg_upsampler, actor_critic_loss_cfg, rl_env) | |
| # Training state (things to be saved/restored) | |
| self.epoch = 0 | |
| self.num_epochs_collect = None | |
| self.num_episodes_test = 0 | |
| self.num_batch_train = CommonTools(0, 0, 0) | |
| self.num_batch_test = CommonTools(0, 0, 0) | |
| if cfg.common.resume: | |
| self.load_state_checkpoint() | |
| else: | |
| self.save_checkpoint() | |
| if self._rank == 0: | |
| for name in self._model_names: | |
| print(f"{count_parameters(getattr(self.agent, name))} parameters in {name}") | |
| print(self.train_dataset) | |
| print(self.test_dataset) | |
| def run(self) -> None: | |
| to_log = [] | |
| if self.epoch == 0: | |
| if self._is_model_free or self._is_static_dataset: | |
| self.num_epochs_collect = 0 | |
| else: | |
| if self._rank == 0: | |
| self.num_epochs_collect, to_log_ = self.collect_initial_dataset() | |
| to_log += to_log_ | |
| self.num_epochs_collect, sd_train_dataset = broadcast_if_needed(self.num_epochs_collect, self.train_dataset.state_dict()) | |
| self.train_dataset.load_state_dict(sd_train_dataset) | |
| num_epochs = self.num_epochs_collect + self._cfg.training.num_final_epochs | |
| while self.epoch < num_epochs: | |
| self.epoch += 1 | |
| start_time = time.time() | |
| if self._rank == 0: | |
| print(f"\nEpoch {self.epoch} / {num_epochs}\n") | |
| # Training | |
| should_collect_train = (self._rank == 0 and not self._is_model_free and not self._is_static_dataset and self.epoch <= self.num_epochs_collect) | |
| if should_collect_train: | |
| c = self._cfg.collection.train | |
| to_log += self._train_collector.send(NumToCollect(steps=c.steps_per_epoch)) | |
| sd_train_dataset, = broadcast_if_needed(self.train_dataset.state_dict()) # update dataset for ranks > 0 | |
| self.train_dataset.load_state_dict(sd_train_dataset) | |
| if self._cfg.training.should: | |
| to_log += self.train_agent() | |
| # Evaluation | |
| should_test = self._rank == 0 and self._cfg.evaluation.should and (self.epoch % self._cfg.evaluation.every == 0) | |
| should_collect_test = should_test and not self._is_static_dataset | |
| if should_collect_test: | |
| to_log += self.collect_test() | |
| if should_test and not self._is_model_free: | |
| to_log += self.test_agent() | |
| # Logging | |
| to_log.append({"duration": (time.time() - start_time) / 3600}) | |
| if self._rank == 0: | |
| wandb_log(to_log, self.epoch) | |
| to_log = [] | |
| # Checkpointing | |
| self.save_checkpoint() | |
| if dist.is_initialized(): | |
| dist.barrier() | |
| # Last collect | |
| if self._rank == 0 and not self._is_static_dataset: | |
| wandb_log(self.collect_test(final=True), self.epoch) | |
| def collect_initial_dataset(self) -> Tuple[int, Logs]: | |
| print("\nInitial collect\n") | |
| to_log = [] | |
| c = self._cfg.collection.train | |
| min_steps = c.first_epoch.min | |
| steps_per_epoch = c.steps_per_epoch | |
| max_steps = c.first_epoch.max | |
| threshold_rew = c.first_epoch.threshold_rew | |
| assert min_steps % steps_per_epoch == 0 | |
| steps = min_steps | |
| while True: | |
| to_log += self._train_collector.send(NumToCollect(steps=steps)) | |
| num_steps = self.train_dataset.num_steps | |
| total_minority_rew = sum(sorted(self.train_dataset.counts_rew)[:-1]) | |
| if total_minority_rew >= threshold_rew: | |
| break | |
| if (max_steps is not None) and num_steps >= max_steps: | |
| print("Reached the specified maximum for initial collect") | |
| break | |
| print(f"Minority reward: {total_minority_rew}/{threshold_rew} -> Keep collecting\n") | |
| steps = steps_per_epoch | |
| print("\nSummary of initial collect:") | |
| print(f"Num steps: {num_steps} / {c.num_steps_total}") | |
| print(f"Reward counts: {dict(self.train_dataset.counter_rew)}") | |
| remaining_steps = c.num_steps_total - num_steps | |
| assert remaining_steps % c.steps_per_epoch == 0 | |
| num_epochs_collect = remaining_steps // c.steps_per_epoch | |
| return num_epochs_collect, to_log | |
| def collect_test(self, final: bool = False) -> Logs: | |
| c = self._cfg.collection.test | |
| episodes = c.num_final_episodes if final else c.num_episodes | |
| td = self.test_dataset | |
| td.clear() | |
| to_log = self._test_collector.send(NumToCollect(episodes=episodes)) | |
| key_ep_id = f"{td.name}/episode_id" | |
| to_log = [{k: v + self.num_episodes_test if k == key_ep_id else v for k, v in x.items()} for x in to_log] | |
| print(f"\nSummary of {'final' if final else 'test'} collect: {td.num_episodes} episodes ({td.num_steps} steps)") | |
| keys = [key_ep_id, "return", "length"] | |
| to_log_episodes = [x for x in to_log if set(x.keys()) == set(keys)] | |
| episode_ids, returns, lengths = [[d[k] for d in to_log_episodes] for k in keys] | |
| for i, (ep_id, ret, length) in enumerate(zip(episode_ids, returns, lengths)): | |
| print(f" Episode {ep_id}: return = {ret} length = {length}\n", end="\n" if i == episodes - 1 else "") | |
| self.num_episodes_test += episodes | |
| if final: | |
| to_log.append({"final_return_mean": np.mean(returns), "final_return_std": np.std(returns)}) | |
| print(to_log[-1]) | |
| return to_log | |
| def train_agent(self) -> Logs: | |
| self.agent.train() | |
| self.agent.zero_grad() | |
| to_log = [] | |
| for name in self._model_names: | |
| cfg = getattr(self._cfg, name).training | |
| if self.epoch > cfg.start_after_epochs: | |
| steps = cfg.steps_first_epoch if self.epoch == 1 else cfg.steps_per_epoch | |
| to_log += self.train_component(name, steps) | |
| return to_log | |
| def test_agent(self) -> Logs: | |
| self.agent.eval() | |
| to_log = [] | |
| for name in self._model_names: | |
| if name == "actor_critic": | |
| continue | |
| cfg = getattr(self._cfg, name).training | |
| if self.epoch > cfg.start_after_epochs: | |
| to_log += self.test_component(name) | |
| return to_log | |
| def train_component(self, name: str, steps: int) -> Logs: | |
| cfg = getattr(self._cfg, name).training | |
| model = getattr(self._agent, name) | |
| opt = self.opt.get(name) | |
| lr_sched = self.lr_sched.get(name) | |
| data_loader = self._data_loader_train.get(name) | |
| torch.cuda.empty_cache() | |
| model.to(self._device) | |
| move_opt_to(opt, self._device) | |
| model.train() | |
| opt.zero_grad() | |
| data_iterator = iter(data_loader) if data_loader is not None else None | |
| to_log = [] | |
| num_steps = cfg.grad_acc_steps * steps | |
| for i in trange(num_steps, desc=f"Training {name}", disable=self._rank > 0): | |
| batch = next(data_iterator).to(self._device) if data_iterator is not None else None | |
| loss, metrics = model(batch) if batch is not None else model() | |
| loss.backward() | |
| num_batch = self.num_batch_train.get(name) | |
| metrics[f"num_batch_train_{name}"] = num_batch | |
| self.num_batch_train.set(name, num_batch + 1) | |
| if (i + 1) % cfg.grad_acc_steps == 0: | |
| if cfg.max_grad_norm is not None: | |
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm).item() | |
| metrics["grad_norm_before_clip"] = grad_norm | |
| opt.step() | |
| opt.zero_grad() | |
| if lr_sched is not None: | |
| metrics["lr"] = lr_sched.get_last_lr()[0] | |
| lr_sched.step() | |
| to_log.append(metrics) | |
| process_confusion_matrices_if_any_and_compute_classification_metrics(to_log) | |
| to_log = [{f"{name}/train/{k}": v for k, v in d.items()} for d in to_log] | |
| model.to("cpu") | |
| move_opt_to(opt, "cpu") | |
| return to_log | |
| def test_component(self, name: str) -> Logs: | |
| model = getattr(self.agent, name) | |
| data_loader = self._data_loader_test.get(name) | |
| model.eval() | |
| model.to(self._device) | |
| to_log = [] | |
| for batch in tqdm(data_loader, desc=f"Evaluating {name}"): | |
| batch = batch.to(self._device) | |
| _, metrics = model(batch) | |
| num_batch = self.num_batch_test.get(name) | |
| metrics[f"num_batch_test_{name}"] = num_batch | |
| self.num_batch_test.set(name, num_batch + 1) | |
| to_log.append(metrics) | |
| process_confusion_matrices_if_any_and_compute_classification_metrics(to_log) | |
| to_log = [{f"{name}/test/{k}": v for k, v in d.items()} for d in to_log] | |
| model.to("cpu") | |
| return to_log | |
| def load_state_checkpoint(self) -> None: | |
| self.load_state_dict(torch.load(self._path_state_ckpt, map_location=self._device)) | |
| def save_checkpoint(self) -> None: | |
| if self._rank == 0: | |
| save_with_backup(self.state_dict(), self._path_state_ckpt) | |
| self.train_dataset.save_to_default_path() | |
| self.test_dataset.save_to_default_path() | |
| self._keep_agent_copies(self.agent.state_dict(), self.epoch) | |
| self._save_info_for_import_script(self.epoch) | |