Spaces:
Sleeping
Sleeping
File size: 3,642 Bytes
37163a6 2279ae0 37163a6 2279ae0 37163a6 2279ae0 37163a6 2279ae0 37163a6 2279ae0 37163a6 2279ae0 37163a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import torch
from models import register
@register('fm')
class FM:
def __init__(self, sigma_min=1e-5, timescale=1.0):
self.sigma_min = sigma_min
self.prediction_type = None
self.timescale = timescale
def alpha(self, t):
return 1.0 - t
def sigma(self, t):
return self.sigma_min + t * (1.0 - self.sigma_min)
def A(self, t):
return 1.0
def B(self, t):
return -(1.0 - self.sigma_min)
def _get_reduction_dims(self, x):
"""Get appropriate dimensions for loss reduction based on tensor shape"""
if x.dim() == 4:
# Images: [batch, channels, height, width]
return [1, 2, 3]
elif x.dim() == 3:
# Audio: [batch, channels, samples] or [batch, latent_dim, time_frames]
return [1, 2]
elif x.dim() == 2:
# 1D signals: [batch, samples]
return [1]
else:
# Fallback: reduce over all non-batch dimensions
return list(range(1, x.dim()))
def get_betas(self, n_timesteps):
return torch.zeros(n_timesteps) # Not VP and not supported
def add_noise(self, x, t, noise=None):
noise = torch.randn_like(x) if noise is None else noise
s = [x.shape[0]] + [1] * (x.dim() - 1)
x_t = self.alpha(t).view(*s) * x + self.sigma(t).view(*s) * noise
return x_t, noise
def loss(self, net, x, t=None, net_kwargs=None, return_loss_unreduced=False, return_all=False):
if net_kwargs is None:
net_kwargs = {}
if t is None:
t = torch.rand(x.shape[0], device=x.device)
# print('x shape: ', x.shape)
x_t, noise = self.add_noise(x, t)
# print('x_t shape: ', x_t.shape)
pred = net(x_t, t=t * self.timescale, **net_kwargs)
# print('pred shape: ', pred.shape)
target = self.A(t) * x + self.B(t) * noise # -dxt/dt
# print('target shape: ', target.shape)
# print('return_loss_unreduced: ', return_loss_unreduced, 'return_all: ', return_all)
if return_loss_unreduced:
print('pred shape: ', pred.shape, 'target shape: ', target.shape)
reduce_dims = self._get_reduction_dims(x)
loss = ((pred.float() - target.float()) ** 2).mean(dim=reduce_dims)
# loss = ((pred.float() - target.float()) ** 2).mean(dim=[1, 2, 3])
if return_all:
return loss, t, x_t, pred
else:
return loss, t
else:
# here we go
loss = ((pred.float() - target.float()) ** 2).mean()
if return_all:
return loss, x_t, pred
else:
return loss
def get_prediction(
self,
net,
x_t,
t,
net_kwargs=None,
uncond_net_kwargs=None,
guidance=1.0,
):
if net_kwargs is None:
net_kwargs = {}
pred = net(x_t, t=t * self.timescale, **net_kwargs)
if guidance != 1.0:
assert uncond_net_kwargs is not None
uncond_pred = net(x_t, t=t * self.timescale, **uncond_net_kwargs)
pred = uncond_pred + guidance * (pred - uncond_pred)
return pred
def convert_sample_prediction(self, x_t, t, pred):
M = torch.tensor([
[self.alpha(t), self.sigma(t)],
[self.A(t), self.B(t)],
], dtype=torch.float64)
M_inv = torch.linalg.inv(M)
sample_pred = M_inv[0, 0].item() * x_t + M_inv[0, 1].item() * pred
return sample_pred
|