primepake
add training flowvae
4f877a2
import numpy as np
import torch
from models import register
@register('fm_euler_sampler')
class FMEulerSampler:
def __init__(self, diffusion):
self.diffusion = diffusion
def sample(
self,
net,
shape,
n_steps,
net_kwargs=None,
uncond_net_kwargs=None,
guidance=1.0,
noise=None,
):
device = next(net.parameters()).device
x_t = torch.randn(shape, device=device) if noise is None else noise
t_steps = torch.linspace(1, 0, n_steps + 1, device=device)
with torch.no_grad():
for i in range(n_steps):
t = t_steps[i].repeat(x_t.shape[0])
neg_v = self.diffusion.get_prediction(
net,
x_t,
t,
net_kwargs=net_kwargs,
uncond_net_kwargs=uncond_net_kwargs,
guidance=guidance,
)
x_t = x_t + neg_v * (t_steps[i] - t_steps[i + 1])
return x_t