| """ |
| LiquidGen: A Novel Liquid Neural Network Image Generation Model |
| |
| Architecture Overview: |
| - Frozen VAE encoder/decoder (SDXL VAE, 4ch latent, 8x compression, no login needed) |
| - Liquid backbone for denoising (fully parallelizable, no attention, no sequential ODE) |
| - Flow matching training objective (velocity prediction) |
| |
| Key Innovation: Replaces attention with Liquid Neural Network dynamics: |
| - CfC-inspired closed-form update: x_new = α·x + (1-α)·h(x) |
| - Per-channel learnable decay rates (liquid time constants) |
| - Depthwise + pointwise convolutions for spatial context (no attention needed) |
| - Zigzag spatial scanning for global receptive field |
| - Gated stimulus with biologically-inspired sign constraints |
| - U-Net style long skip connections from shallow to deep blocks |
| |
| Math Foundation (from Hasani et al., CfC paper): |
| x_{t+1} = exp(-Δt/τ_t) · x_t + (1 - exp(-Δt/τ_t)) · h(x_t, u_t) |
| |
| Our parallelizable adaptation (inspired by LiquidTAD): |
| α = exp(-softplus(ρ)) [per-channel learnable decay] |
| h = gate · stimulus [gated depthwise conv output] |
| out = α · x + (1 - α) · h [liquid relaxation blend] |
| |
| This removes the input-dependent τ (which requires sequential computation) |
| and replaces it with a per-channel learned decay — making it fully parallel |
| while preserving the liquid dynamics' ability to blend old state with new input. |
| |
| Design for 16GB VRAM (Colab free tier): |
| - VAE frozen: ~1GB |
| - Backbone: ~55-280M params (~100-550MB in fp16) |
| - Training overhead (grads + optimizer): ~3-8GB |
| - Batch of latents: ~1-2GB |
| - Total: fits comfortably in 16GB |
| |
| References: |
| - Hasani et al., "Liquid Time-constant Networks" (NeurIPS 2020) |
| - Hasani et al., "Closed-form Continuous-depth Models" (Nature Machine Intelligence 2022) |
| - Lechner et al., "Neural Circuit Policies" (Nature Machine Intelligence 2020) |
| - LiquidTAD (2025) - Parallelized liquid dynamics |
| - ZigMa (ECCV 2024) - Zigzag scanning for SSM-based diffusion |
| - DiMSUM (NeurIPS 2024) - Attention-free diffusion |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.checkpoint import checkpoint |
| import math |
| from typing import Optional, Tuple |
|
|
|
|
| |
| |
| |
|
|
| class LiquidTimeConstant(nn.Module): |
| """ |
| Core liquid time-constant module. |
| |
| Implements the CfC closed-form dynamics in a fully parallelizable way: |
| out = α · x + (1 - α) · stimulus |
| |
| where α = exp(-softplus(ρ)) is a learnable per-channel decay rate, |
| derived from the liquid time constant τ = 1/softplus(ρ). |
| |
| This preserves the key property of Liquid Neural Networks: |
| - Exponential relaxation toward a target (stimulus) |
| - Rate controlled by τ (how fast to adapt) |
| - No sequential ODE solving required |
| |
| Stability guarantee (from LTC Theorem 1): |
| τ_sys ∈ [τ/(1+τW), τ] — time constants NEVER explode |
| """ |
| def __init__(self, channels: int): |
| super().__init__() |
| |
| |
| self.rho = nn.Parameter(torch.zeros(channels)) |
| |
| def forward(self, x: torch.Tensor, stimulus: torch.Tensor) -> torch.Tensor: |
| """ |
| x: [B, C, H, W] - current state (residual path) |
| stimulus: [B, C, H, W] - computed target from context |
| returns: [B, C, H, W] - liquid-blended output |
| """ |
| lam = F.softplus(self.rho) + 1e-5 |
| alpha = torch.exp(-lam).view(1, -1, 1, 1) |
| return alpha * x + (1.0 - alpha) * stimulus |
|
|
|
|
| class GatedDepthwiseStimulusConv(nn.Module): |
| """ |
| Computes the spatial stimulus using depthwise-separable convolutions |
| with a sigmoid gate (inspired by GLU / gated mechanisms in SSMs). |
| |
| This replaces attention for capturing local spatial context: |
| - Depthwise conv: captures local spatial patterns per channel |
| - Pointwise conv: mixes channel information |
| - Sigmoid gate: controls information flow (like synaptic gating in NCP) |
| |
| Two parallel paths (inspired by NCP inter→command split): |
| 1. Stimulus path: DW-conv → PW-conv → GELU → project back |
| 2. Gate path: DW-conv → PW-conv → sigmoid |
| Output = stimulus * gate |
| """ |
| def __init__(self, channels: int, kernel_size: int = 7, expand_ratio: float = 2.0): |
| super().__init__() |
| hidden = int(channels * expand_ratio) |
| |
| self.stim_dw = nn.Conv2d(channels, channels, kernel_size, |
| padding=kernel_size // 2, groups=channels, bias=False) |
| self.stim_pw = nn.Conv2d(channels, hidden, 1, bias=False) |
| self.stim_act = nn.GELU() |
| self.stim_proj = nn.Conv2d(hidden, channels, 1, bias=False) |
| |
| self.gate_dw = nn.Conv2d(channels, channels, kernel_size, |
| padding=kernel_size // 2, groups=channels, bias=False) |
| self.gate_pw = nn.Conv2d(channels, channels, 1, bias=True) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| stim = self.stim_proj(self.stim_act(self.stim_pw(self.stim_dw(x)))) |
| gate = torch.sigmoid(self.gate_pw(self.gate_dw(x))) |
| return stim * gate |
|
|
|
|
| class ChannelMixMLP(nn.Module): |
| """Channel mixing MLP with GELU activation (command neuron processing in NCP).""" |
| def __init__(self, channels: int, expand_ratio: float = 4.0): |
| super().__init__() |
| hidden = int(channels * expand_ratio) |
| self.fc1 = nn.Conv2d(channels, hidden, 1, bias=True) |
| self.act = nn.GELU() |
| self.fc2 = nn.Conv2d(hidden, channels, 1, bias=True) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.fc2(self.act(self.fc1(x))) |
|
|
|
|
| class AdaptiveGroupNorm(nn.Module): |
| """ |
| Adaptive Group Normalization conditioned on timestep embedding. |
| Applies: out = (1 + scale) * GroupNorm(x) + shift |
| """ |
| def __init__(self, channels: int, cond_dim: int, num_groups: int = 32): |
| super().__init__() |
| self.norm = nn.GroupNorm(num_groups, channels, affine=False) |
| self.proj = nn.Linear(cond_dim, channels * 2) |
| nn.init.zeros_(self.proj.weight) |
| nn.init.zeros_(self.proj.bias) |
| |
| def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: |
| h = self.norm(x) |
| params = self.proj(cond) |
| scale, shift = params.chunk(2, dim=-1) |
| return h * (1.0 + scale.unsqueeze(-1).unsqueeze(-1)) + shift.unsqueeze(-1).unsqueeze(-1) |
|
|
|
|
| class ZigzagScan1D(nn.Module): |
| """ |
| 1D global mixing via zigzag-scanned depthwise conv. |
| |
| Gives quasi-global receptive field without attention's O(n²) cost. |
| Zigzag scan preserves spatial continuity (from ZigMa, ECCV 2024). |
| """ |
| def __init__(self, channels: int, kernel_size: int = 31): |
| super().__init__() |
| self.conv1d = nn.Conv1d(channels, channels, kernel_size, |
| padding=kernel_size // 2, groups=channels, bias=False) |
| self.pw = nn.Conv1d(channels, channels, 1, bias=True) |
| self.act = nn.GELU() |
| self._idx_cache = {} |
| |
| def _get_indices(self, H: int, W: int, device: torch.device): |
| key = (H, W, device) |
| if key not in self._idx_cache: |
| indices = [] |
| for i in range(H): |
| row = list(range(i * W, (i + 1) * W)) |
| if i % 2 == 1: |
| row = row[::-1] |
| indices.extend(row) |
| fwd = torch.tensor(indices, device=device, dtype=torch.long) |
| inv = torch.empty_like(fwd) |
| inv[fwd] = torch.arange(H * W, device=device) |
| self._idx_cache[key] = (fwd, inv) |
| return self._idx_cache[key] |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| B, C, H, W = x.shape |
| zz_idx, inv_idx = self._get_indices(H, W, x.device) |
| x_flat = x.reshape(B, C, H * W) |
| x_zz = x_flat[:, :, zz_idx] |
| x_mixed = self.pw(self.act(self.conv1d(x_zz))) |
| x_restored = x_mixed[:, :, inv_idx] |
| return x_restored.reshape(B, C, H, W) |
|
|
|
|
| |
| |
| |
|
|
| class LiquidBlock(nn.Module): |
| """ |
| A single Liquid Neural Network block for image denoising. |
| |
| Architecture (maps to NCP hierarchy): |
| 1. [SENSORY] AdaGN conditioning → spatial context extraction |
| 2. [INTER] Zigzag 1D scan for global mixing |
| 3. [COMMAND] Liquid time-constant blend (CfC dynamics) |
| 4. [MOTOR] Channel mixing MLP for output projection |
| |
| All operations are fully parallelizable — no sequential dependencies. |
| """ |
| def __init__( |
| self, channels: int, cond_dim: int, spatial_kernel: int = 7, |
| scan_kernel: int = 31, expand_ratio: float = 2.0, mlp_ratio: float = 4.0, |
| drop_rate: float = 0.0, use_zigzag: bool = True, |
| ): |
| super().__init__() |
| self.norm1 = AdaptiveGroupNorm(channels, cond_dim) |
| self.norm2 = AdaptiveGroupNorm(channels, cond_dim) |
| self.spatial_stim = GatedDepthwiseStimulusConv(channels, spatial_kernel, expand_ratio) |
| self.use_zigzag = use_zigzag |
| if use_zigzag: |
| self.zigzag = ZigzagScan1D(channels, scan_kernel) |
| self.zigzag_gate = nn.Parameter(torch.zeros(1)) |
| self.liquid = LiquidTimeConstant(channels) |
| self.channel_mix = ChannelMixMLP(channels, mlp_ratio) |
| self.liquid2 = LiquidTimeConstant(channels) |
| self.drop = nn.Dropout2d(drop_rate) if drop_rate > 0 else nn.Identity() |
| |
| def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: |
| h = self.norm1(x, cond) |
| stim = self.spatial_stim(h) |
| if self.use_zigzag: |
| zz = self.zigzag(h) |
| stim = stim + torch.sigmoid(self.zigzag_gate) * zz |
| stim = self.drop(stim) |
| x = self.liquid(x, stim) |
| h2 = self.norm2(x, cond) |
| ch_out = self.drop(self.channel_mix(h2)) |
| x = self.liquid2(x, ch_out) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class TimestepEmbedding(nn.Module): |
| """Sinusoidal timestep embedding followed by MLP projection.""" |
| def __init__(self, dim: int, freq_dim: int = 256): |
| super().__init__() |
| self.freq_dim = freq_dim |
| self.mlp = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) |
| |
| def forward(self, t: torch.Tensor) -> torch.Tensor: |
| half = self.freq_dim // 2 |
| freqs = torch.exp(-math.log(10000.0) * torch.arange(half, device=t.device, dtype=t.dtype) / half) |
| args = t.unsqueeze(-1) * freqs.unsqueeze(0) |
| emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| return self.mlp(emb) |
|
|
|
|
| class ClassEmbedding(nn.Module): |
| """Optional class-conditional embedding with CFG null embedding.""" |
| def __init__(self, num_classes: int, dim: int): |
| super().__init__() |
| self.embed = nn.Embedding(num_classes, dim) |
| self.null_embed = nn.Parameter(torch.randn(dim) * 0.02) |
| |
| def forward(self, labels: torch.Tensor, drop_prob: float = 0.0) -> torch.Tensor: |
| emb = self.embed(labels) |
| if self.training and drop_prob > 0: |
| mask = torch.rand(labels.shape[0], 1, device=labels.device) < drop_prob |
| emb = torch.where(mask, self.null_embed.unsqueeze(0).expand_as(emb), emb) |
| return emb |
|
|
|
|
| |
| |
| |
|
|
| class LiquidGen(nn.Module): |
| """ |
| LiquidGen: Liquid Neural Network Image Generator |
| |
| A novel attention-free diffusion model that uses Liquid Neural Network |
| dynamics (CfC closed-form continuous-depth) for image generation. |
| |
| Features: |
| - NO self-attention anywhere — O(n) complexity |
| - NO sequential ODE solving — fully parallelizable |
| - Liquid time constants for adaptive information blending |
| - Zigzag scanning for global context |
| - Depthwise convolutions for local spatial structure |
| - Gated stimulus (biologically-inspired from NCP) |
| - U-Net long skip connections (from U-ViT/DiM) |
| |
| Config Presets: |
| - LiquidGen-S: ~55M params (256px, fast training) |
| - LiquidGen-B: ~140M params (256/512px, balanced) |
| - LiquidGen-L: ~280M params (512px, high quality) |
| """ |
| |
| def __init__( |
| self, |
| in_channels: int = 4, |
| patch_size: int = 2, |
| embed_dim: int = 512, |
| depth: int = 16, |
| spatial_kernel: int = 7, |
| scan_kernel: int = 31, |
| expand_ratio: float = 2.0, |
| mlp_ratio: float = 4.0, |
| drop_rate: float = 0.0, |
| num_classes: int = 0, |
| class_drop_prob: float = 0.1, |
| use_zigzag: bool = True, |
| ): |
| super().__init__() |
| self.in_channels = in_channels |
| self.patch_size = patch_size |
| self.embed_dim = embed_dim |
| self.depth = depth |
| self.num_classes = num_classes |
| self.class_drop_prob = class_drop_prob |
| |
| cond_dim = embed_dim |
| |
| self.time_embed = TimestepEmbedding(cond_dim) |
| self.class_embed = ClassEmbedding(num_classes, cond_dim) if num_classes > 0 else None |
| |
| self.patch_embed = nn.Conv2d(in_channels, embed_dim, patch_size, stride=patch_size) |
| |
| self.pos_embed_size = 32 |
| self.pos_embed = nn.Parameter( |
| torch.randn(1, embed_dim, self.pos_embed_size, self.pos_embed_size) * 0.02 |
| ) |
| |
| self.input_proj = nn.Sequential( |
| nn.Conv2d(embed_dim, embed_dim, 3, padding=1, groups=embed_dim, bias=False), |
| nn.Conv2d(embed_dim, embed_dim, 1, bias=True), |
| nn.GELU(), |
| ) |
| |
| self.blocks = nn.ModuleList([ |
| LiquidBlock(embed_dim, cond_dim, spatial_kernel, scan_kernel, |
| expand_ratio, mlp_ratio, drop_rate, use_zigzag) |
| for _ in range(depth) |
| ]) |
| |
| self.final_norm = nn.GroupNorm(32, embed_dim) |
| self.final_proj = nn.Sequential( |
| nn.Conv2d(embed_dim, embed_dim, 3, padding=1, bias=True), |
| nn.GELU(), |
| ) |
| |
| self.unpatch = nn.ConvTranspose2d(embed_dim, in_channels, patch_size, stride=patch_size) |
| nn.init.zeros_(self.unpatch.weight) |
| nn.init.zeros_(self.unpatch.bias) |
| |
| self.apply(self._init_weights) |
| self._gradient_checkpointing = False |
| |
| def enable_gradient_checkpointing(self): |
| """Enable gradient checkpointing to reduce VRAM by ~40-60%. |
| Recomputes block activations during backward instead of storing them. |
| Slower training (~30%) but allows much larger batch sizes or models.""" |
| self._gradient_checkpointing = True |
| |
| def disable_gradient_checkpointing(self): |
| self._gradient_checkpointing = False |
| |
| def _init_weights(self, m): |
| if isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Embedding): |
| nn.init.normal_(m.weight, std=0.02) |
| |
| def _interpolate_pos_embed(self, H: int, W: int) -> torch.Tensor: |
| if H == self.pos_embed_size and W == self.pos_embed_size: |
| return self.pos_embed |
| return F.interpolate(self.pos_embed, size=(H, W), mode='bilinear', align_corners=False) |
| |
| def forward( |
| self, x: torch.Tensor, t: torch.Tensor, class_labels: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| Predict velocity field for flow matching. |
| Args: |
| x: [B, C, H, W] noisy latent (C=4 for SDXL VAE) |
| t: [B] timestep in [0, 1] |
| class_labels: [B] optional class labels |
| Returns: |
| v: [B, C, H, W] predicted velocity |
| """ |
| cond = self.time_embed(t) |
| if self.class_embed is not None and class_labels is not None: |
| drop_p = self.class_drop_prob if self.training else 0.0 |
| cond = cond + self.class_embed(class_labels, drop_prob=drop_p) |
| |
| h = self.patch_embed(x) |
| B, C, H_p, W_p = h.shape |
| h = h + self._interpolate_pos_embed(H_p, W_p) |
| h = self.input_proj(h) |
| |
| |
| skip_connections = [] |
| mid = self.depth // 2 |
| for i, block in enumerate(self.blocks): |
| if i < mid: |
| skip_connections.append(h) |
| elif i >= mid and len(skip_connections) > 0: |
| skip = skip_connections.pop() |
| h = h + skip |
| if self._gradient_checkpointing and self.training: |
| h = checkpoint(block, h, cond, use_reentrant=False) |
| else: |
| h = block(h, cond) |
| |
| h = self.final_norm(h) |
| h = self.final_proj(h) |
| v = self.unpatch(h) |
| return v |
| |
| def count_params(self) -> int: |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
|
|
| |
| |
| |
|
|
| def liquidgen_small(**kwargs) -> LiquidGen: |
| """~55M params - for 256px, fast training/testing""" |
| defaults = dict( |
| embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31, |
| expand_ratio=2.0, mlp_ratio=3.0, use_zigzag=True, |
| ) |
| defaults.update(kwargs) |
| return LiquidGen(**defaults) |
|
|
| def liquidgen_base(**kwargs) -> LiquidGen: |
| """~140M params - for 256/512px, balanced (fits T4 16GB easily)""" |
| defaults = dict( |
| embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31, |
| expand_ratio=2.0, mlp_ratio=4.0, use_zigzag=True, |
| ) |
| defaults.update(kwargs) |
| return LiquidGen(**defaults) |
|
|
| def liquidgen_large(**kwargs) -> LiquidGen: |
| """~280M params - for 512px, high quality (fits T4 16GB with small batch)""" |
| defaults = dict( |
| embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31, |
| expand_ratio=2.5, mlp_ratio=4.0, use_zigzag=True, |
| ) |
| defaults.update(kwargs) |
| return LiquidGen(**defaults) |
|
|
|
|
| if __name__ == "__main__": |
| device = "cpu" |
| for name, factory in [("Small", liquidgen_small), ("Base", liquidgen_base), ("Large", liquidgen_large)]: |
| model = factory(num_classes=27).to(device) |
| print(f"LiquidGen-{name}: {model.count_params() / 1e6:.1f}M params") |
| |
| |
| x = torch.randn(2, 4, 32, 32, device=device) |
| t = torch.rand(2, device=device) |
| labels = torch.randint(0, 27, (2,), device=device) |
| v = model(x, t, labels) |
| assert v.shape == x.shape |
| |
| |
| x512 = torch.randn(1, 4, 64, 64, device=device) |
| v512 = model(x512, t[:1], labels[:1]) |
| assert v512.shape == x512.shape |
| print(f" 256px ✅ 512px ✅") |
| del model |
| |
| print("\n✅ All tests passed!") |
|
|