| if __name__ == "__main__": |
| import sys |
| import os |
| import pathlib |
|
|
| ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) |
| sys.path.append(ROOT_DIR) |
| os.chdir(ROOT_DIR) |
|
|
| import os |
| import hydra |
| import torch |
| from omegaconf import OmegaConf |
| import pathlib |
| from torch.utils.data import DataLoader |
| import copy |
| import random |
| import wandb |
| import tqdm |
| import numpy as np |
| import shutil |
| from diffusion_policy.workspace.base_workspace import BaseWorkspace |
| from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy |
| from diffusion_policy.dataset.base_dataset import BaseImageDataset |
| from diffusion_policy.env_runner.base_image_runner import BaseImageRunner |
| from diffusion_policy.common.checkpoint_util import TopKCheckpointManager |
| from diffusion_policy.common.json_logger import JsonLogger |
| from diffusion_policy.common.pytorch_util import dict_apply, optimizer_to |
| from diffusion_policy.model.diffusion.ema_model import EMAModel |
| from diffusion_policy.model.common.lr_scheduler import get_scheduler |
|
|
| OmegaConf.register_new_resolver("eval", eval, replace=True) |
|
|
| class TrainDiffusionUnetImageWorkspace(BaseWorkspace): |
| include_keys = ['global_step', 'epoch'] |
|
|
| def __init__(self, cfg: OmegaConf, output_dir=None): |
| super().__init__(cfg, output_dir=output_dir) |
|
|
| |
| seed = cfg.training.seed |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
|
|
| |
| self.model: DiffusionUnetImagePolicy = hydra.utils.instantiate(cfg.policy) |
|
|
| self.ema_model: DiffusionUnetImagePolicy = None |
| if cfg.training.use_ema: |
| self.ema_model = copy.deepcopy(self.model) |
|
|
| |
| self.optimizer = hydra.utils.instantiate( |
| cfg.optimizer, params=self.model.parameters()) |
|
|
| |
| self.global_step = 0 |
| self.epoch = 0 |
|
|
| def run(self): |
| cfg = copy.deepcopy(self.cfg) |
|
|
| |
| if cfg.training.resume: |
| lastest_ckpt_path = self.get_checkpoint_path() |
| if lastest_ckpt_path.is_file(): |
| print(f"Resuming from checkpoint {lastest_ckpt_path}") |
| self.load_checkpoint(path=lastest_ckpt_path) |
|
|
| |
| dataset: BaseImageDataset |
| dataset = hydra.utils.instantiate(cfg.task.dataset) |
| assert isinstance(dataset, BaseImageDataset) |
| train_dataloader = DataLoader(dataset, **cfg.dataloader) |
| normalizer = dataset.get_normalizer() |
|
|
| |
| val_dataset = dataset.get_validation_dataset() |
| val_dataloader = DataLoader(val_dataset, **cfg.val_dataloader) |
|
|
| self.model.set_normalizer(normalizer) |
| if cfg.training.use_ema: |
| self.ema_model.set_normalizer(normalizer) |
|
|
| |
| lr_scheduler = get_scheduler( |
| cfg.training.lr_scheduler, |
| optimizer=self.optimizer, |
| num_warmup_steps=cfg.training.lr_warmup_steps, |
| num_training_steps=( |
| len(train_dataloader) * cfg.training.num_epochs) \ |
| // cfg.training.gradient_accumulate_every, |
| |
| |
| last_epoch=self.global_step-1 |
| ) |
|
|
| |
| ema: EMAModel = None |
| if cfg.training.use_ema: |
| ema = hydra.utils.instantiate( |
| cfg.ema, |
| model=self.ema_model) |
|
|
| |
| env_runner: BaseImageRunner |
| env_runner = hydra.utils.instantiate( |
| cfg.task.env_runner, |
| output_dir=self.output_dir) |
| assert isinstance(env_runner, BaseImageRunner) |
|
|
| |
| wandb_run = wandb.init( |
| dir=str(self.output_dir), |
| config=OmegaConf.to_container(cfg, resolve=True), |
| **cfg.logging |
| ) |
| wandb.config.update( |
| { |
| "output_dir": self.output_dir, |
| } |
| ) |
|
|
| |
| topk_manager = TopKCheckpointManager( |
| save_dir=os.path.join(self.output_dir, 'checkpoints'), |
| **cfg.checkpoint.topk |
| ) |
|
|
| |
| device = torch.device(cfg.training.device) |
| self.model.to(device) |
| if self.ema_model is not None: |
| self.ema_model.to(device) |
| optimizer_to(self.optimizer, device) |
|
|
| |
| train_sampling_batch = None |
|
|
| if cfg.training.debug: |
| cfg.training.num_epochs = 2 |
| cfg.training.max_train_steps = 3 |
| cfg.training.max_val_steps = 3 |
| cfg.training.rollout_every = 1 |
| cfg.training.checkpoint_every = 1 |
| cfg.training.val_every = 1 |
| cfg.training.sample_every = 1 |
|
|
| |
| log_path = os.path.join(self.output_dir, 'logs.json.txt') |
| with JsonLogger(log_path) as json_logger: |
| for local_epoch_idx in range(cfg.training.num_epochs): |
| step_log = dict() |
| |
| if cfg.training.freeze_encoder: |
| self.model.obs_encoder.eval() |
| self.model.obs_encoder.requires_grad_(False) |
|
|
| train_losses = list() |
| with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}", |
| leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch: |
| for batch_idx, batch in enumerate(tepoch): |
| |
| batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True)) |
| if train_sampling_batch is None: |
| train_sampling_batch = batch |
|
|
| |
| raw_loss = self.model.compute_loss(batch) |
| loss = raw_loss / cfg.training.gradient_accumulate_every |
| loss.backward() |
|
|
| |
| if self.global_step % cfg.training.gradient_accumulate_every == 0: |
| self.optimizer.step() |
| self.optimizer.zero_grad() |
| lr_scheduler.step() |
| |
| |
| if cfg.training.use_ema: |
| ema.step(self.model) |
|
|
| |
| raw_loss_cpu = raw_loss.item() |
| tepoch.set_postfix(loss=raw_loss_cpu, refresh=False) |
| train_losses.append(raw_loss_cpu) |
| step_log = { |
| 'train_loss': raw_loss_cpu, |
| 'global_step': self.global_step, |
| 'epoch': self.epoch, |
| 'lr': lr_scheduler.get_last_lr()[0] |
| } |
|
|
| is_last_batch = (batch_idx == (len(train_dataloader)-1)) |
| if not is_last_batch: |
| |
| wandb_run.log(step_log, step=self.global_step) |
| json_logger.log(step_log) |
| self.global_step += 1 |
|
|
| if (cfg.training.max_train_steps is not None) \ |
| and batch_idx >= (cfg.training.max_train_steps-1): |
| break |
|
|
| |
| |
| train_loss = np.mean(train_losses) |
| step_log['train_loss'] = train_loss |
|
|
| |
| policy = self.model |
| if cfg.training.use_ema: |
| policy = self.ema_model |
| policy.eval() |
|
|
| |
| if (self.epoch % cfg.training.rollout_every) == 0: |
| runner_log = env_runner.run(policy) |
| |
| step_log.update(runner_log) |
|
|
| |
| if (self.epoch % cfg.training.val_every) == 0: |
| with torch.no_grad(): |
| val_losses = list() |
| with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}", |
| leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch: |
| for batch_idx, batch in enumerate(tepoch): |
| batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True)) |
| loss = self.model.compute_loss(batch) |
| val_losses.append(loss) |
| if (cfg.training.max_val_steps is not None) \ |
| and batch_idx >= (cfg.training.max_val_steps-1): |
| break |
| if len(val_losses) > 0: |
| val_loss = torch.mean(torch.tensor(val_losses)).item() |
| |
| step_log['val_loss'] = val_loss |
|
|
| |
| if (self.epoch % cfg.training.sample_every) == 0: |
| with torch.no_grad(): |
| |
| batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True)) |
| obs_dict = batch['obs'] |
| gt_action = batch['action'] |
| |
| result = policy.predict_action(obs_dict) |
| pred_action = result['action_pred'] |
| mse = torch.nn.functional.mse_loss(pred_action, gt_action) |
| step_log['train_action_mse_error'] = mse.item() |
| del batch |
| del obs_dict |
| del gt_action |
| del result |
| del pred_action |
| del mse |
| |
| |
| if (self.epoch % cfg.training.checkpoint_every) == 0: |
| |
| if cfg.checkpoint.save_last_ckpt: |
| self.save_checkpoint() |
| if cfg.checkpoint.save_last_snapshot: |
| self.save_snapshot() |
|
|
| |
| metric_dict = dict() |
| for key, value in step_log.items(): |
| new_key = key.replace('/', '_') |
| metric_dict[new_key] = value |
| |
| |
| |
| |
| topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict) |
|
|
| if topk_ckpt_path is not None: |
| self.save_checkpoint(path=topk_ckpt_path) |
| |
| policy.train() |
|
|
| |
| |
| wandb_run.log(step_log, step=self.global_step) |
| json_logger.log(step_log) |
| self.global_step += 1 |
| self.epoch += 1 |
|
|
| @hydra.main( |
| version_base=None, |
| config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")), |
| config_name=pathlib.Path(__file__).stem) |
| def main(cfg): |
| workspace = TrainDiffusionUnetImageWorkspace(cfg) |
| workspace.run() |
|
|
| if __name__ == "__main__": |
| main() |
|
|