""" LiquidGen: A Novel Liquid Neural Network Image Generation Model Architecture Overview: - Frozen VAE encoder/decoder (SDXL VAE, 4ch latent, 8x compression, no login needed) - Liquid backbone for denoising (fully parallelizable, no attention, no sequential ODE) - Flow matching training objective (velocity prediction) Key Innovation: Replaces attention with Liquid Neural Network dynamics: - CfC-inspired closed-form update: x_new = α·x + (1-α)·h(x) - Per-channel learnable decay rates (liquid time constants) - Depthwise + pointwise convolutions for spatial context (no attention needed) - Zigzag spatial scanning for global receptive field - Gated stimulus with biologically-inspired sign constraints - U-Net style long skip connections from shallow to deep blocks Math Foundation (from Hasani et al., CfC paper): x_{t+1} = exp(-Δt/τ_t) · x_t + (1 - exp(-Δt/τ_t)) · h(x_t, u_t) Our parallelizable adaptation (inspired by LiquidTAD): α = exp(-softplus(ρ)) [per-channel learnable decay] h = gate · stimulus [gated depthwise conv output] out = α · x + (1 - α) · h [liquid relaxation blend] This removes the input-dependent τ (which requires sequential computation) and replaces it with a per-channel learned decay — making it fully parallel while preserving the liquid dynamics' ability to blend old state with new input. Design for 16GB VRAM (Colab free tier): - VAE frozen: ~1GB - Backbone: ~55-280M params (~100-550MB in fp16) - Training overhead (grads + optimizer): ~3-8GB - Batch of latents: ~1-2GB - Total: fits comfortably in 16GB References: - Hasani et al., "Liquid Time-constant Networks" (NeurIPS 2020) - Hasani et al., "Closed-form Continuous-depth Models" (Nature Machine Intelligence 2022) - Lechner et al., "Neural Circuit Policies" (Nature Machine Intelligence 2020) - LiquidTAD (2025) - Parallelized liquid dynamics - ZigMa (ECCV 2024) - Zigzag scanning for SSM-based diffusion - DiMSUM (NeurIPS 2024) - Attention-free diffusion """ import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint import math from typing import Optional, Tuple # ============================================================================= # Building Blocks # ============================================================================= class LiquidTimeConstant(nn.Module): """ Core liquid time-constant module. Implements the CfC closed-form dynamics in a fully parallelizable way: out = α · x + (1 - α) · stimulus where α = exp(-softplus(ρ)) is a learnable per-channel decay rate, derived from the liquid time constant τ = 1/softplus(ρ). This preserves the key property of Liquid Neural Networks: - Exponential relaxation toward a target (stimulus) - Rate controlled by τ (how fast to adapt) - No sequential ODE solving required Stability guarantee (from LTC Theorem 1): τ_sys ∈ [τ/(1+τW), τ] — time constants NEVER explode """ def __init__(self, channels: int): super().__init__() # ρ parameterizes the decay: λ = softplus(ρ), α = exp(-λ) # Initialize ρ=0 → λ≈0.693 → α≈0.5 (equal blend of old and new) self.rho = nn.Parameter(torch.zeros(channels)) def forward(self, x: torch.Tensor, stimulus: torch.Tensor) -> torch.Tensor: """ x: [B, C, H, W] - current state (residual path) stimulus: [B, C, H, W] - computed target from context returns: [B, C, H, W] - liquid-blended output """ lam = F.softplus(self.rho) + 1e-5 alpha = torch.exp(-lam).view(1, -1, 1, 1) return alpha * x + (1.0 - alpha) * stimulus class GatedDepthwiseStimulusConv(nn.Module): """ Computes the spatial stimulus using depthwise-separable convolutions with a sigmoid gate (inspired by GLU / gated mechanisms in SSMs). This replaces attention for capturing local spatial context: - Depthwise conv: captures local spatial patterns per channel - Pointwise conv: mixes channel information - Sigmoid gate: controls information flow (like synaptic gating in NCP) Two parallel paths (inspired by NCP inter→command split): 1. Stimulus path: DW-conv → PW-conv → GELU → project back 2. Gate path: DW-conv → PW-conv → sigmoid Output = stimulus * gate """ def __init__(self, channels: int, kernel_size: int = 7, expand_ratio: float = 2.0): super().__init__() hidden = int(channels * expand_ratio) self.stim_dw = nn.Conv2d(channels, channels, kernel_size, padding=kernel_size // 2, groups=channels, bias=False) self.stim_pw = nn.Conv2d(channels, hidden, 1, bias=False) self.stim_act = nn.GELU() self.stim_proj = nn.Conv2d(hidden, channels, 1, bias=False) self.gate_dw = nn.Conv2d(channels, channels, kernel_size, padding=kernel_size // 2, groups=channels, bias=False) self.gate_pw = nn.Conv2d(channels, channels, 1, bias=True) def forward(self, x: torch.Tensor) -> torch.Tensor: stim = self.stim_proj(self.stim_act(self.stim_pw(self.stim_dw(x)))) gate = torch.sigmoid(self.gate_pw(self.gate_dw(x))) return stim * gate class ChannelMixMLP(nn.Module): """Channel mixing MLP with GELU activation (command neuron processing in NCP).""" def __init__(self, channels: int, expand_ratio: float = 4.0): super().__init__() hidden = int(channels * expand_ratio) self.fc1 = nn.Conv2d(channels, hidden, 1, bias=True) self.act = nn.GELU() self.fc2 = nn.Conv2d(hidden, channels, 1, bias=True) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc2(self.act(self.fc1(x))) class AdaptiveGroupNorm(nn.Module): """ Adaptive Group Normalization conditioned on timestep embedding. Applies: out = (1 + scale) * GroupNorm(x) + shift """ def __init__(self, channels: int, cond_dim: int, num_groups: int = 32): super().__init__() self.norm = nn.GroupNorm(num_groups, channels, affine=False) self.proj = nn.Linear(cond_dim, channels * 2) nn.init.zeros_(self.proj.weight) nn.init.zeros_(self.proj.bias) def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: h = self.norm(x) params = self.proj(cond) scale, shift = params.chunk(2, dim=-1) return h * (1.0 + scale.unsqueeze(-1).unsqueeze(-1)) + shift.unsqueeze(-1).unsqueeze(-1) class ZigzagScan1D(nn.Module): """ 1D global mixing via zigzag-scanned depthwise conv. Gives quasi-global receptive field without attention's O(n²) cost. Zigzag scan preserves spatial continuity (from ZigMa, ECCV 2024). """ def __init__(self, channels: int, kernel_size: int = 31): super().__init__() self.conv1d = nn.Conv1d(channels, channels, kernel_size, padding=kernel_size // 2, groups=channels, bias=False) self.pw = nn.Conv1d(channels, channels, 1, bias=True) self.act = nn.GELU() self._idx_cache = {} def _get_indices(self, H: int, W: int, device: torch.device): key = (H, W, device) if key not in self._idx_cache: indices = [] for i in range(H): row = list(range(i * W, (i + 1) * W)) if i % 2 == 1: row = row[::-1] indices.extend(row) fwd = torch.tensor(indices, device=device, dtype=torch.long) inv = torch.empty_like(fwd) inv[fwd] = torch.arange(H * W, device=device) self._idx_cache[key] = (fwd, inv) return self._idx_cache[key] def forward(self, x: torch.Tensor) -> torch.Tensor: B, C, H, W = x.shape zz_idx, inv_idx = self._get_indices(H, W, x.device) x_flat = x.reshape(B, C, H * W) x_zz = x_flat[:, :, zz_idx] x_mixed = self.pw(self.act(self.conv1d(x_zz))) x_restored = x_mixed[:, :, inv_idx] return x_restored.reshape(B, C, H, W) # ============================================================================= # Liquid Block: The core building block # ============================================================================= class LiquidBlock(nn.Module): """ A single Liquid Neural Network block for image denoising. Architecture (maps to NCP hierarchy): 1. [SENSORY] AdaGN conditioning → spatial context extraction 2. [INTER] Zigzag 1D scan for global mixing 3. [COMMAND] Liquid time-constant blend (CfC dynamics) 4. [MOTOR] Channel mixing MLP for output projection All operations are fully parallelizable — no sequential dependencies. """ def __init__( self, channels: int, cond_dim: int, spatial_kernel: int = 7, scan_kernel: int = 31, expand_ratio: float = 2.0, mlp_ratio: float = 4.0, drop_rate: float = 0.0, use_zigzag: bool = True, ): super().__init__() self.norm1 = AdaptiveGroupNorm(channels, cond_dim) self.norm2 = AdaptiveGroupNorm(channels, cond_dim) self.spatial_stim = GatedDepthwiseStimulusConv(channels, spatial_kernel, expand_ratio) self.use_zigzag = use_zigzag if use_zigzag: self.zigzag = ZigzagScan1D(channels, scan_kernel) self.zigzag_gate = nn.Parameter(torch.zeros(1)) self.liquid = LiquidTimeConstant(channels) self.channel_mix = ChannelMixMLP(channels, mlp_ratio) self.liquid2 = LiquidTimeConstant(channels) self.drop = nn.Dropout2d(drop_rate) if drop_rate > 0 else nn.Identity() def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: h = self.norm1(x, cond) stim = self.spatial_stim(h) if self.use_zigzag: zz = self.zigzag(h) stim = stim + torch.sigmoid(self.zigzag_gate) * zz stim = self.drop(stim) x = self.liquid(x, stim) h2 = self.norm2(x, cond) ch_out = self.drop(self.channel_mix(h2)) x = self.liquid2(x, ch_out) return x # ============================================================================= # Timestep and Class Embeddings # ============================================================================= class TimestepEmbedding(nn.Module): """Sinusoidal timestep embedding followed by MLP projection.""" def __init__(self, dim: int, freq_dim: int = 256): super().__init__() self.freq_dim = freq_dim self.mlp = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) def forward(self, t: torch.Tensor) -> torch.Tensor: half = self.freq_dim // 2 freqs = torch.exp(-math.log(10000.0) * torch.arange(half, device=t.device, dtype=t.dtype) / half) args = t.unsqueeze(-1) * freqs.unsqueeze(0) emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) return self.mlp(emb) class ClassEmbedding(nn.Module): """Optional class-conditional embedding with CFG null embedding.""" def __init__(self, num_classes: int, dim: int): super().__init__() self.embed = nn.Embedding(num_classes, dim) self.null_embed = nn.Parameter(torch.randn(dim) * 0.02) def forward(self, labels: torch.Tensor, drop_prob: float = 0.0) -> torch.Tensor: emb = self.embed(labels) if self.training and drop_prob > 0: mask = torch.rand(labels.shape[0], 1, device=labels.device) < drop_prob emb = torch.where(mask, self.null_embed.unsqueeze(0).expand_as(emb), emb) return emb # ============================================================================= # LiquidGen: Full Model # ============================================================================= class LiquidGen(nn.Module): """ LiquidGen: Liquid Neural Network Image Generator A novel attention-free diffusion model that uses Liquid Neural Network dynamics (CfC closed-form continuous-depth) for image generation. Features: - NO self-attention anywhere — O(n) complexity - NO sequential ODE solving — fully parallelizable - Liquid time constants for adaptive information blending - Zigzag scanning for global context - Depthwise convolutions for local spatial structure - Gated stimulus (biologically-inspired from NCP) - U-Net long skip connections (from U-ViT/DiM) Config Presets: - LiquidGen-S: ~55M params (256px, fast training) - LiquidGen-B: ~140M params (256/512px, balanced) - LiquidGen-L: ~280M params (512px, high quality) """ def __init__( self, in_channels: int = 4, # 4 for SDXL VAE patch_size: int = 2, embed_dim: int = 512, depth: int = 16, spatial_kernel: int = 7, scan_kernel: int = 31, expand_ratio: float = 2.0, mlp_ratio: float = 4.0, drop_rate: float = 0.0, num_classes: int = 0, class_drop_prob: float = 0.1, use_zigzag: bool = True, ): super().__init__() self.in_channels = in_channels self.patch_size = patch_size self.embed_dim = embed_dim self.depth = depth self.num_classes = num_classes self.class_drop_prob = class_drop_prob cond_dim = embed_dim self.time_embed = TimestepEmbedding(cond_dim) self.class_embed = ClassEmbedding(num_classes, cond_dim) if num_classes > 0 else None self.patch_embed = nn.Conv2d(in_channels, embed_dim, patch_size, stride=patch_size) self.pos_embed_size = 32 self.pos_embed = nn.Parameter( torch.randn(1, embed_dim, self.pos_embed_size, self.pos_embed_size) * 0.02 ) self.input_proj = nn.Sequential( nn.Conv2d(embed_dim, embed_dim, 3, padding=1, groups=embed_dim, bias=False), nn.Conv2d(embed_dim, embed_dim, 1, bias=True), nn.GELU(), ) self.blocks = nn.ModuleList([ LiquidBlock(embed_dim, cond_dim, spatial_kernel, scan_kernel, expand_ratio, mlp_ratio, drop_rate, use_zigzag) for _ in range(depth) ]) self.final_norm = nn.GroupNorm(32, embed_dim) self.final_proj = nn.Sequential( nn.Conv2d(embed_dim, embed_dim, 3, padding=1, bias=True), nn.GELU(), ) self.unpatch = nn.ConvTranspose2d(embed_dim, in_channels, patch_size, stride=patch_size) nn.init.zeros_(self.unpatch.weight) nn.init.zeros_(self.unpatch.bias) self.apply(self._init_weights) self._gradient_checkpointing = False def enable_gradient_checkpointing(self): """Enable gradient checkpointing to reduce VRAM by ~40-60%. Recomputes block activations during backward instead of storing them. Slower training (~30%) but allows much larger batch sizes or models.""" self._gradient_checkpointing = True def disable_gradient_checkpointing(self): self._gradient_checkpointing = False def _init_weights(self, m): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, std=0.02) def _interpolate_pos_embed(self, H: int, W: int) -> torch.Tensor: if H == self.pos_embed_size and W == self.pos_embed_size: return self.pos_embed return F.interpolate(self.pos_embed, size=(H, W), mode='bilinear', align_corners=False) def forward( self, x: torch.Tensor, t: torch.Tensor, class_labels: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Predict velocity field for flow matching. Args: x: [B, C, H, W] noisy latent (C=4 for SDXL VAE) t: [B] timestep in [0, 1] class_labels: [B] optional class labels Returns: v: [B, C, H, W] predicted velocity """ cond = self.time_embed(t) if self.class_embed is not None and class_labels is not None: drop_p = self.class_drop_prob if self.training else 0.0 cond = cond + self.class_embed(class_labels, drop_prob=drop_p) h = self.patch_embed(x) B, C, H_p, W_p = h.shape h = h + self._interpolate_pos_embed(H_p, W_p) h = self.input_proj(h) # U-Net style long skip connections skip_connections = [] mid = self.depth // 2 for i, block in enumerate(self.blocks): if i < mid: skip_connections.append(h) elif i >= mid and len(skip_connections) > 0: skip = skip_connections.pop() h = h + skip if self._gradient_checkpointing and self.training: h = checkpoint(block, h, cond, use_reentrant=False) else: h = block(h, cond) h = self.final_norm(h) h = self.final_proj(h) v = self.unpatch(h) return v def count_params(self) -> int: return sum(p.numel() for p in self.parameters() if p.requires_grad) # ============================================================================= # Model Presets # ============================================================================= def liquidgen_small(**kwargs) -> LiquidGen: """~55M params - for 256px, fast training/testing""" defaults = dict( embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31, expand_ratio=2.0, mlp_ratio=3.0, use_zigzag=True, ) defaults.update(kwargs) return LiquidGen(**defaults) def liquidgen_base(**kwargs) -> LiquidGen: """~140M params - for 256/512px, balanced (fits T4 16GB easily)""" defaults = dict( embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31, expand_ratio=2.0, mlp_ratio=4.0, use_zigzag=True, ) defaults.update(kwargs) return LiquidGen(**defaults) def liquidgen_large(**kwargs) -> LiquidGen: """~280M params - for 512px, high quality (fits T4 16GB with small batch)""" defaults = dict( embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31, expand_ratio=2.5, mlp_ratio=4.0, use_zigzag=True, ) defaults.update(kwargs) return LiquidGen(**defaults) if __name__ == "__main__": device = "cpu" for name, factory in [("Small", liquidgen_small), ("Base", liquidgen_base), ("Large", liquidgen_large)]: model = factory(num_classes=27).to(device) print(f"LiquidGen-{name}: {model.count_params() / 1e6:.1f}M params") # 256px: image/8 = 32x32 latent, 4 channels (SDXL VAE) x = torch.randn(2, 4, 32, 32, device=device) t = torch.rand(2, device=device) labels = torch.randint(0, 27, (2,), device=device) v = model(x, t, labels) assert v.shape == x.shape # 512px: image/8 = 64x64 latent x512 = torch.randn(1, 4, 64, 64, device=device) v512 = model(x512, t[:1], labels[:1]) assert v512.shape == x512.shape print(f" 256px ✅ 512px ✅") del model print("\n✅ All tests passed!")