Spaces:
Sleeping
Sleeping
primepake
commited on
Commit
·
62d19d0
1
Parent(s):
0a7b8fc
apply immiscible random noise
Browse files
flowae/models/diffusion/fm.py
CHANGED
|
@@ -6,11 +6,13 @@ from models import register
|
|
| 6 |
@register('fm')
|
| 7 |
class FM:
|
| 8 |
|
| 9 |
-
def __init__(self, sigma_min=1e-5, timescale=1.0):
|
| 10 |
self.sigma_min = sigma_min
|
| 11 |
self.prediction_type = None
|
| 12 |
self.timescale = timescale
|
| 13 |
-
|
|
|
|
|
|
|
| 14 |
def alpha(self, t):
|
| 15 |
return 1.0 - t
|
| 16 |
|
|
@@ -23,6 +25,32 @@ class FM:
|
|
| 23 |
def B(self, t):
|
| 24 |
return -(1.0 - self.sigma_min)
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
def _get_reduction_dims(self, x):
|
| 27 |
"""Get appropriate dimensions for loss reduction based on tensor shape"""
|
| 28 |
if x.dim() == 4:
|
|
@@ -42,7 +70,11 @@ class FM:
|
|
| 42 |
return torch.zeros(n_timesteps) # Not VP and not supported
|
| 43 |
|
| 44 |
def add_noise(self, x, t, noise=None):
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
s = [x.shape[0]] + [1] * (x.dim() - 1)
|
| 47 |
x_t = self.alpha(t).view(*s) * x + self.sigma(t).view(*s) * noise
|
| 48 |
return x_t, noise
|
|
|
|
| 6 |
@register('fm')
|
| 7 |
class FM:
|
| 8 |
|
| 9 |
+
def __init__(self, sigma_min=1e-5, timescale=1.0, use_immiscible=True, k_candidates=4):
|
| 10 |
self.sigma_min = sigma_min
|
| 11 |
self.prediction_type = None
|
| 12 |
self.timescale = timescale
|
| 13 |
+
self.use_immiscible = use_immiscible
|
| 14 |
+
self.k_candidates = k_candidates
|
| 15 |
+
|
| 16 |
def alpha(self, t):
|
| 17 |
return 1.0 - t
|
| 18 |
|
|
|
|
| 25 |
def B(self, t):
|
| 26 |
return -(1.0 - self.sigma_min)
|
| 27 |
|
| 28 |
+
def get_immiscible_noise(self, x, k=4):
|
| 29 |
+
"""Generate noise using k-NN immiscible assignment"""
|
| 30 |
+
batch_size = x.shape[0]
|
| 31 |
+
|
| 32 |
+
# Generate k noise candidates
|
| 33 |
+
noise_candidates = torch.randn(batch_size, k, *x.shape[1:], device=x.device)
|
| 34 |
+
|
| 35 |
+
# Flatten for distance computation (use fp16 for efficiency)
|
| 36 |
+
x_flat = x.reshape(batch_size, -1).half()
|
| 37 |
+
noise_flat = noise_candidates.reshape(batch_size, k, -1).half()
|
| 38 |
+
|
| 39 |
+
# Compute distances
|
| 40 |
+
distances = torch.norm(x_flat.unsqueeze(1) - noise_flat, dim=2)
|
| 41 |
+
|
| 42 |
+
# Select closest noise
|
| 43 |
+
min_indices = distances.argmin(dim=1)
|
| 44 |
+
|
| 45 |
+
# Gather selected noise
|
| 46 |
+
noise = torch.gather(
|
| 47 |
+
noise_candidates,
|
| 48 |
+
1,
|
| 49 |
+
min_indices.view(batch_size, 1, *([1] * (x.dim() - 1))).expand(-1, 1, *x.shape[1:])
|
| 50 |
+
).squeeze(1)
|
| 51 |
+
|
| 52 |
+
return noise
|
| 53 |
+
|
| 54 |
def _get_reduction_dims(self, x):
|
| 55 |
"""Get appropriate dimensions for loss reduction based on tensor shape"""
|
| 56 |
if x.dim() == 4:
|
|
|
|
| 70 |
return torch.zeros(n_timesteps) # Not VP and not supported
|
| 71 |
|
| 72 |
def add_noise(self, x, t, noise=None):
|
| 73 |
+
if noise is None:
|
| 74 |
+
if self.use_immiscible:
|
| 75 |
+
noise = self.get_immiscible_noise(x, self.k_candidates)
|
| 76 |
+
else:
|
| 77 |
+
noise = torch.randn_like(x)
|
| 78 |
s = [x.shape[0]] + [1] * (x.dim() - 1)
|
| 79 |
x_t = self.alpha(t).view(*s) * x + self.sigma(t).view(*s) * noise
|
| 80 |
return x_t, noise
|