LiquidGen / model.py
asdf98's picture
Add gradient checkpointing + zigzag index caching
193fbf7 verified
"""
LiquidGen: A Novel Liquid Neural Network Image Generation Model
Architecture Overview:
- Frozen VAE encoder/decoder (SDXL VAE, 4ch latent, 8x compression, no login needed)
- Liquid backbone for denoising (fully parallelizable, no attention, no sequential ODE)
- Flow matching training objective (velocity prediction)
Key Innovation: Replaces attention with Liquid Neural Network dynamics:
- CfC-inspired closed-form update: x_new = α·x + (1-α)·h(x)
- Per-channel learnable decay rates (liquid time constants)
- Depthwise + pointwise convolutions for spatial context (no attention needed)
- Zigzag spatial scanning for global receptive field
- Gated stimulus with biologically-inspired sign constraints
- U-Net style long skip connections from shallow to deep blocks
Math Foundation (from Hasani et al., CfC paper):
x_{t+1} = exp(-Δt/τ_t) · x_t + (1 - exp(-Δt/τ_t)) · h(x_t, u_t)
Our parallelizable adaptation (inspired by LiquidTAD):
α = exp(-softplus(ρ)) [per-channel learnable decay]
h = gate · stimulus [gated depthwise conv output]
out = α · x + (1 - α) · h [liquid relaxation blend]
This removes the input-dependent τ (which requires sequential computation)
and replaces it with a per-channel learned decay — making it fully parallel
while preserving the liquid dynamics' ability to blend old state with new input.
Design for 16GB VRAM (Colab free tier):
- VAE frozen: ~1GB
- Backbone: ~55-280M params (~100-550MB in fp16)
- Training overhead (grads + optimizer): ~3-8GB
- Batch of latents: ~1-2GB
- Total: fits comfortably in 16GB
References:
- Hasani et al., "Liquid Time-constant Networks" (NeurIPS 2020)
- Hasani et al., "Closed-form Continuous-depth Models" (Nature Machine Intelligence 2022)
- Lechner et al., "Neural Circuit Policies" (Nature Machine Intelligence 2020)
- LiquidTAD (2025) - Parallelized liquid dynamics
- ZigMa (ECCV 2024) - Zigzag scanning for SSM-based diffusion
- DiMSUM (NeurIPS 2024) - Attention-free diffusion
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
import math
from typing import Optional, Tuple
# =============================================================================
# Building Blocks
# =============================================================================
class LiquidTimeConstant(nn.Module):
"""
Core liquid time-constant module.
Implements the CfC closed-form dynamics in a fully parallelizable way:
out = α · x + (1 - α) · stimulus
where α = exp(-softplus(ρ)) is a learnable per-channel decay rate,
derived from the liquid time constant τ = 1/softplus(ρ).
This preserves the key property of Liquid Neural Networks:
- Exponential relaxation toward a target (stimulus)
- Rate controlled by τ (how fast to adapt)
- No sequential ODE solving required
Stability guarantee (from LTC Theorem 1):
τ_sys ∈ [τ/(1+τW), τ] — time constants NEVER explode
"""
def __init__(self, channels: int):
super().__init__()
# ρ parameterizes the decay: λ = softplus(ρ), α = exp(-λ)
# Initialize ρ=0 → λ≈0.693 → α≈0.5 (equal blend of old and new)
self.rho = nn.Parameter(torch.zeros(channels))
def forward(self, x: torch.Tensor, stimulus: torch.Tensor) -> torch.Tensor:
"""
x: [B, C, H, W] - current state (residual path)
stimulus: [B, C, H, W] - computed target from context
returns: [B, C, H, W] - liquid-blended output
"""
lam = F.softplus(self.rho) + 1e-5
alpha = torch.exp(-lam).view(1, -1, 1, 1)
return alpha * x + (1.0 - alpha) * stimulus
class GatedDepthwiseStimulusConv(nn.Module):
"""
Computes the spatial stimulus using depthwise-separable convolutions
with a sigmoid gate (inspired by GLU / gated mechanisms in SSMs).
This replaces attention for capturing local spatial context:
- Depthwise conv: captures local spatial patterns per channel
- Pointwise conv: mixes channel information
- Sigmoid gate: controls information flow (like synaptic gating in NCP)
Two parallel paths (inspired by NCP inter→command split):
1. Stimulus path: DW-conv → PW-conv → GELU → project back
2. Gate path: DW-conv → PW-conv → sigmoid
Output = stimulus * gate
"""
def __init__(self, channels: int, kernel_size: int = 7, expand_ratio: float = 2.0):
super().__init__()
hidden = int(channels * expand_ratio)
self.stim_dw = nn.Conv2d(channels, channels, kernel_size,
padding=kernel_size // 2, groups=channels, bias=False)
self.stim_pw = nn.Conv2d(channels, hidden, 1, bias=False)
self.stim_act = nn.GELU()
self.stim_proj = nn.Conv2d(hidden, channels, 1, bias=False)
self.gate_dw = nn.Conv2d(channels, channels, kernel_size,
padding=kernel_size // 2, groups=channels, bias=False)
self.gate_pw = nn.Conv2d(channels, channels, 1, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
stim = self.stim_proj(self.stim_act(self.stim_pw(self.stim_dw(x))))
gate = torch.sigmoid(self.gate_pw(self.gate_dw(x)))
return stim * gate
class ChannelMixMLP(nn.Module):
"""Channel mixing MLP with GELU activation (command neuron processing in NCP)."""
def __init__(self, channels: int, expand_ratio: float = 4.0):
super().__init__()
hidden = int(channels * expand_ratio)
self.fc1 = nn.Conv2d(channels, hidden, 1, bias=True)
self.act = nn.GELU()
self.fc2 = nn.Conv2d(hidden, channels, 1, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fc2(self.act(self.fc1(x)))
class AdaptiveGroupNorm(nn.Module):
"""
Adaptive Group Normalization conditioned on timestep embedding.
Applies: out = (1 + scale) * GroupNorm(x) + shift
"""
def __init__(self, channels: int, cond_dim: int, num_groups: int = 32):
super().__init__()
self.norm = nn.GroupNorm(num_groups, channels, affine=False)
self.proj = nn.Linear(cond_dim, channels * 2)
nn.init.zeros_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
h = self.norm(x)
params = self.proj(cond)
scale, shift = params.chunk(2, dim=-1)
return h * (1.0 + scale.unsqueeze(-1).unsqueeze(-1)) + shift.unsqueeze(-1).unsqueeze(-1)
class ZigzagScan1D(nn.Module):
"""
1D global mixing via zigzag-scanned depthwise conv.
Gives quasi-global receptive field without attention's O(n²) cost.
Zigzag scan preserves spatial continuity (from ZigMa, ECCV 2024).
"""
def __init__(self, channels: int, kernel_size: int = 31):
super().__init__()
self.conv1d = nn.Conv1d(channels, channels, kernel_size,
padding=kernel_size // 2, groups=channels, bias=False)
self.pw = nn.Conv1d(channels, channels, 1, bias=True)
self.act = nn.GELU()
self._idx_cache = {}
def _get_indices(self, H: int, W: int, device: torch.device):
key = (H, W, device)
if key not in self._idx_cache:
indices = []
for i in range(H):
row = list(range(i * W, (i + 1) * W))
if i % 2 == 1:
row = row[::-1]
indices.extend(row)
fwd = torch.tensor(indices, device=device, dtype=torch.long)
inv = torch.empty_like(fwd)
inv[fwd] = torch.arange(H * W, device=device)
self._idx_cache[key] = (fwd, inv)
return self._idx_cache[key]
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
zz_idx, inv_idx = self._get_indices(H, W, x.device)
x_flat = x.reshape(B, C, H * W)
x_zz = x_flat[:, :, zz_idx]
x_mixed = self.pw(self.act(self.conv1d(x_zz)))
x_restored = x_mixed[:, :, inv_idx]
return x_restored.reshape(B, C, H, W)
# =============================================================================
# Liquid Block: The core building block
# =============================================================================
class LiquidBlock(nn.Module):
"""
A single Liquid Neural Network block for image denoising.
Architecture (maps to NCP hierarchy):
1. [SENSORY] AdaGN conditioning → spatial context extraction
2. [INTER] Zigzag 1D scan for global mixing
3. [COMMAND] Liquid time-constant blend (CfC dynamics)
4. [MOTOR] Channel mixing MLP for output projection
All operations are fully parallelizable — no sequential dependencies.
"""
def __init__(
self, channels: int, cond_dim: int, spatial_kernel: int = 7,
scan_kernel: int = 31, expand_ratio: float = 2.0, mlp_ratio: float = 4.0,
drop_rate: float = 0.0, use_zigzag: bool = True,
):
super().__init__()
self.norm1 = AdaptiveGroupNorm(channels, cond_dim)
self.norm2 = AdaptiveGroupNorm(channels, cond_dim)
self.spatial_stim = GatedDepthwiseStimulusConv(channels, spatial_kernel, expand_ratio)
self.use_zigzag = use_zigzag
if use_zigzag:
self.zigzag = ZigzagScan1D(channels, scan_kernel)
self.zigzag_gate = nn.Parameter(torch.zeros(1))
self.liquid = LiquidTimeConstant(channels)
self.channel_mix = ChannelMixMLP(channels, mlp_ratio)
self.liquid2 = LiquidTimeConstant(channels)
self.drop = nn.Dropout2d(drop_rate) if drop_rate > 0 else nn.Identity()
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
h = self.norm1(x, cond)
stim = self.spatial_stim(h)
if self.use_zigzag:
zz = self.zigzag(h)
stim = stim + torch.sigmoid(self.zigzag_gate) * zz
stim = self.drop(stim)
x = self.liquid(x, stim)
h2 = self.norm2(x, cond)
ch_out = self.drop(self.channel_mix(h2))
x = self.liquid2(x, ch_out)
return x
# =============================================================================
# Timestep and Class Embeddings
# =============================================================================
class TimestepEmbedding(nn.Module):
"""Sinusoidal timestep embedding followed by MLP projection."""
def __init__(self, dim: int, freq_dim: int = 256):
super().__init__()
self.freq_dim = freq_dim
self.mlp = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
def forward(self, t: torch.Tensor) -> torch.Tensor:
half = self.freq_dim // 2
freqs = torch.exp(-math.log(10000.0) * torch.arange(half, device=t.device, dtype=t.dtype) / half)
args = t.unsqueeze(-1) * freqs.unsqueeze(0)
emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
return self.mlp(emb)
class ClassEmbedding(nn.Module):
"""Optional class-conditional embedding with CFG null embedding."""
def __init__(self, num_classes: int, dim: int):
super().__init__()
self.embed = nn.Embedding(num_classes, dim)
self.null_embed = nn.Parameter(torch.randn(dim) * 0.02)
def forward(self, labels: torch.Tensor, drop_prob: float = 0.0) -> torch.Tensor:
emb = self.embed(labels)
if self.training and drop_prob > 0:
mask = torch.rand(labels.shape[0], 1, device=labels.device) < drop_prob
emb = torch.where(mask, self.null_embed.unsqueeze(0).expand_as(emb), emb)
return emb
# =============================================================================
# LiquidGen: Full Model
# =============================================================================
class LiquidGen(nn.Module):
"""
LiquidGen: Liquid Neural Network Image Generator
A novel attention-free diffusion model that uses Liquid Neural Network
dynamics (CfC closed-form continuous-depth) for image generation.
Features:
- NO self-attention anywhere — O(n) complexity
- NO sequential ODE solving — fully parallelizable
- Liquid time constants for adaptive information blending
- Zigzag scanning for global context
- Depthwise convolutions for local spatial structure
- Gated stimulus (biologically-inspired from NCP)
- U-Net long skip connections (from U-ViT/DiM)
Config Presets:
- LiquidGen-S: ~55M params (256px, fast training)
- LiquidGen-B: ~140M params (256/512px, balanced)
- LiquidGen-L: ~280M params (512px, high quality)
"""
def __init__(
self,
in_channels: int = 4, # 4 for SDXL VAE
patch_size: int = 2,
embed_dim: int = 512,
depth: int = 16,
spatial_kernel: int = 7,
scan_kernel: int = 31,
expand_ratio: float = 2.0,
mlp_ratio: float = 4.0,
drop_rate: float = 0.0,
num_classes: int = 0,
class_drop_prob: float = 0.1,
use_zigzag: bool = True,
):
super().__init__()
self.in_channels = in_channels
self.patch_size = patch_size
self.embed_dim = embed_dim
self.depth = depth
self.num_classes = num_classes
self.class_drop_prob = class_drop_prob
cond_dim = embed_dim
self.time_embed = TimestepEmbedding(cond_dim)
self.class_embed = ClassEmbedding(num_classes, cond_dim) if num_classes > 0 else None
self.patch_embed = nn.Conv2d(in_channels, embed_dim, patch_size, stride=patch_size)
self.pos_embed_size = 32
self.pos_embed = nn.Parameter(
torch.randn(1, embed_dim, self.pos_embed_size, self.pos_embed_size) * 0.02
)
self.input_proj = nn.Sequential(
nn.Conv2d(embed_dim, embed_dim, 3, padding=1, groups=embed_dim, bias=False),
nn.Conv2d(embed_dim, embed_dim, 1, bias=True),
nn.GELU(),
)
self.blocks = nn.ModuleList([
LiquidBlock(embed_dim, cond_dim, spatial_kernel, scan_kernel,
expand_ratio, mlp_ratio, drop_rate, use_zigzag)
for _ in range(depth)
])
self.final_norm = nn.GroupNorm(32, embed_dim)
self.final_proj = nn.Sequential(
nn.Conv2d(embed_dim, embed_dim, 3, padding=1, bias=True),
nn.GELU(),
)
self.unpatch = nn.ConvTranspose2d(embed_dim, in_channels, patch_size, stride=patch_size)
nn.init.zeros_(self.unpatch.weight)
nn.init.zeros_(self.unpatch.bias)
self.apply(self._init_weights)
self._gradient_checkpointing = False
def enable_gradient_checkpointing(self):
"""Enable gradient checkpointing to reduce VRAM by ~40-60%.
Recomputes block activations during backward instead of storing them.
Slower training (~30%) but allows much larger batch sizes or models."""
self._gradient_checkpointing = True
def disable_gradient_checkpointing(self):
self._gradient_checkpointing = False
def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, std=0.02)
def _interpolate_pos_embed(self, H: int, W: int) -> torch.Tensor:
if H == self.pos_embed_size and W == self.pos_embed_size:
return self.pos_embed
return F.interpolate(self.pos_embed, size=(H, W), mode='bilinear', align_corners=False)
def forward(
self, x: torch.Tensor, t: torch.Tensor, class_labels: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Predict velocity field for flow matching.
Args:
x: [B, C, H, W] noisy latent (C=4 for SDXL VAE)
t: [B] timestep in [0, 1]
class_labels: [B] optional class labels
Returns:
v: [B, C, H, W] predicted velocity
"""
cond = self.time_embed(t)
if self.class_embed is not None and class_labels is not None:
drop_p = self.class_drop_prob if self.training else 0.0
cond = cond + self.class_embed(class_labels, drop_prob=drop_p)
h = self.patch_embed(x)
B, C, H_p, W_p = h.shape
h = h + self._interpolate_pos_embed(H_p, W_p)
h = self.input_proj(h)
# U-Net style long skip connections
skip_connections = []
mid = self.depth // 2
for i, block in enumerate(self.blocks):
if i < mid:
skip_connections.append(h)
elif i >= mid and len(skip_connections) > 0:
skip = skip_connections.pop()
h = h + skip
if self._gradient_checkpointing and self.training:
h = checkpoint(block, h, cond, use_reentrant=False)
else:
h = block(h, cond)
h = self.final_norm(h)
h = self.final_proj(h)
v = self.unpatch(h)
return v
def count_params(self) -> int:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
# =============================================================================
# Model Presets
# =============================================================================
def liquidgen_small(**kwargs) -> LiquidGen:
"""~55M params - for 256px, fast training/testing"""
defaults = dict(
embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31,
expand_ratio=2.0, mlp_ratio=3.0, use_zigzag=True,
)
defaults.update(kwargs)
return LiquidGen(**defaults)
def liquidgen_base(**kwargs) -> LiquidGen:
"""~140M params - for 256/512px, balanced (fits T4 16GB easily)"""
defaults = dict(
embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31,
expand_ratio=2.0, mlp_ratio=4.0, use_zigzag=True,
)
defaults.update(kwargs)
return LiquidGen(**defaults)
def liquidgen_large(**kwargs) -> LiquidGen:
"""~280M params - for 512px, high quality (fits T4 16GB with small batch)"""
defaults = dict(
embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31,
expand_ratio=2.5, mlp_ratio=4.0, use_zigzag=True,
)
defaults.update(kwargs)
return LiquidGen(**defaults)
if __name__ == "__main__":
device = "cpu"
for name, factory in [("Small", liquidgen_small), ("Base", liquidgen_base), ("Large", liquidgen_large)]:
model = factory(num_classes=27).to(device)
print(f"LiquidGen-{name}: {model.count_params() / 1e6:.1f}M params")
# 256px: image/8 = 32x32 latent, 4 channels (SDXL VAE)
x = torch.randn(2, 4, 32, 32, device=device)
t = torch.rand(2, device=device)
labels = torch.randint(0, 27, (2,), device=device)
v = model(x, t, labels)
assert v.shape == x.shape
# 512px: image/8 = 64x64 latent
x512 = torch.randn(1, 4, 64, 64, device=device)
v512 = model(x512, t[:1], labels[:1])
assert v512.shape == x512.shape
print(f" 256px ✅ 512px ✅")
del model
print("\n✅ All tests passed!")