| if __name__ == "__main__": |
| import sys |
| import os |
| import pathlib |
|
|
| ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) |
| sys.path.append(ROOT_DIR) |
| os.chdir(ROOT_DIR) |
|
|
| import os |
| import hydra |
| import torch |
| from omegaconf import OmegaConf |
| import pathlib |
| from torch.utils.data import DataLoader |
| import copy |
| import random |
| import wandb |
| import tqdm |
| import numpy as np |
| import shutil |
| from diffusion_policy.workspace.base_workspace import BaseWorkspace |
| from diffusion_policy.policy.robomimic_image_policy import RobomimicImagePolicy |
| from diffusion_policy.dataset.base_dataset import BaseImageDataset |
| from diffusion_policy.env_runner.base_image_runner import BaseImageRunner |
| from diffusion_policy.common.checkpoint_util import TopKCheckpointManager |
| from diffusion_policy.common.json_logger import JsonLogger |
| from diffusion_policy.common.pytorch_util import dict_apply, optimizer_to |
|
|
|
|
| OmegaConf.register_new_resolver("eval", eval, replace=True) |
|
|
| class TrainRobomimicImageWorkspace(BaseWorkspace): |
| include_keys = ['global_step', 'epoch'] |
|
|
| def __init__(self, cfg: OmegaConf, output_dir=None): |
| super().__init__(cfg, output_dir=output_dir) |
|
|
| |
| seed = cfg.training.seed |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
|
|
| |
| self.model: RobomimicImagePolicy = hydra.utils.instantiate(cfg.policy) |
|
|
| |
| self.global_step = 0 |
| self.epoch = 0 |
|
|
| def run(self): |
| cfg = copy.deepcopy(self.cfg) |
|
|
| |
| if cfg.training.resume: |
| lastest_ckpt_path = self.get_checkpoint_path() |
| if lastest_ckpt_path.is_file(): |
| print(f"Resuming from checkpoint {lastest_ckpt_path}") |
| self.load_checkpoint(path=lastest_ckpt_path) |
|
|
| |
| dataset: BaseImageDataset |
| dataset = hydra.utils.instantiate(cfg.task.dataset) |
| assert isinstance(dataset, BaseImageDataset) |
| train_dataloader = DataLoader(dataset, **cfg.dataloader) |
| normalizer = dataset.get_normalizer() |
|
|
| |
| val_dataset = dataset.get_validation_dataset() |
| val_dataloader = DataLoader(val_dataset, **cfg.val_dataloader) |
|
|
| self.model.set_normalizer(normalizer) |
|
|
| |
| env_runner: BaseImageRunner |
| env_runner = hydra.utils.instantiate( |
| cfg.task.env_runner, |
| output_dir=self.output_dir) |
| assert isinstance(env_runner, BaseImageRunner) |
|
|
| |
| wandb_run = wandb.init( |
| dir=str(self.output_dir), |
| config=OmegaConf.to_container(cfg, resolve=True), |
| **cfg.logging |
| ) |
| wandb.config.update( |
| { |
| "output_dir": self.output_dir, |
| } |
| ) |
|
|
| |
| topk_manager = TopKCheckpointManager( |
| save_dir=os.path.join(self.output_dir, 'checkpoints'), |
| **cfg.checkpoint.topk |
| ) |
|
|
| |
| device = torch.device(cfg.training.device) |
| self.model.to(device) |
|
|
| |
| train_sampling_batch = None |
|
|
| if cfg.training.debug: |
| cfg.training.num_epochs = 2 |
| cfg.training.max_train_steps = 3 |
| cfg.training.max_val_steps = 3 |
| cfg.training.rollout_every = 1 |
| cfg.training.checkpoint_every = 1 |
| cfg.training.val_every = 1 |
| cfg.training.sample_every = 1 |
|
|
| |
| log_path = os.path.join(self.output_dir, 'logs.json.txt') |
| with JsonLogger(log_path) as json_logger: |
| for local_epoch_idx in range(cfg.training.num_epochs): |
| step_log = dict() |
| |
| train_losses = list() |
| with tqdm.tqdm(train_dataloader, desc=f"Training epoch {self.epoch}", |
| leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch: |
| for batch_idx, batch in enumerate(tepoch): |
| |
| batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True)) |
| if train_sampling_batch is None: |
| train_sampling_batch = batch |
|
|
| info = self.model.train_on_batch(batch, epoch=self.epoch) |
|
|
| |
| loss_cpu = info['losses']['action_loss'].item() |
| tepoch.set_postfix(loss=loss_cpu, refresh=False) |
| train_losses.append(loss_cpu) |
| step_log = { |
| 'train_loss': loss_cpu, |
| 'global_step': self.global_step, |
| 'epoch': self.epoch |
| } |
|
|
| is_last_batch = (batch_idx == (len(train_dataloader)-1)) |
| if not is_last_batch: |
| |
| wandb_run.log(step_log, step=self.global_step) |
| json_logger.log(step_log) |
| self.global_step += 1 |
|
|
| if (cfg.training.max_train_steps is not None) \ |
| and batch_idx >= (cfg.training.max_train_steps-1): |
| break |
|
|
| |
| |
| train_loss = np.mean(train_losses) |
| step_log['train_loss'] = train_loss |
|
|
| |
| self.model.eval() |
|
|
| |
| if (self.epoch % cfg.training.rollout_every) == 0: |
| runner_log = env_runner.run(self.model) |
| |
| step_log.update(runner_log) |
|
|
| |
| if (self.epoch % cfg.training.val_every) == 0: |
| with torch.no_grad(): |
| val_losses = list() |
| with tqdm.tqdm(val_dataloader, desc=f"Validation epoch {self.epoch}", |
| leave=False, mininterval=cfg.training.tqdm_interval_sec) as tepoch: |
| for batch_idx, batch in enumerate(tepoch): |
| batch = dict_apply(batch, lambda x: x.to(device, non_blocking=True)) |
| info = self.model.train_on_batch(batch, epoch=self.epoch, validate=True) |
| loss = info['losses']['action_loss'] |
| val_losses.append(loss) |
| if (cfg.training.max_val_steps is not None) \ |
| and batch_idx >= (cfg.training.max_val_steps-1): |
| break |
| if len(val_losses) > 0: |
| val_loss = torch.mean(torch.tensor(val_losses)).item() |
| |
| step_log['val_loss'] = val_loss |
|
|
| |
| if (self.epoch % cfg.training.sample_every) == 0: |
| with torch.no_grad(): |
| |
| batch = dict_apply(train_sampling_batch, lambda x: x.to(device, non_blocking=True)) |
| obs_dict = batch['obs'] |
| gt_action = batch['action'] |
| T = gt_action.shape[1] |
|
|
| pred_actions = list() |
| self.model.reset() |
| for i in range(T): |
| result = self.model.predict_action( |
| dict_apply(obs_dict, lambda x: x[:,[i]]) |
| ) |
| pred_actions.append(result['action']) |
| pred_actions = torch.cat(pred_actions, dim=1) |
| mse = torch.nn.functional.mse_loss(pred_actions, gt_action) |
| step_log['train_action_mse_error'] = mse.item() |
| del batch |
| del obs_dict |
| del gt_action |
| del result |
| del pred_actions |
| del mse |
|
|
| |
| if (self.epoch % cfg.training.checkpoint_every) == 0: |
| |
| if cfg.checkpoint.save_last_ckpt: |
| self.save_checkpoint() |
| if cfg.checkpoint.save_last_snapshot: |
| self.save_snapshot() |
|
|
| |
| metric_dict = dict() |
| for key, value in step_log.items(): |
| new_key = key.replace('/', '_') |
| metric_dict[new_key] = value |
| |
| |
| |
| |
| topk_ckpt_path = topk_manager.get_ckpt_path(metric_dict) |
|
|
| if topk_ckpt_path is not None: |
| self.save_checkpoint(path=topk_ckpt_path) |
| |
| self.model.train() |
|
|
| |
| |
| wandb_run.log(step_log, step=self.global_step) |
| json_logger.log(step_log) |
| self.global_step += 1 |
| self.epoch += 1 |
|
|
|
|
| @hydra.main( |
| version_base=None, |
| config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")), |
| config_name=pathlib.Path(__file__).stem) |
| def main(cfg): |
| workspace = TrainRobomimicImageWorkspace(cfg) |
| workspace.run() |
|
|
| if __name__ == "__main__": |
| main() |
|
|