| if __name__ == "__main__": |
| import sys |
| import os |
| import pathlib |
|
|
| ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) |
| sys.path.append(ROOT_DIR) |
| os.chdir(ROOT_DIR) |
|
|
| import os |
| import hydra |
| import torch |
| from omegaconf import OmegaConf |
| import pathlib |
| from torch.utils.data import DataLoader |
| import copy |
| import random |
| import wandb |
| import tqdm |
| import numpy as np |
|
|
| from diffusion_policy.workspace.base_workspace import BaseWorkspace |
| from diffusion_policy.policy.diffusion_unet_video_policy import DiffusionUnetVideoPolicy |
| from diffusion_policy.dataset.base_dataset import BaseImageDataset |
| from diffusion_policy.env_runner.base_image_runner import BaseImageRunner |
| from diffusion_policy.common.checkpoint_util import TopKCheckpointManager |
| from diffusion_policy.common.pytorch_util import dict_apply, optimizer_to |
| from diffusion_policy.model.diffusion.ema_model import EMAModel |
| from diffusion_policy.model.common.lr_scheduler import get_scheduler |
|
|
| OmegaConf.register_new_resolver("eval", eval, replace=True) |
|
|
| class TrainDiffusionUnetVideoWorkspace(BaseWorkspace): |
| include_keys = ['global_step', 'epoch'] |
|
|
| def __init__(self, cfg: OmegaConf, output_dir=None): |
| super().__init__(cfg, output_dir=output_dir) |
| |
| |
| seed = cfg.training.seed |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
|
|
| |
| self.model: DiffusionUnetVideoPolicy = hydra.utils.instantiate(cfg.policy) |
|
|
| self.ema_model: DiffusionUnetVideoPolicy = None |
| if cfg.training.use_ema: |
| self.ema_model = copy.deepcopy(self.model) |
|
|
| |
| self.optimizer = hydra.utils.instantiate( |
| cfg.optimizer, params=self.model.parameters()) |
|
|
| |
| self.global_step = 0 |
| self.epoch = 0 |
|
|
| def run(self): |
| cfg = copy.deepcopy(self.cfg) |
|
|
| |
| if cfg.training.resume: |
| lastest_ckpt_path = self.get_checkpoint_path() |
| if lastest_ckpt_path.is_file(): |
| print(f"Resuming from checkpoint {lastest_ckpt_path}") |
| self.load_checkpoint(path=lastest_ckpt_path) |
|
|
| |
| dataset: BaseImageDataset |
| dataset = hydra.utils.instantiate(cfg.task.dataset) |
| assert isinstance(dataset, BaseImageDataset) |
| train_dataloader = DataLoader(dataset, **cfg.dataloader) |
| normalizer = dataset.get_normalizer() |
|
|
| self.model.set_normalizer(normalizer) |
| if cfg.training.use_ema: |
| self.ema_model.set_normalizer(normalizer) |
|
|
| |
| lr_scheduler = get_scheduler( |
| cfg.training.lr_scheduler, |
| optimizer=self.optimizer, |
| num_warmup_steps=cfg.training.lr_warmup_steps, |
| num_training_steps=( |
| len(train_dataloader) * cfg.training.num_epochs) \ |
| // cfg.training.gradient_accumulate_every, |
| |
| |
| last_epoch=self.global_step-1 |
| ) |
|
|
| |
| ema: EMAModel = None |
| if cfg.training.use_ema: |
| ema = hydra.utils.instantiate( |
| cfg.ema, |
| model=self.ema_model) |
|
|
| |
| env_runner: BaseImageRunner |
| env_runner = hydra.utils.instantiate( |
| cfg.task.env_runner, |
| output_dir=self.output_dir) |
| assert isinstance(env_runner, BaseImageRunner) |
|
|
| |
| wandb_run = wandb.init( |
| dir=str(self.output_dir), |
| config=OmegaConf.to_container(cfg, resolve=True), |
| **cfg.logging |
| ) |
| wandb.config.update( |
| { |
| "output_dir": self.output_dir, |
| } |
| ) |
|
|
| |
| topk_manager = TopKCheckpointManager( |
| save_dir=os.path.join(self.output_dir, 'checkpoints'), |
| **cfg.checkpoint.topk |
| ) |
|
|
| |
| device = torch.device(cfg.training.device) |
| self.model.to(device) |
| if self.ema_model is not None: |
| self.ema_model.to(device) |
| optimizer_to(self.optimizer, device) |
|
|
| |
| val_batch = next(iter(train_dataloader)) |
|
|
| |
| for _ in range(cfg.training.num_epochs): |
| with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}", |
| leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch: |
| for batch in tepoch: |
| |
| batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True)) |
|
|
| |
| raw_loss = self.model.compute_loss(batch) |
| loss = raw_loss / cfg.training.gradient_accumulate_every |
| loss.backward() |
|
|
| |
| if self.global_step % cfg.training.gradient_accumulate_every == 0: |
| self.optimizer.step() |
| self.optimizer.zero_grad() |
| lr_scheduler.step() |
| |
| |
| if cfg.training.use_ema: |
| ema.step(self.model) |
|
|
| |
| raw_loss_cpu = raw_loss.item() |
| |
| step_log = { |
| 'train_loss': raw_loss_cpu, |
| 'global_step': self.global_step, |
| 'epoch': self.epoch, |
| 'lr': lr_scheduler.get_last_lr()[0] |
| } |
|
|
| |
| should_eval = self.global_step % cfg.training.eval_every == 0 |
| if (not cfg.training.eval_first) and (self.global_step == 0): |
| should_eval = False |
| if should_eval: |
| policy = self.model |
| if cfg.training.use_ema: |
| policy = self.ema_model |
| policy.eval() |
| runner_log = env_runner.run(policy) |
| policy.train() |
| step_log.update(runner_log) |
|
|
| |
| if cfg.checkpoint.save_last_ckpt: |
| self.save_checkpoint() |
| if cfg.checkpoint.save_last_snapshot: |
| self.save_snapshot() |
|
|
| save_path = topk_manager.get_ckpt_path({ |
| 'epoch': self.epoch, |
| 'test_score': runner_log['test/mean_score'] |
| }) |
| if save_path is not None: |
| self.save_checkpoint(path=save_path) |
|
|
| |
| should_val = self.global_step % cfg.training.val_every == 0 |
| if should_val: |
| policy = self.model |
| if cfg.training.use_ema: |
| policy = self.ema_model |
| policy.eval() |
| with torch.no_grad(): |
| |
| batch = dict_apply(val_batch, lambda x: x.to(device, non_blocking=True)) |
| obs_dict = batch['obs'] |
| gt_action = batch['action'] |
| |
| result = policy.predict_action(obs_dict) |
| pred_action = result['action_pred'] |
| mse = torch.nn.functional.mse_loss(pred_action, gt_action) |
| step_log['train_action_mse_error'] = mse.item() |
| del batch |
| del obs_dict |
| del gt_action |
| del result |
| del pred_action |
| del mse |
| policy.train() |
|
|
| |
| wandb_run.log(step_log, step=self.global_step) |
|
|
| self.global_step += 1 |
| self.epoch += 1 |
|
|
| @hydra.main( |
| version_base=None, |
| config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")), |
| config_name=pathlib.Path(__file__).stem) |
| def main(cfg): |
| workspace = TrainDiffusionUnetVideoWorkspace(cfg) |
| workspace.run() |
|
|
| if __name__ == "__main__": |
| main() |
|
|