File size: 19,749 Bytes
4f46baa a733be1 4f46baa 193fbf7 4f46baa 193fbf7 4f46baa 193fbf7 4f46baa 193fbf7 4f46baa a733be1 4f46baa 193fbf7 4f46baa a733be1 4f46baa 193fbf7 4f46baa a733be1 4f46baa a733be1 4f46baa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 | """
LiquidGen: A Novel Liquid Neural Network Image Generation Model
Architecture Overview:
- Frozen VAE encoder/decoder (SDXL VAE, 4ch latent, 8x compression, no login needed)
- Liquid backbone for denoising (fully parallelizable, no attention, no sequential ODE)
- Flow matching training objective (velocity prediction)
Key Innovation: Replaces attention with Liquid Neural Network dynamics:
- CfC-inspired closed-form update: x_new = α·x + (1-α)·h(x)
- Per-channel learnable decay rates (liquid time constants)
- Depthwise + pointwise convolutions for spatial context (no attention needed)
- Zigzag spatial scanning for global receptive field
- Gated stimulus with biologically-inspired sign constraints
- U-Net style long skip connections from shallow to deep blocks
Math Foundation (from Hasani et al., CfC paper):
x_{t+1} = exp(-Δt/τ_t) · x_t + (1 - exp(-Δt/τ_t)) · h(x_t, u_t)
Our parallelizable adaptation (inspired by LiquidTAD):
α = exp(-softplus(ρ)) [per-channel learnable decay]
h = gate · stimulus [gated depthwise conv output]
out = α · x + (1 - α) · h [liquid relaxation blend]
This removes the input-dependent τ (which requires sequential computation)
and replaces it with a per-channel learned decay — making it fully parallel
while preserving the liquid dynamics' ability to blend old state with new input.
Design for 16GB VRAM (Colab free tier):
- VAE frozen: ~1GB
- Backbone: ~55-280M params (~100-550MB in fp16)
- Training overhead (grads + optimizer): ~3-8GB
- Batch of latents: ~1-2GB
- Total: fits comfortably in 16GB
References:
- Hasani et al., "Liquid Time-constant Networks" (NeurIPS 2020)
- Hasani et al., "Closed-form Continuous-depth Models" (Nature Machine Intelligence 2022)
- Lechner et al., "Neural Circuit Policies" (Nature Machine Intelligence 2020)
- LiquidTAD (2025) - Parallelized liquid dynamics
- ZigMa (ECCV 2024) - Zigzag scanning for SSM-based diffusion
- DiMSUM (NeurIPS 2024) - Attention-free diffusion
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
import math
from typing import Optional, Tuple
# =============================================================================
# Building Blocks
# =============================================================================
class LiquidTimeConstant(nn.Module):
"""
Core liquid time-constant module.
Implements the CfC closed-form dynamics in a fully parallelizable way:
out = α · x + (1 - α) · stimulus
where α = exp(-softplus(ρ)) is a learnable per-channel decay rate,
derived from the liquid time constant τ = 1/softplus(ρ).
This preserves the key property of Liquid Neural Networks:
- Exponential relaxation toward a target (stimulus)
- Rate controlled by τ (how fast to adapt)
- No sequential ODE solving required
Stability guarantee (from LTC Theorem 1):
τ_sys ∈ [τ/(1+τW), τ] — time constants NEVER explode
"""
def __init__(self, channels: int):
super().__init__()
# ρ parameterizes the decay: λ = softplus(ρ), α = exp(-λ)
# Initialize ρ=0 → λ≈0.693 → α≈0.5 (equal blend of old and new)
self.rho = nn.Parameter(torch.zeros(channels))
def forward(self, x: torch.Tensor, stimulus: torch.Tensor) -> torch.Tensor:
"""
x: [B, C, H, W] - current state (residual path)
stimulus: [B, C, H, W] - computed target from context
returns: [B, C, H, W] - liquid-blended output
"""
lam = F.softplus(self.rho) + 1e-5
alpha = torch.exp(-lam).view(1, -1, 1, 1)
return alpha * x + (1.0 - alpha) * stimulus
class GatedDepthwiseStimulusConv(nn.Module):
"""
Computes the spatial stimulus using depthwise-separable convolutions
with a sigmoid gate (inspired by GLU / gated mechanisms in SSMs).
This replaces attention for capturing local spatial context:
- Depthwise conv: captures local spatial patterns per channel
- Pointwise conv: mixes channel information
- Sigmoid gate: controls information flow (like synaptic gating in NCP)
Two parallel paths (inspired by NCP inter→command split):
1. Stimulus path: DW-conv → PW-conv → GELU → project back
2. Gate path: DW-conv → PW-conv → sigmoid
Output = stimulus * gate
"""
def __init__(self, channels: int, kernel_size: int = 7, expand_ratio: float = 2.0):
super().__init__()
hidden = int(channels * expand_ratio)
self.stim_dw = nn.Conv2d(channels, channels, kernel_size,
padding=kernel_size // 2, groups=channels, bias=False)
self.stim_pw = nn.Conv2d(channels, hidden, 1, bias=False)
self.stim_act = nn.GELU()
self.stim_proj = nn.Conv2d(hidden, channels, 1, bias=False)
self.gate_dw = nn.Conv2d(channels, channels, kernel_size,
padding=kernel_size // 2, groups=channels, bias=False)
self.gate_pw = nn.Conv2d(channels, channels, 1, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
stim = self.stim_proj(self.stim_act(self.stim_pw(self.stim_dw(x))))
gate = torch.sigmoid(self.gate_pw(self.gate_dw(x)))
return stim * gate
class ChannelMixMLP(nn.Module):
"""Channel mixing MLP with GELU activation (command neuron processing in NCP)."""
def __init__(self, channels: int, expand_ratio: float = 4.0):
super().__init__()
hidden = int(channels * expand_ratio)
self.fc1 = nn.Conv2d(channels, hidden, 1, bias=True)
self.act = nn.GELU()
self.fc2 = nn.Conv2d(hidden, channels, 1, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fc2(self.act(self.fc1(x)))
class AdaptiveGroupNorm(nn.Module):
"""
Adaptive Group Normalization conditioned on timestep embedding.
Applies: out = (1 + scale) * GroupNorm(x) + shift
"""
def __init__(self, channels: int, cond_dim: int, num_groups: int = 32):
super().__init__()
self.norm = nn.GroupNorm(num_groups, channels, affine=False)
self.proj = nn.Linear(cond_dim, channels * 2)
nn.init.zeros_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
h = self.norm(x)
params = self.proj(cond)
scale, shift = params.chunk(2, dim=-1)
return h * (1.0 + scale.unsqueeze(-1).unsqueeze(-1)) + shift.unsqueeze(-1).unsqueeze(-1)
class ZigzagScan1D(nn.Module):
"""
1D global mixing via zigzag-scanned depthwise conv.
Gives quasi-global receptive field without attention's O(n²) cost.
Zigzag scan preserves spatial continuity (from ZigMa, ECCV 2024).
"""
def __init__(self, channels: int, kernel_size: int = 31):
super().__init__()
self.conv1d = nn.Conv1d(channels, channels, kernel_size,
padding=kernel_size // 2, groups=channels, bias=False)
self.pw = nn.Conv1d(channels, channels, 1, bias=True)
self.act = nn.GELU()
self._idx_cache = {}
def _get_indices(self, H: int, W: int, device: torch.device):
key = (H, W, device)
if key not in self._idx_cache:
indices = []
for i in range(H):
row = list(range(i * W, (i + 1) * W))
if i % 2 == 1:
row = row[::-1]
indices.extend(row)
fwd = torch.tensor(indices, device=device, dtype=torch.long)
inv = torch.empty_like(fwd)
inv[fwd] = torch.arange(H * W, device=device)
self._idx_cache[key] = (fwd, inv)
return self._idx_cache[key]
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
zz_idx, inv_idx = self._get_indices(H, W, x.device)
x_flat = x.reshape(B, C, H * W)
x_zz = x_flat[:, :, zz_idx]
x_mixed = self.pw(self.act(self.conv1d(x_zz)))
x_restored = x_mixed[:, :, inv_idx]
return x_restored.reshape(B, C, H, W)
# =============================================================================
# Liquid Block: The core building block
# =============================================================================
class LiquidBlock(nn.Module):
"""
A single Liquid Neural Network block for image denoising.
Architecture (maps to NCP hierarchy):
1. [SENSORY] AdaGN conditioning → spatial context extraction
2. [INTER] Zigzag 1D scan for global mixing
3. [COMMAND] Liquid time-constant blend (CfC dynamics)
4. [MOTOR] Channel mixing MLP for output projection
All operations are fully parallelizable — no sequential dependencies.
"""
def __init__(
self, channels: int, cond_dim: int, spatial_kernel: int = 7,
scan_kernel: int = 31, expand_ratio: float = 2.0, mlp_ratio: float = 4.0,
drop_rate: float = 0.0, use_zigzag: bool = True,
):
super().__init__()
self.norm1 = AdaptiveGroupNorm(channels, cond_dim)
self.norm2 = AdaptiveGroupNorm(channels, cond_dim)
self.spatial_stim = GatedDepthwiseStimulusConv(channels, spatial_kernel, expand_ratio)
self.use_zigzag = use_zigzag
if use_zigzag:
self.zigzag = ZigzagScan1D(channels, scan_kernel)
self.zigzag_gate = nn.Parameter(torch.zeros(1))
self.liquid = LiquidTimeConstant(channels)
self.channel_mix = ChannelMixMLP(channels, mlp_ratio)
self.liquid2 = LiquidTimeConstant(channels)
self.drop = nn.Dropout2d(drop_rate) if drop_rate > 0 else nn.Identity()
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
h = self.norm1(x, cond)
stim = self.spatial_stim(h)
if self.use_zigzag:
zz = self.zigzag(h)
stim = stim + torch.sigmoid(self.zigzag_gate) * zz
stim = self.drop(stim)
x = self.liquid(x, stim)
h2 = self.norm2(x, cond)
ch_out = self.drop(self.channel_mix(h2))
x = self.liquid2(x, ch_out)
return x
# =============================================================================
# Timestep and Class Embeddings
# =============================================================================
class TimestepEmbedding(nn.Module):
"""Sinusoidal timestep embedding followed by MLP projection."""
def __init__(self, dim: int, freq_dim: int = 256):
super().__init__()
self.freq_dim = freq_dim
self.mlp = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
def forward(self, t: torch.Tensor) -> torch.Tensor:
half = self.freq_dim // 2
freqs = torch.exp(-math.log(10000.0) * torch.arange(half, device=t.device, dtype=t.dtype) / half)
args = t.unsqueeze(-1) * freqs.unsqueeze(0)
emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
return self.mlp(emb)
class ClassEmbedding(nn.Module):
"""Optional class-conditional embedding with CFG null embedding."""
def __init__(self, num_classes: int, dim: int):
super().__init__()
self.embed = nn.Embedding(num_classes, dim)
self.null_embed = nn.Parameter(torch.randn(dim) * 0.02)
def forward(self, labels: torch.Tensor, drop_prob: float = 0.0) -> torch.Tensor:
emb = self.embed(labels)
if self.training and drop_prob > 0:
mask = torch.rand(labels.shape[0], 1, device=labels.device) < drop_prob
emb = torch.where(mask, self.null_embed.unsqueeze(0).expand_as(emb), emb)
return emb
# =============================================================================
# LiquidGen: Full Model
# =============================================================================
class LiquidGen(nn.Module):
"""
LiquidGen: Liquid Neural Network Image Generator
A novel attention-free diffusion model that uses Liquid Neural Network
dynamics (CfC closed-form continuous-depth) for image generation.
Features:
- NO self-attention anywhere — O(n) complexity
- NO sequential ODE solving — fully parallelizable
- Liquid time constants for adaptive information blending
- Zigzag scanning for global context
- Depthwise convolutions for local spatial structure
- Gated stimulus (biologically-inspired from NCP)
- U-Net long skip connections (from U-ViT/DiM)
Config Presets:
- LiquidGen-S: ~55M params (256px, fast training)
- LiquidGen-B: ~140M params (256/512px, balanced)
- LiquidGen-L: ~280M params (512px, high quality)
"""
def __init__(
self,
in_channels: int = 4, # 4 for SDXL VAE
patch_size: int = 2,
embed_dim: int = 512,
depth: int = 16,
spatial_kernel: int = 7,
scan_kernel: int = 31,
expand_ratio: float = 2.0,
mlp_ratio: float = 4.0,
drop_rate: float = 0.0,
num_classes: int = 0,
class_drop_prob: float = 0.1,
use_zigzag: bool = True,
):
super().__init__()
self.in_channels = in_channels
self.patch_size = patch_size
self.embed_dim = embed_dim
self.depth = depth
self.num_classes = num_classes
self.class_drop_prob = class_drop_prob
cond_dim = embed_dim
self.time_embed = TimestepEmbedding(cond_dim)
self.class_embed = ClassEmbedding(num_classes, cond_dim) if num_classes > 0 else None
self.patch_embed = nn.Conv2d(in_channels, embed_dim, patch_size, stride=patch_size)
self.pos_embed_size = 32
self.pos_embed = nn.Parameter(
torch.randn(1, embed_dim, self.pos_embed_size, self.pos_embed_size) * 0.02
)
self.input_proj = nn.Sequential(
nn.Conv2d(embed_dim, embed_dim, 3, padding=1, groups=embed_dim, bias=False),
nn.Conv2d(embed_dim, embed_dim, 1, bias=True),
nn.GELU(),
)
self.blocks = nn.ModuleList([
LiquidBlock(embed_dim, cond_dim, spatial_kernel, scan_kernel,
expand_ratio, mlp_ratio, drop_rate, use_zigzag)
for _ in range(depth)
])
self.final_norm = nn.GroupNorm(32, embed_dim)
self.final_proj = nn.Sequential(
nn.Conv2d(embed_dim, embed_dim, 3, padding=1, bias=True),
nn.GELU(),
)
self.unpatch = nn.ConvTranspose2d(embed_dim, in_channels, patch_size, stride=patch_size)
nn.init.zeros_(self.unpatch.weight)
nn.init.zeros_(self.unpatch.bias)
self.apply(self._init_weights)
self._gradient_checkpointing = False
def enable_gradient_checkpointing(self):
"""Enable gradient checkpointing to reduce VRAM by ~40-60%.
Recomputes block activations during backward instead of storing them.
Slower training (~30%) but allows much larger batch sizes or models."""
self._gradient_checkpointing = True
def disable_gradient_checkpointing(self):
self._gradient_checkpointing = False
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, std=0.02)
def _interpolate_pos_embed(self, H: int, W: int) -> torch.Tensor:
if H == self.pos_embed_size and W == self.pos_embed_size:
return self.pos_embed
return F.interpolate(self.pos_embed, size=(H, W), mode='bilinear', align_corners=False)
def forward(
self, x: torch.Tensor, t: torch.Tensor, class_labels: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Predict velocity field for flow matching.
Args:
x: [B, C, H, W] noisy latent (C=4 for SDXL VAE)
t: [B] timestep in [0, 1]
class_labels: [B] optional class labels
Returns:
v: [B, C, H, W] predicted velocity
"""
cond = self.time_embed(t)
if self.class_embed is not None and class_labels is not None:
drop_p = self.class_drop_prob if self.training else 0.0
cond = cond + self.class_embed(class_labels, drop_prob=drop_p)
h = self.patch_embed(x)
B, C, H_p, W_p = h.shape
h = h + self._interpolate_pos_embed(H_p, W_p)
h = self.input_proj(h)
# U-Net style long skip connections
skip_connections = []
mid = self.depth // 2
for i, block in enumerate(self.blocks):
if i < mid:
skip_connections.append(h)
elif i >= mid and len(skip_connections) > 0:
skip = skip_connections.pop()
h = h + skip
if self._gradient_checkpointing and self.training:
h = checkpoint(block, h, cond, use_reentrant=False)
else:
h = block(h, cond)
h = self.final_norm(h)
h = self.final_proj(h)
v = self.unpatch(h)
return v
def count_params(self) -> int:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
# =============================================================================
# Model Presets
# =============================================================================
def liquidgen_small(**kwargs) -> LiquidGen:
"""~55M params - for 256px, fast training/testing"""
defaults = dict(
embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31,
expand_ratio=2.0, mlp_ratio=3.0, use_zigzag=True,
)
defaults.update(kwargs)
return LiquidGen(**defaults)
def liquidgen_base(**kwargs) -> LiquidGen:
"""~140M params - for 256/512px, balanced (fits T4 16GB easily)"""
defaults = dict(
embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31,
expand_ratio=2.0, mlp_ratio=4.0, use_zigzag=True,
)
defaults.update(kwargs)
return LiquidGen(**defaults)
def liquidgen_large(**kwargs) -> LiquidGen:
"""~280M params - for 512px, high quality (fits T4 16GB with small batch)"""
defaults = dict(
embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31,
expand_ratio=2.5, mlp_ratio=4.0, use_zigzag=True,
)
defaults.update(kwargs)
return LiquidGen(**defaults)
if __name__ == "__main__":
device = "cpu"
for name, factory in [("Small", liquidgen_small), ("Base", liquidgen_base), ("Large", liquidgen_large)]:
model = factory(num_classes=27).to(device)
print(f"LiquidGen-{name}: {model.count_params() / 1e6:.1f}M params")
# 256px: image/8 = 32x32 latent, 4 channels (SDXL VAE)
x = torch.randn(2, 4, 32, 32, device=device)
t = torch.rand(2, device=device)
labels = torch.randint(0, 27, (2,), device=device)
v = model(x, t, labels)
assert v.shape == x.shape
# 512px: image/8 = 64x64 latent
x512 = torch.randn(1, 4, 64, 64, device=device)
v512 = model(x512, t[:1], labels[:1])
assert v512.shape == x512.shape
print(f" 256px ✅ 512px ✅")
del model
print("\n✅ All tests passed!")
|