Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| from models import register | |
| class FMEulerSampler: | |
| def __init__(self, diffusion): | |
| self.diffusion = diffusion | |
| def sample( | |
| self, | |
| net, | |
| shape, | |
| n_steps, | |
| net_kwargs=None, | |
| uncond_net_kwargs=None, | |
| guidance=1.0, | |
| noise=None, | |
| ): | |
| device = next(net.parameters()).device | |
| x_t = torch.randn(shape, device=device) if noise is None else noise | |
| t_steps = torch.linspace(1, 0, n_steps + 1, device=device) | |
| with torch.no_grad(): | |
| for i in range(n_steps): | |
| t = t_steps[i].repeat(x_t.shape[0]) | |
| neg_v = self.diffusion.get_prediction( | |
| net, | |
| x_t, | |
| t, | |
| net_kwargs=net_kwargs, | |
| uncond_net_kwargs=uncond_net_kwargs, | |
| guidance=guidance, | |
| ) | |
| x_t = x_t + neg_v * (t_steps[i] - t_steps[i + 1]) | |
| return x_t | |