krystv commited on
Commit
95b50da
·
verified ·
1 Parent(s): f143208

Upload liquid_flow/generator.py

Browse files
Files changed (1) hide show
  1. liquid_flow/generator.py +53 -198
liquid_flow/generator.py CHANGED
@@ -1,40 +1,18 @@
1
  """
2
  LiquidFlow Generator — Main diffusion model.
3
-
4
- Combines:
5
- - LiquidFlowBackbone (CfC + Mamba-2 SSD) as the noise predictor
6
- - DDPM/DDIM diffusion process
7
- - Physics-informed regularization
8
-
9
- Supports:
10
- - Training on 128×128 and 512×512 images
11
- - TAESD VAE (lightweight, Colab/Kaggle compatible)
12
- - SD VAE (higher quality)
13
- - Both DDPM and DDIM sampling
14
-
15
- The model is designed to be:
16
- - Trainable on Google Colab free tier / Kaggle (T4 GPU, 15GB)
17
- - Exportable to ONNX/CoreML for mobile deployment
18
- - Pure PyTorch — no CUDA kernels needed (Mamba-2 SSD runs on CPU too)
19
  """
20
 
21
  import torch
22
  import torch.nn as nn
23
  import torch.nn.functional as F
24
  import math
25
- import numpy as np
26
  from tqdm import tqdm
27
- from typing import Optional, Dict, Tuple
28
 
29
  from .liquid_flow_block import LiquidFlowBackbone
30
  from .physics_loss import PhysicsRegularizer, DDIMEstimator
31
 
32
 
33
- def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
34
- """Linear noise schedule (DDPM)."""
35
- return torch.linspace(beta_start, beta_end, timesteps)
36
-
37
-
38
  def cosine_beta_schedule(timesteps, s=0.008):
39
  """Cosine noise schedule (Improved DDPM)."""
40
  steps = timesteps + 1
@@ -45,26 +23,14 @@ def cosine_beta_schedule(timesteps, s=0.008):
45
  return torch.clip(betas, 0.0001, 0.9999)
46
 
47
 
 
 
 
 
 
48
  class LiquidFlowGenerator(nn.Module):
49
  """
50
  LiquidFlow Generator: Liquid Neural Network + Mamba-2 SSD Diffusion Model.
51
-
52
- Uses LiquidFlowBackbone as noise predictor in a DDPM/DDIM framework.
53
-
54
- Architecture:
55
- Noise Predictor = LiquidFlowBackbone (CfC + Mamba-2 SSD)
56
- Diffusion = DDPM (forward) + DDIM (sampling)
57
- Regularizer = Physics-Informed Losses (TV, spectral, conservation)
58
-
59
- Args:
60
- in_channels: Latent channels from VAE (default 4)
61
- hidden_dim: Hidden dimension in backbone
62
- num_stages: Number of LiquidFlow stages
63
- blocks_per_stage: Blocks per stage
64
- image_size: Target image size (for latent computation)
65
- beta_schedule: 'linear' or 'cosine'
66
- timesteps: Number of diffusion timesteps
67
- physics_weights: Weights for physics regularizers
68
  """
69
 
70
  def __init__(
@@ -81,10 +47,10 @@ class LiquidFlowGenerator(nn.Module):
81
  super().__init__()
82
  self.in_channels = in_channels
83
  self.hidden_dim = hidden_dim
84
- self.image_size = image_size # Latent space size = image_size / 8
85
  self.timesteps = timesteps
86
 
87
- # Noise predictor (backbone)
88
  self.backbone = LiquidFlowBackbone(
89
  in_channels=in_channels,
90
  hidden_dim=hidden_dim,
@@ -105,85 +71,58 @@ class LiquidFlowGenerator(nn.Module):
105
  self.register_buffer('alphas', 1.0 - betas)
106
  self.register_buffer('alphas_cumprod', torch.cumprod(self.alphas, dim=0))
107
  self.register_buffer('alphas_cumprod_prev', F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0))
108
-
109
- # For DDIM sampling
110
  self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
111
  self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - self.alphas_cumprod))
112
 
113
- # Physics regularizer
114
  if physics_weights is None:
115
- physics_weights = {'tv': 0.01, 'cons': 0.001, 'spec': 0.01, 'grad': 0.001}
116
- self.physics = PhysicsRegularizer(**physics_weights)
 
 
 
 
 
 
117
  self.ddim_estimator = DDIMEstimator()
118
 
119
  def q_sample(self, x0, t, noise=None):
120
- """
121
- Forward diffusion: q(x_t | x_0).
122
-
123
- x_t = √(ᾱ_t) * x_0 + √(1 - ᾱ_t) * ε
124
- """
125
  if noise is None:
126
  noise = torch.randn_like(x0)
127
-
128
- sqrt_alpha_bar = self.sqrt_alphas_cumprod[t].reshape(-1, 1, 1, 1)
129
- sqrt_one_minus_alpha_bar = self.sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1, 1, 1)
130
-
131
- return sqrt_alpha_bar * x0 + sqrt_one_minus_alpha_bar * noise, noise
132
 
133
  def forward(self, x, t):
134
  """Predict noise from noisy input."""
135
  return self.backbone(x, t)
136
 
137
  def training_step(self, x0, optimizer, scaler=None, use_amp=False):
138
- """
139
- Single training step with physics regularization.
140
-
141
- Args:
142
- x0: Clean latents [B, C, H, W]
143
- optimizer: Optimizer
144
- scaler: Optional GradScaler for AMP
145
- use_amp: Whether to use automatic mixed precision
146
-
147
- Returns:
148
- loss_dict: Dictionary of losses
149
- """
150
  B = x0.shape[0]
151
  device = x0.device
152
 
153
- # Sample timesteps
154
  t = torch.randint(0, self.timesteps, (B,), device=device)
155
-
156
- # Forward diffusion
157
  noise = torch.randn_like(x0)
158
  xt, noise = self.q_sample(x0, t, noise)
159
 
 
160
  if use_amp and scaler is not None:
161
  with torch.cuda.amp.autocast():
162
- # Predict noise
163
  noise_pred = self.forward(xt, t)
164
-
165
- # Base diffusion loss (L2 or L1)
166
  diffusion_loss = F.mse_loss(noise_pred, noise)
167
-
168
- # Physics regularization on estimated x0
169
- x0_hat = self.ddim_estimator.estimate_x0(
170
- xt, noise_pred, self.alphas_cumprod[t]
171
- )
172
- phys_loss, phys_dict = self.physics(x0_hat, x0)
173
-
174
  total_loss = diffusion_loss + phys_loss
175
  else:
176
  noise_pred = self.forward(xt, t)
177
  diffusion_loss = F.mse_loss(noise_pred, noise)
178
-
179
- x0_hat = self.ddim_estimator.estimate_x0(
180
- xt, noise_pred, self.alphas_cumprod[t]
181
- )
182
- phys_loss, phys_dict = self.physics(x0_hat, x0)
183
-
184
  total_loss = diffusion_loss + phys_loss
185
 
186
- # Backward
187
  optimizer.zero_grad()
188
  if scaler is not None:
189
  scaler.scale(total_loss).backward()
@@ -199,29 +138,15 @@ class LiquidFlowGenerator(nn.Module):
199
  return {
200
  'total': total_loss.item(),
201
  'diffusion': diffusion_loss.item(),
202
- 'physics': phys_loss.item(),
203
- **{f'phys_{k}': v.item() for k, v in phys_dict.items()},
204
  }
205
 
206
  @torch.no_grad()
207
  def sample(self, batch_size=4, steps=50, ddim=True, eta=0.0, progress=True):
208
- """
209
- Generate images using DDPM or DDIM sampling.
210
-
211
- Args:
212
- batch_size: Number of images
213
- steps: Sampling steps (for DDIM: can be << timesteps)
214
- ddim: Use DDIM sampling (faster)
215
- eta: DDIM stochasticity (0 = deterministic)
216
- progress: Show progress bar
217
-
218
- Returns:
219
- Generated latents [B, C, H, W]
220
- """
221
  device = next(self.parameters()).device
222
  latent_size = self.image_size // 8
223
-
224
- # Start from pure noise
225
  x = torch.randn(batch_size, self.in_channels, latent_size, latent_size, device=device)
226
 
227
  if ddim:
@@ -231,133 +156,63 @@ class LiquidFlowGenerator(nn.Module):
231
 
232
  @torch.no_grad()
233
  def _ddpm_sample(self, x, progress=True):
234
- """DDPM sampling (full 1000 steps)."""
235
  device = x.device
236
-
237
- iterator = tqdm(
238
- reversed(range(0, self.timesteps)),
239
- desc='DDPM Sampling',
240
- total=self.timesteps,
241
- disable=not progress,
242
- )
243
-
244
- for t_idx in iterator:
245
  t = torch.full((x.shape[0],), t_idx, device=device, dtype=torch.long)
246
-
247
  noise_pred = self.forward(x, t)
248
-
249
  alpha = self.alphas[t_idx]
250
  alpha_bar = self.alphas_cumprod[t_idx]
251
- alpha_bar_prev = self.alphas_cumprod_prev[t_idx]
252
  beta = self.betas[t_idx]
253
-
254
- if t_idx > 0:
255
- noise = torch.randn_like(x)
256
- else:
257
- noise = 0
258
-
259
- # DDPM posterior
260
- x = (1 / torch.sqrt(alpha)) * (
261
- x - (beta / torch.sqrt(1 - alpha_bar)) * noise_pred
262
- ) + torch.sqrt(beta) * noise
263
-
264
  return x
265
 
266
  @torch.no_grad()
267
  def _ddim_sample(self, x, steps=50, eta=0.0, progress=True):
268
- """
269
- DDIM sampling with fewer steps.
270
-
271
- DDIM can produce good samples in 20-50 steps
272
- instead of 1000 DDPM steps.
273
- """
274
  device = x.device
275
-
276
- # Timestep spacing
277
  skip = self.timesteps // steps
278
  seq = list(range(0, self.timesteps, skip))
279
  seq_next = [-1] + seq[:-1]
280
 
281
- iterator = tqdm(
282
- zip(reversed(seq), reversed(seq_next)),
283
- desc='DDIM Sampling',
284
- total=len(seq),
285
- disable=not progress,
286
- )
287
-
288
- for i, j in iterator:
289
  t = torch.full((x.shape[0],), i, device=device, dtype=torch.long)
290
-
291
  noise_pred = self.forward(x, t)
292
 
293
- alpha_bar_i = self.alphas_cumprod[i]
294
- alpha_bar_j = self.alphas_cumprod[j] if j >= 0 else torch.tensor(1.0, device=device)
295
 
296
- # Predicted x0
297
- x0_pred = (x - torch.sqrt(1 - alpha_bar_i) * noise_pred) / torch.sqrt(alpha_bar_i)
298
- x0_pred = torch.clamp(x0_pred, -1, 1) # Prevent outliers
299
 
300
- # Direction pointing to x_t
301
- dir_xt = torch.sqrt(1 - alpha_bar_j - eta * eta * (
302
- (1 - alpha_bar_j) / (1 - alpha_bar_i)
303
- )) * noise_pred
304
 
305
- # Random noise
306
  if eta > 0:
307
- noise = torch.randn_like(x)
308
- sigma = eta * torch.sqrt((1 - alpha_bar_j) / (1 - alpha_bar_i) * (1 - alpha_bar_i / alpha_bar_j))
309
- x = torch.sqrt(alpha_bar_j) * x0_pred + dir_xt + sigma * noise
310
- else:
311
- noise = 0
312
- x = torch.sqrt(alpha_bar_j) * x0_pred + dir_xt
313
 
314
  return x
315
 
316
  def count_parameters(self):
317
- """Count trainable parameters."""
318
  return sum(p.numel() for p in self.parameters() if p.requires_grad)
319
 
320
 
321
- def create_liquidflow(
322
- variant='small',
323
- image_size=128,
324
- **kwargs,
325
- ):
326
  """
327
- Create a LiquidFlow model with preset configurations.
328
 
329
  Variants:
330
- - 'tiny': ~2M params, 2 stages, 2 blocks each, hidden_dim=128
331
- - 'small': ~8M params, 4 stages, 4 blocks each, hidden_dim=256
332
- - 'base': ~30M params, 6 stages, 6 blocks each, hidden_dim=384
333
-
334
- All designed to run on T4 (15GB) with batch_size >= 16 at 128×128.
335
  """
336
  configs = {
337
- 'tiny': {
338
- 'hidden_dim': 128,
339
- 'num_stages': 2,
340
- 'blocks_per_stage': 2,
341
- },
342
- 'small': {
343
- 'hidden_dim': 256,
344
- 'num_stages': 4,
345
- 'blocks_per_stage': 4,
346
- },
347
- 'base': {
348
- 'hidden_dim': 384,
349
- 'num_stages': 6,
350
- 'blocks_per_stage': 6,
351
- },
352
  }
353
-
354
  config = configs.get(variant, configs['small'])
355
  config.update(kwargs)
356
 
357
- model = LiquidFlowGenerator(
358
- in_channels=4, # VAE latent channels
359
- image_size=image_size,
360
- **config,
361
- )
362
-
363
- return model
 
1
  """
2
  LiquidFlow Generator — Main diffusion model.
3
+ CORRECTED: physics_weights parameter naming, proper kwarg passing.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  """
5
 
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
  import math
 
10
  from tqdm import tqdm
 
11
 
12
  from .liquid_flow_block import LiquidFlowBackbone
13
  from .physics_loss import PhysicsRegularizer, DDIMEstimator
14
 
15
 
 
 
 
 
 
16
  def cosine_beta_schedule(timesteps, s=0.008):
17
  """Cosine noise schedule (Improved DDPM)."""
18
  steps = timesteps + 1
 
23
  return torch.clip(betas, 0.0001, 0.9999)
24
 
25
 
26
+ def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
27
+ """Linear noise schedule (DDPM)."""
28
+ return torch.linspace(beta_start, beta_end, timesteps)
29
+
30
+
31
  class LiquidFlowGenerator(nn.Module):
32
  """
33
  LiquidFlow Generator: Liquid Neural Network + Mamba-2 SSD Diffusion Model.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  """
35
 
36
  def __init__(
 
47
  super().__init__()
48
  self.in_channels = in_channels
49
  self.hidden_dim = hidden_dim
50
+ self.image_size = image_size
51
  self.timesteps = timesteps
52
 
53
+ # Noise predictor
54
  self.backbone = LiquidFlowBackbone(
55
  in_channels=in_channels,
56
  hidden_dim=hidden_dim,
 
71
  self.register_buffer('alphas', 1.0 - betas)
72
  self.register_buffer('alphas_cumprod', torch.cumprod(self.alphas, dim=0))
73
  self.register_buffer('alphas_cumprod_prev', F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0))
 
 
74
  self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
75
  self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - self.alphas_cumprod))
76
 
77
+ # Physics regularizer — note: keys are tv_weight, cons_weight, spec_weight, grad_weight
78
  if physics_weights is None:
79
+ physics_weights = {}
80
+ pw = {
81
+ 'tv_weight': physics_weights.get('tv', 0.01),
82
+ 'cons_weight': physics_weights.get('cons', 0.001),
83
+ 'spec_weight': physics_weights.get('spec', 0.01),
84
+ 'grad_weight': physics_weights.get('grad', 0.001),
85
+ }
86
+ self.physics = PhysicsRegularizer(**pw)
87
  self.ddim_estimator = DDIMEstimator()
88
 
89
  def q_sample(self, x0, t, noise=None):
90
+ """Forward diffusion: q(x_t | x_0)."""
 
 
 
 
91
  if noise is None:
92
  noise = torch.randn_like(x0)
93
+ sqrt_ab = self.sqrt_alphas_cumprod[t].reshape(-1, 1, 1, 1)
94
+ sqrt_1_ab = self.sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1, 1, 1)
95
+ return sqrt_ab * x0 + sqrt_1_ab * noise, noise
 
 
96
 
97
  def forward(self, x, t):
98
  """Predict noise from noisy input."""
99
  return self.backbone(x, t)
100
 
101
  def training_step(self, x0, optimizer, scaler=None, use_amp=False):
102
+ """Single training step with physics regularization."""
 
 
 
 
 
 
 
 
 
 
 
103
  B = x0.shape[0]
104
  device = x0.device
105
 
 
106
  t = torch.randint(0, self.timesteps, (B,), device=device)
 
 
107
  noise = torch.randn_like(x0)
108
  xt, noise = self.q_sample(x0, t, noise)
109
 
110
+ # Forward
111
  if use_amp and scaler is not None:
112
  with torch.cuda.amp.autocast():
 
113
  noise_pred = self.forward(xt, t)
 
 
114
  diffusion_loss = F.mse_loss(noise_pred, noise)
115
+ x0_hat = self.ddim_estimator.estimate_x0(xt, noise_pred, self.alphas_cumprod[t])
116
+ phys_loss, phys_dict = self.physics(x0_hat)
 
 
 
 
 
117
  total_loss = diffusion_loss + phys_loss
118
  else:
119
  noise_pred = self.forward(xt, t)
120
  diffusion_loss = F.mse_loss(noise_pred, noise)
121
+ x0_hat = self.ddim_estimator.estimate_x0(xt, noise_pred, self.alphas_cumprod[t])
122
+ phys_loss, phys_dict = self.physics(x0_hat)
 
 
 
 
123
  total_loss = diffusion_loss + phys_loss
124
 
125
+ # Backward + step
126
  optimizer.zero_grad()
127
  if scaler is not None:
128
  scaler.scale(total_loss).backward()
 
138
  return {
139
  'total': total_loss.item(),
140
  'diffusion': diffusion_loss.item(),
141
+ 'physics': phys_loss.item() if isinstance(phys_loss, torch.Tensor) else phys_loss,
142
+ **{f'phys_{k}': v.item() if isinstance(v, torch.Tensor) else v for k, v in phys_dict.items()},
143
  }
144
 
145
  @torch.no_grad()
146
  def sample(self, batch_size=4, steps=50, ddim=True, eta=0.0, progress=True):
147
+ """Generate images via DDIM or DDPM sampling."""
 
 
 
 
 
 
 
 
 
 
 
 
148
  device = next(self.parameters()).device
149
  latent_size = self.image_size // 8
 
 
150
  x = torch.randn(batch_size, self.in_channels, latent_size, latent_size, device=device)
151
 
152
  if ddim:
 
156
 
157
  @torch.no_grad()
158
  def _ddpm_sample(self, x, progress=True):
 
159
  device = x.device
160
+ for t_idx in tqdm(reversed(range(self.timesteps)), total=self.timesteps, disable=not progress):
 
 
 
 
 
 
 
 
161
  t = torch.full((x.shape[0],), t_idx, device=device, dtype=torch.long)
 
162
  noise_pred = self.forward(x, t)
 
163
  alpha = self.alphas[t_idx]
164
  alpha_bar = self.alphas_cumprod[t_idx]
 
165
  beta = self.betas[t_idx]
166
+ noise = torch.randn_like(x) if t_idx > 0 else 0
167
+ x = (1 / torch.sqrt(alpha)) * (x - (beta / torch.sqrt(1 - alpha_bar)) * noise_pred) + torch.sqrt(beta) * noise
 
 
 
 
 
 
 
 
 
168
  return x
169
 
170
  @torch.no_grad()
171
  def _ddim_sample(self, x, steps=50, eta=0.0, progress=True):
 
 
 
 
 
 
172
  device = x.device
 
 
173
  skip = self.timesteps // steps
174
  seq = list(range(0, self.timesteps, skip))
175
  seq_next = [-1] + seq[:-1]
176
 
177
+ for i, j in tqdm(zip(reversed(seq), reversed(seq_next)), total=len(seq), disable=not progress):
 
 
 
 
 
 
 
178
  t = torch.full((x.shape[0],), i, device=device, dtype=torch.long)
 
179
  noise_pred = self.forward(x, t)
180
 
181
+ ab_i = self.alphas_cumprod[i]
182
+ ab_j = self.alphas_cumprod[j] if j >= 0 else torch.tensor(1.0, device=device)
183
 
184
+ x0_pred = (x - torch.sqrt(1 - ab_i) * noise_pred) / (torch.sqrt(ab_i) + 1e-8)
185
+ x0_pred = x0_pred.clamp(-3, 3)
 
186
 
187
+ # DDIM update
188
+ dir_xt = torch.sqrt(1 - ab_j) * noise_pred
189
+ x = torch.sqrt(ab_j) * x0_pred + dir_xt
 
190
 
 
191
  if eta > 0:
192
+ sigma = eta * torch.sqrt((1 - ab_j) / (1 - ab_i + 1e-8) * (1 - ab_i / (ab_j + 1e-8)))
193
+ x = x + sigma * torch.randn_like(x)
 
 
 
 
194
 
195
  return x
196
 
197
  def count_parameters(self):
 
198
  return sum(p.numel() for p in self.parameters() if p.requires_grad)
199
 
200
 
201
+ def create_liquidflow(variant='small', image_size=128, **kwargs):
 
 
 
 
202
  """
203
+ Create LiquidFlow model.
204
 
205
  Variants:
206
+ - 'tiny': ~2M params, 2 stages × 2 blocks, hidden_dim=128
207
+ - 'small': ~8M params, 4 stages × 4 blocks, hidden_dim=256
208
+ - 'base': ~30M params, 6 stages × 6 blocks, hidden_dim=384
 
 
209
  """
210
  configs = {
211
+ 'tiny': {'hidden_dim': 128, 'num_stages': 2, 'blocks_per_stage': 2},
212
+ 'small': {'hidden_dim': 256, 'num_stages': 4, 'blocks_per_stage': 4},
213
+ 'base': {'hidden_dim': 384, 'num_stages': 6, 'blocks_per_stage': 6},
 
 
 
 
 
 
 
 
 
 
 
 
214
  }
 
215
  config = configs.get(variant, configs['small'])
216
  config.update(kwargs)
217
 
218
+ return LiquidFlowGenerator(in_channels=4, image_size=image_size, **config)