| """ |
| This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research |
| template [repo](https://github.com/buoyancy99/research-template). |
| By its MIT license, you must keep the above sentence in `README.md` |
| and the `LICENSE` file to credit the author. |
| """ |
|
|
| from typing import Optional |
| from tqdm import tqdm |
| from omegaconf import DictConfig |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from typing import Any |
| from einops import rearrange |
|
|
| from lightning.pytorch.utilities.types import STEP_OUTPUT |
|
|
| from algorithms.common.base_pytorch_algo import BasePytorchAlgo |
| from .models.diffusion import Diffusion |
|
|
|
|
| class DiffusionForcingBase(BasePytorchAlgo): |
| def __init__(self, cfg: DictConfig): |
| self.cfg = cfg |
| self.x_shape = cfg.x_shape |
| self.frame_stack = cfg.frame_stack |
| self.x_stacked_shape = list(self.x_shape) |
| self.x_stacked_shape[0] *= cfg.frame_stack |
| self.guidance_scale = cfg.guidance_scale |
| self.context_frames = cfg.context_frames |
| self.chunk_size = cfg.chunk_size |
| self.action_cond_dim = cfg.action_cond_dim |
| self.causal = cfg.causal |
|
|
|
|
| self.uncertainty_scale = cfg.uncertainty_scale |
| self.timesteps = cfg.diffusion.timesteps |
| self.sampling_timesteps = cfg.diffusion.sampling_timesteps |
| self.clip_noise = cfg.diffusion.clip_noise |
|
|
| self.cfg.diffusion.cum_snr_decay = self.cfg.diffusion.cum_snr_decay ** (self.frame_stack * cfg.frame_skip) |
|
|
| self.validation_step_outputs = [] |
| super().__init__(cfg) |
|
|
| def _build_model(self): |
| self.diffusion_model = Diffusion( |
| x_shape=self.x_stacked_shape, |
| action_cond_dim=self.action_cond_dim, |
| is_causal=self.causal, |
| cfg=self.cfg.diffusion, |
| ) |
| self.register_data_mean_std(self.cfg.data_mean, self.cfg.data_std) |
|
|
| def configure_optimizers(self): |
| params = tuple(self.diffusion_model.parameters()) |
| optimizer_dynamics = torch.optim.AdamW( |
| params, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay, betas=self.cfg.optimizer_beta |
| ) |
| return optimizer_dynamics |
|
|
| def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): |
| |
| optimizer.step(closure=optimizer_closure) |
|
|
| |
| if self.trainer.global_step < self.cfg.warmup_steps: |
| lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.cfg.warmup_steps) |
| for pg in optimizer.param_groups: |
| pg["lr"] = lr_scale * self.cfg.lr |
|
|
| def training_step(self, batch, batch_idx) -> STEP_OUTPUT: |
| xs, conditions, masks = self._preprocess_batch(batch) |
|
|
| rand_length = torch.randint(3,xs.shape[0]-2, (1,))[0].item() |
| xs = torch.cat([xs[:rand_length], xs[rand_length-3:rand_length-1]]) |
| conditions = torch.cat([conditions[:rand_length], conditions[rand_length-3:rand_length-1]]) |
| masks = torch.cat([masks[:rand_length], masks[rand_length-3:rand_length-1]]) |
| noise_levels=self._generate_noise_levels(xs) |
| noise_levels[:rand_length] = 15 |
| noise_levels[rand_length+1:] = 15 |
|
|
| xs_pred, loss = self.diffusion_model(xs, conditions, noise_levels=noise_levels) |
| loss = self.reweight_loss(loss, masks) |
|
|
| |
| if batch_idx % 20 == 0: |
| self.log("training/loss", loss) |
|
|
| xs = self._unstack_and_unnormalize(xs) |
| xs_pred = self._unstack_and_unnormalize(xs_pred) |
|
|
| output_dict = { |
| "loss": loss, |
| "xs_pred": xs_pred, |
| "xs": xs, |
| } |
|
|
| return output_dict |
|
|
| @torch.no_grad() |
| def validation_step(self, batch, batch_idx, namespace="validation") -> STEP_OUTPUT: |
| xs, conditions, masks = self._preprocess_batch(batch) |
| n_frames, batch_size, *_ = xs.shape |
| xs_pred = [] |
| curr_frame = 0 |
|
|
| |
| n_context_frames = self.context_frames // self.frame_stack |
| xs_pred = xs[:n_context_frames].clone() |
| curr_frame += n_context_frames |
|
|
| if self.condtion_similar_length: |
| n_frames -= self.condtion_similar_length |
|
|
| pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling") |
| while curr_frame < n_frames: |
| if self.chunk_size > 0: |
| horizon = min(n_frames - curr_frame, self.chunk_size) |
| else: |
| horizon = n_frames - curr_frame |
| assert horizon <= self.n_tokens, "horizon exceeds the number of tokens." |
| scheduling_matrix = self._generate_scheduling_matrix(horizon) |
|
|
| chunk = torch.randn((horizon, batch_size, *self.x_stacked_shape), device=self.device) |
| chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise) |
| xs_pred = torch.cat([xs_pred, chunk], 0) |
|
|
| |
| start_frame = max(0, curr_frame + horizon - self.n_tokens) |
|
|
| pbar.set_postfix( |
| { |
| "start": start_frame, |
| "end": curr_frame + horizon, |
| } |
| ) |
|
|
| if self.condtion_similar_length: |
| xs_pred = torch.cat([xs_pred, xs[curr_frame-self.condtion_similar_length:curr_frame].clone()], 0) |
|
|
| for m in range(scheduling_matrix.shape[0] - 1): |
|
|
| from_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m]))[ |
| :, None |
| ].repeat(batch_size, axis=1) |
| to_noise_levels = np.concatenate( |
| ( |
| np.zeros((curr_frame,), dtype=np.int64), |
| scheduling_matrix[m + 1], |
| ) |
| )[ |
| :, None |
| ].repeat(batch_size, axis=1) |
|
|
| if self.condtion_similar_length: |
| from_noise_levels = np.concatenate([from_noise_levels, np.array([[0,0,0,0]*self.condtion_similar_length])], axis=0) |
| to_noise_levels = np.concatenate([to_noise_levels, np.array([[0,0,0,0]*self.condtion_similar_length])], axis=0) |
|
|
| from_noise_levels = torch.from_numpy(from_noise_levels).to(self.device) |
| to_noise_levels = torch.from_numpy(to_noise_levels).to(self.device) |
|
|
| |
| |
|
|
| try: |
| input_condition = conditions[start_frame : curr_frame + horizon].clone() |
| except: |
| import pdb;pdb.set_trace() |
| if self.condtion_similar_length: |
| input_condition = torch.cat([conditions[start_frame : curr_frame + horizon], conditions[-self.condtion_similar_length:]], dim=0) |
| xs_pred[start_frame:] = self.diffusion_model.sample_step( |
| xs_pred[start_frame:], |
| input_condition, |
| from_noise_levels[start_frame:], |
| to_noise_levels[start_frame:], |
| ) |
|
|
| if self.condtion_similar_length: |
| xs_pred = xs_pred[:-self.condtion_similar_length] |
|
|
| curr_frame += horizon |
| pbar.update(horizon) |
|
|
| if self.condtion_similar_length: |
| xs = xs[:-self.condtion_similar_length] |
| |
| loss = F.mse_loss(xs_pred, xs, reduction="none") |
| loss = self.reweight_loss(loss, masks) |
| self.validation_step_outputs.append((xs_pred.detach().cpu(), xs.detach().cpu())) |
|
|
| return loss |
|
|
| def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: |
| return self.validation_step(*args, **kwargs, namespace="test") |
|
|
| def on_test_epoch_end(self) -> None: |
| self.on_validation_epoch_end(namespace="test") |
|
|
| def _generate_noise_levels(self, xs: torch.Tensor, masks: Optional[torch.Tensor] = None) -> torch.Tensor: |
| """ |
| Generate noise levels for training. |
| """ |
| num_frames, batch_size, *_ = xs.shape |
| match self.cfg.noise_level: |
| case "random_all": |
| noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device) |
| case "same": |
| noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device) |
| noise_levels[1:] = noise_levels[0] |
|
|
| if masks is not None: |
| |
| discard = torch.all(~rearrange(masks.bool(), "(t fs) b -> t b fs", fs=self.frame_stack), -1) |
| noise_levels = torch.where(discard, torch.full_like(noise_levels, self.timesteps - 1), noise_levels) |
|
|
| return noise_levels |
|
|
| def _generate_scheduling_matrix(self, horizon: int): |
| match self.cfg.scheduling_matrix: |
| case "pyramid": |
| return self._generate_pyramid_scheduling_matrix(horizon, self.uncertainty_scale) |
| case "full_sequence": |
| return np.arange(self.sampling_timesteps, -1, -1)[:, None].repeat(horizon, axis=1) |
| case "autoregressive": |
| return self._generate_pyramid_scheduling_matrix(horizon, self.sampling_timesteps) |
| case "trapezoid": |
| return self._generate_trapezoid_scheduling_matrix(horizon, self.uncertainty_scale) |
|
|
| def _generate_pyramid_scheduling_matrix(self, horizon: int, uncertainty_scale: float): |
| height = self.sampling_timesteps + int((horizon - 1) * uncertainty_scale) + 1 |
| scheduling_matrix = np.zeros((height, horizon), dtype=np.int64) |
| for m in range(height): |
| for t in range(horizon): |
| scheduling_matrix[m, t] = self.sampling_timesteps + int(t * uncertainty_scale) - m |
|
|
| return np.clip(scheduling_matrix, 0, self.sampling_timesteps) |
|
|
| def _generate_trapezoid_scheduling_matrix(self, horizon: int, uncertainty_scale: float): |
| height = self.sampling_timesteps + int((horizon + 1) // 2 * uncertainty_scale) |
| scheduling_matrix = np.zeros((height, horizon), dtype=np.int64) |
| for m in range(height): |
| for t in range((horizon + 1) // 2): |
| scheduling_matrix[m, t] = self.sampling_timesteps + int(t * uncertainty_scale) - m |
| scheduling_matrix[m, -t] = self.sampling_timesteps + int(t * uncertainty_scale) - m |
|
|
| return np.clip(scheduling_matrix, 0, self.sampling_timesteps) |
|
|
| def reweight_loss(self, loss, weight=None): |
| |
| loss = rearrange(loss, "t b (fs c) ... -> t b fs c ...", fs=self.frame_stack) |
| if weight is not None: |
| expand_dim = len(loss.shape) - len(weight.shape) - 1 |
| weight = rearrange( |
| weight, |
| "(t fs) b ... -> t b fs ..." + " 1" * expand_dim, |
| fs=self.frame_stack, |
| ) |
| loss = loss * weight |
|
|
| return loss.mean() |
|
|
| def _preprocess_batch(self, batch): |
| xs = batch[0] |
| batch_size, n_frames = xs.shape[:2] |
|
|
| if n_frames % self.frame_stack != 0: |
| raise ValueError("Number of frames must be divisible by frame stack size") |
| if self.context_frames % self.frame_stack != 0: |
| raise ValueError("Number of context frames must be divisible by frame stack size") |
|
|
| masks = torch.ones(n_frames, batch_size).to(xs.device) |
| n_frames = n_frames // self.frame_stack |
|
|
| if self.action_cond_dim: |
| conditions = batch[1] |
| conditions = torch.cat([torch.zeros_like(conditions[:, :1]), conditions[:, 1:]], 1) |
| conditions = rearrange(conditions, "b (t fs) d -> t b (fs d)", fs=self.frame_stack).contiguous() |
|
|
| |
| |
| |
| |
| |
| else: |
| conditions = [None for _ in range(n_frames)] |
|
|
| xs = self._normalize_x(xs) |
| xs = rearrange(xs, "b (t fs) c ... -> t b (fs c) ...", fs=self.frame_stack).contiguous() |
|
|
| return xs, conditions, masks |
|
|
| def _normalize_x(self, xs): |
| shape = [1] * (xs.ndim - self.data_mean.ndim) + list(self.data_mean.shape) |
| mean = self.data_mean.reshape(shape) |
| std = self.data_std.reshape(shape) |
| return (xs - mean) / std |
|
|
| def _unnormalize_x(self, xs): |
| shape = [1] * (xs.ndim - self.data_mean.ndim) + list(self.data_mean.shape) |
| mean = self.data_mean.reshape(shape) |
| std = self.data_std.reshape(shape) |
| return xs * std + mean |
|
|
| def _unstack_and_unnormalize(self, xs): |
| xs = rearrange(xs, "t b (fs c) ... -> (t fs) b c ...", fs=self.frame_stack) |
| return self._unnormalize_x(xs) |
|
|