import numpy as np import torch from models import register @register('fm_euler_sampler') class FMEulerSampler: def __init__(self, diffusion): self.diffusion = diffusion def sample( self, net, shape, n_steps, net_kwargs=None, uncond_net_kwargs=None, guidance=1.0, noise=None, ): device = next(net.parameters()).device x_t = torch.randn(shape, device=device) if noise is None else noise t_steps = torch.linspace(1, 0, n_steps + 1, device=device) with torch.no_grad(): for i in range(n_steps): t = t_steps[i].repeat(x_t.shape[0]) neg_v = self.diffusion.get_prediction( net, x_t, t, net_kwargs=net_kwargs, uncond_net_kwargs=uncond_net_kwargs, guidance=guidance, ) x_t = x_t + neg_v * (t_steps[i] - t_steps[i + 1]) return x_t