import torch from enum import Enum from ppd.utils.timesteps import Timesteps from ppd.utils.schedule import LinearSchedule class EulerSampler: """ The Euler method is the simplest ODE solver. """ def __init__( self, schedule: LinearSchedule, timesteps: Timesteps, prediction_type: 'velocity', ): self.schedule = schedule self.timesteps = timesteps self.prediction_type = prediction_type def step( self, pred: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor, **kwargs, ) -> torch.Tensor: """ Step to the next timestep. """ return self.step_to(pred, x_t, t, self.get_next_timestep(t), **kwargs) def step_to( self, pred: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor, s: torch.Tensor, **kwargs, ) -> torch.Tensor: """ Steps from x_t at timestep t to x_s at timestep s. Returns x_s. """ t = t[(...,) + (None,) * (x_t.ndim - t.ndim)] if t.ndim < x_t.ndim else t s = s[(...,) + (None,) * (x_t.ndim - s.ndim)] if s.ndim < x_t.ndim else s T = self.schedule.T # Step from x_t to x_s. pred_x_0, pred_x_T = self.schedule.convert_from_pred(pred, self.prediction_type, x_t, t) pred_x_s = self.schedule.forward(pred_x_0, pred_x_T, s.clamp(0, T)) # Clamp x_s to x_0 and x_T if s is out of bound. pred_x_s = pred_x_s.where(s >= 0, pred_x_0) pred_x_s = pred_x_s.where(s <= T, pred_x_T) return pred_x_s def get_next_timestep( self, t: torch.Tensor, ) -> torch.Tensor: """ Get the next sample timestep. Support multiple different timesteps t in a batch. If no more steps, return out of bound value -1 or T+1. """ T = self.timesteps.T steps = len(self.timesteps) curr_idx = self.timesteps.index(t) next_idx = curr_idx + 1 s = self.timesteps[next_idx.clamp_max(steps - 1)] s = s.where(next_idx < steps, -1) return s