Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,142 Bytes
709cfd2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import torch
from enum import Enum
from ppd.utils.timesteps import Timesteps
from ppd.utils.schedule import LinearSchedule
class EulerSampler:
"""
The Euler method is the simplest ODE solver.
"""
def __init__(
self,
schedule: LinearSchedule,
timesteps: Timesteps,
prediction_type: 'velocity',
):
self.schedule = schedule
self.timesteps = timesteps
self.prediction_type = prediction_type
def step(
self,
pred: torch.Tensor,
x_t: torch.Tensor,
t: torch.Tensor,
**kwargs,
) -> torch.Tensor:
"""
Step to the next timestep.
"""
return self.step_to(pred, x_t, t, self.get_next_timestep(t), **kwargs)
def step_to(
self,
pred: torch.Tensor,
x_t: torch.Tensor,
t: torch.Tensor,
s: torch.Tensor,
**kwargs,
) -> torch.Tensor:
"""
Steps from x_t at timestep t to x_s at timestep s. Returns x_s.
"""
t = t[(...,) + (None,) * (x_t.ndim - t.ndim)] if t.ndim < x_t.ndim else t
s = s[(...,) + (None,) * (x_t.ndim - s.ndim)] if s.ndim < x_t.ndim else s
T = self.schedule.T
# Step from x_t to x_s.
pred_x_0, pred_x_T = self.schedule.convert_from_pred(pred, self.prediction_type, x_t, t)
pred_x_s = self.schedule.forward(pred_x_0, pred_x_T, s.clamp(0, T))
# Clamp x_s to x_0 and x_T if s is out of bound.
pred_x_s = pred_x_s.where(s >= 0, pred_x_0)
pred_x_s = pred_x_s.where(s <= T, pred_x_T)
return pred_x_s
def get_next_timestep(
self,
t: torch.Tensor,
) -> torch.Tensor:
"""
Get the next sample timestep.
Support multiple different timesteps t in a batch.
If no more steps, return out of bound value -1 or T+1.
"""
T = self.timesteps.T
steps = len(self.timesteps)
curr_idx = self.timesteps.index(t)
next_idx = curr_idx + 1
s = self.timesteps[next_idx.clamp_max(steps - 1)]
s = s.where(next_idx < steps, -1)
return s |