primepake commited on
Commit
62d19d0
·
1 Parent(s): 0a7b8fc

apply immiscible random noise

Browse files
Files changed (1) hide show
  1. flowae/models/diffusion/fm.py +35 -3
flowae/models/diffusion/fm.py CHANGED
@@ -6,11 +6,13 @@ from models import register
6
  @register('fm')
7
  class FM:
8
 
9
- def __init__(self, sigma_min=1e-5, timescale=1.0):
10
  self.sigma_min = sigma_min
11
  self.prediction_type = None
12
  self.timescale = timescale
13
-
 
 
14
  def alpha(self, t):
15
  return 1.0 - t
16
 
@@ -23,6 +25,32 @@ class FM:
23
  def B(self, t):
24
  return -(1.0 - self.sigma_min)
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def _get_reduction_dims(self, x):
27
  """Get appropriate dimensions for loss reduction based on tensor shape"""
28
  if x.dim() == 4:
@@ -42,7 +70,11 @@ class FM:
42
  return torch.zeros(n_timesteps) # Not VP and not supported
43
 
44
  def add_noise(self, x, t, noise=None):
45
- noise = torch.randn_like(x) if noise is None else noise
 
 
 
 
46
  s = [x.shape[0]] + [1] * (x.dim() - 1)
47
  x_t = self.alpha(t).view(*s) * x + self.sigma(t).view(*s) * noise
48
  return x_t, noise
 
6
  @register('fm')
7
  class FM:
8
 
9
+ def __init__(self, sigma_min=1e-5, timescale=1.0, use_immiscible=True, k_candidates=4):
10
  self.sigma_min = sigma_min
11
  self.prediction_type = None
12
  self.timescale = timescale
13
+ self.use_immiscible = use_immiscible
14
+ self.k_candidates = k_candidates
15
+
16
  def alpha(self, t):
17
  return 1.0 - t
18
 
 
25
  def B(self, t):
26
  return -(1.0 - self.sigma_min)
27
 
28
+ def get_immiscible_noise(self, x, k=4):
29
+ """Generate noise using k-NN immiscible assignment"""
30
+ batch_size = x.shape[0]
31
+
32
+ # Generate k noise candidates
33
+ noise_candidates = torch.randn(batch_size, k, *x.shape[1:], device=x.device)
34
+
35
+ # Flatten for distance computation (use fp16 for efficiency)
36
+ x_flat = x.reshape(batch_size, -1).half()
37
+ noise_flat = noise_candidates.reshape(batch_size, k, -1).half()
38
+
39
+ # Compute distances
40
+ distances = torch.norm(x_flat.unsqueeze(1) - noise_flat, dim=2)
41
+
42
+ # Select closest noise
43
+ min_indices = distances.argmin(dim=1)
44
+
45
+ # Gather selected noise
46
+ noise = torch.gather(
47
+ noise_candidates,
48
+ 1,
49
+ min_indices.view(batch_size, 1, *([1] * (x.dim() - 1))).expand(-1, 1, *x.shape[1:])
50
+ ).squeeze(1)
51
+
52
+ return noise
53
+
54
  def _get_reduction_dims(self, x):
55
  """Get appropriate dimensions for loss reduction based on tensor shape"""
56
  if x.dim() == 4:
 
70
  return torch.zeros(n_timesteps) # Not VP and not supported
71
 
72
  def add_noise(self, x, t, noise=None):
73
+ if noise is None:
74
+ if self.use_immiscible:
75
+ noise = self.get_immiscible_noise(x, self.k_candidates)
76
+ else:
77
+ noise = torch.randn_like(x)
78
  s = [x.shape[0]] + [1] * (x.dim() - 1)
79
  x_t = self.alpha(t).view(*s) * x + self.sigma(t).view(*s) * noise
80
  return x_t, noise