""" LiquidFlow v0.5 — Fixed: patch_size scales with image size → L≤256 always CRITICAL FIX: patch_size now auto-scales so sequence length L stays ≤256. Before: 256px with patch=8 → L=1024 → 21x slower than needed → stuck After: 256px with patch=16 → L=256 → same speed as 128px Config table (all have L=256 tokens): tiny: 128px patch=8, d=192 depth=6 d_state=8 → ~4M params small: 128px patch=8, d=256 depth=8 d_state=8 → ~10M params base: 256px patch=16, d=384 depth=10 d_state=8 → ~24M params 512: 512px patch=32, d=384 depth=10 d_state=8 → ~24M params """ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint try: from mambapy.pscan import pscan as _pscan HAS_PSCAN = True except ImportError: HAS_PSCAN = False def parallel_scan(A, X): if HAS_PSCAN: return _pscan(A, X.clone()) else: B, L, ED, N = A.shape h = torch.zeros(B, ED, N, device=A.device, dtype=A.dtype) ys = [] for i in range(L): h = A[:, i] * h + X[:, i] ys.append(h) return torch.stack(ys, dim=1) class LiquidCfCCell(nn.Module): def __init__(self, input_dim, hidden_dim): super().__init__() self.backbone = nn.Linear(input_dim, hidden_dim) self.gate_proj = nn.Linear(hidden_dim, hidden_dim * 2) self.act = nn.Tanh() def forward(self, x): h = self.act(self.backbone(x)) f_tau, f_x = self.gate_proj(h).chunk(2, dim=-1) gate = torch.sigmoid(-f_tau) return gate * h + (1.0 - gate) * f_x class SelectiveSSM(nn.Module): def __init__(self, d_model, d_state=8, d_conv=4, expand=2): super().__init__() self.d_model = d_model self.d_state = d_state self.d_inner = int(d_model * expand) self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False) self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, d_conv, padding=d_conv-1, groups=self.d_inner, bias=True) A = torch.arange(1, d_state + 1, dtype=torch.float32) self.A_log = nn.Parameter(torch.log(A).unsqueeze(0).expand(self.d_inner, -1).clone()) self.D = nn.Parameter(torch.ones(self.d_inner)) self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False) self.dt_proj = nn.Linear(1, self.d_inner, bias=True) self.out_proj = nn.Linear(self.d_inner, d_model, bias=False) with torch.no_grad(): dt_init = torch.exp(torch.rand(self.d_inner) * (math.log(0.1) - math.log(0.001)) + math.log(0.001)) self.dt_proj.bias.copy_(dt_init + torch.log(-torch.expm1(-dt_init))) def forward(self, x): B, L, _ = x.shape xz = self.in_proj(x) x_inner, z = xz.chunk(2, dim=-1) x_conv = F.silu(self.conv1d(x_inner.transpose(1, 2))[:, :, :L].transpose(1, 2)) x_ssm = self.x_proj(x_conv) B_sel = x_ssm[:, :, :self.d_state] C_sel = x_ssm[:, :, self.d_state:2*self.d_state] dt = F.softplus(self.dt_proj(x_ssm[:, :, -1:])) A = -torch.exp(self.A_log) A_bar = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)) BX = dt.unsqueeze(-1) * B_sel.unsqueeze(2) * x_conv.unsqueeze(-1) orig_L = L next_pow2 = 1 << (L - 1).bit_length() if next_pow2 != L: pad = next_pow2 - L A_bar = F.pad(A_bar, (0,0, 0,0, 0,pad), value=1.0) BX = F.pad(BX, (0,0, 0,0, 0,pad), value=0.0) h_all = parallel_scan(A_bar, BX)[:, :orig_L] y = (h_all * C_sel.unsqueeze(2)).sum(-1) y = y + x_conv * self.D.unsqueeze(0).unsqueeze(0) return self.out_proj(y * F.silu(z)) def create_scan_patterns(H, W): total = H * W; idx = torch.arange(total); grid = idx.view(H, W) patterns = [idx.clone(), idx.flip(0), grid.t().contiguous().view(-1), torch.cat([grid[i].flip(0) if i % 2 else grid[i] for i in range(H)])] inv = [] for p in patterns: i = torch.zeros_like(p); i[p] = torch.arange(total); inv.append(i) return patterns, inv class LiquidSSMBlock(nn.Module): def __init__(self, d_model, d_state=8, d_conv=4, expand=2, dropout=0.0): super().__init__() self.norm1 = nn.LayerNorm(d_model) self.ssm = SelectiveSSM(d_model, d_state, d_conv, expand) self.norm2 = nn.LayerNorm(d_model) self.liquid = LiquidCfCCell(d_model, d_model) self.norm3 = nn.LayerNorm(d_model) self.ff = nn.Sequential( nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model * 4, d_model), nn.Dropout(dropout)) self.mix_alpha = nn.Parameter(torch.tensor(0.5)) def _ssm_fwd(self, x): return self.ssm(self.norm1(x)) def _liq_fwd(self, x): return self.liquid(self.norm2(x)) def forward(self, x, scan_idx=None, unscan_idx=None): xs = x[:, scan_idx] if scan_idx is not None else x if self.training and x.requires_grad: so = checkpoint(self._ssm_fwd, xs, use_reentrant=False) lo = checkpoint(self._liq_fwd, x, use_reentrant=False) else: so = self._ssm_fwd(xs); lo = self._liq_fwd(x) if unscan_idx is not None: so = so[:, unscan_idx] a = torch.sigmoid(self.mix_alpha) x = x + a * so + (1 - a) * lo return x + self.ff(self.norm3(x)) class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__(); self.dim = dim def forward(self, t): h = self.dim // 2; e = math.log(10000)/(h-1) e = torch.exp(torch.arange(h, device=t.device)*-e) return torch.cat([(t.unsqueeze(-1)*e.unsqueeze(0)).sin(), (t.unsqueeze(-1)*e.unsqueeze(0)).cos()], -1) class AdaptiveLayerNorm(nn.Module): def __init__(self, d, c): super().__init__(); self.norm = nn.LayerNorm(d, elementwise_affine=False) self.proj = nn.Sequential(nn.SiLU(), nn.Linear(c, d*2)) def forward(self, x, cond): s, b = self.proj(cond).chunk(2, -1) return self.norm(x) * (1+s.unsqueeze(1)) + b.unsqueeze(1) class LiquidFlowNet(nn.Module): def __init__(self, img_size=128, patch_size=8, in_channels=3, d_model=256, depth=8, d_state=8, d_conv=4, expand=2, dropout=0.0, num_classes=0): super().__init__() self.img_size = img_size; self.patch_size = patch_size self.in_channels = in_channels; self.d_model = d_model self.depth = depth; self.num_classes = num_classes self.num_patches_h = img_size // patch_size self.num_patches_w = img_size // patch_size self.num_patches = self.num_patches_h * self.num_patches_w self.patch_dim = in_channels * patch_size * patch_size self.patch_embed = nn.Sequential(nn.Linear(self.patch_dim, d_model), nn.LayerNorm(d_model)) self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, d_model) * 0.02) self.time_embed = nn.Sequential(SinusoidalPosEmb(d_model), nn.Linear(d_model, d_model*4), nn.GELU(), nn.Linear(d_model*4, d_model)) self.class_embed = nn.Embedding(num_classes, d_model) if num_classes > 0 else None self.blocks = nn.ModuleList([LiquidSSMBlock(d_model, d_state, d_conv, expand, dropout) for _ in range(depth)]) self.adaln = nn.ModuleList([AdaptiveLayerNorm(d_model, d_model) for _ in range(depth)]) self.skips = nn.ModuleList([nn.Linear(d_model*2, d_model) for _ in range(depth//2)]) self.final_norm = nn.LayerNorm(d_model) self.final_proj = nn.Linear(d_model, self.patch_dim) pats, ipats = create_scan_patterns(self.num_patches_h, self.num_patches_w) for i,(p,ip) in enumerate(zip(pats, ipats)): self.register_buffer(f'scan_{i}', p); self.register_buffer(f'unscan_{i}', ip) self.n_scans = len(pats) self.pre_conv = nn.Conv2d(d_model, d_model, 3, padding=1, groups=d_model) self.post_conv = nn.Conv2d(d_model, d_model, 3, padding=1, groups=d_model) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, (nn.Conv2d, nn.Conv1d)): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) nn.init.zeros_(self.final_proj.weight); nn.init.zeros_(self.final_proj.bias) def patchify(self, x): B,C,H,W = x.shape; p = self.patch_size return x.unfold(2,p,p).unfold(3,p,p).contiguous().view(B,C,self.num_patches_h,self.num_patches_w,p*p).permute(0,2,3,1,4).contiguous().view(B,self.num_patches,self.patch_dim) def unpatchify(self, x): B=x.shape[0]; p=self.patch_size return x.view(B,self.num_patches_h,self.num_patches_w,self.in_channels,p,p).permute(0,3,1,4,2,5).contiguous().view(B,self.in_channels,self.num_patches_h*p,self.num_patches_w*p) def forward(self, x, t, class_label=None): B = x.shape[0] tok = self.patch_embed(self.patchify(x)) + self.pos_embed h = tok.view(B,self.num_patches_h,self.num_patches_w,self.d_model).permute(0,3,1,2) tok = self.pre_conv(h).permute(0,2,3,1).contiguous().view(B,self.num_patches,self.d_model) te = self.time_embed(t) if self.class_embed is not None and class_label is not None: te = te + self.class_embed(class_label) sk = [] for i,(blk,aln) in enumerate(zip(self.blocks, self.adaln)): tok = aln(tok, te); si = i % self.n_scans if i < self.depth//2: sk.append(tok) tok = blk(tok, getattr(self,f'scan_{si}'), getattr(self,f'unscan_{si}')) if i >= self.depth//2: j = self.depth-1-i if j < len(sk): tok = self.skips[j](torch.cat([tok, sk[j]], -1)) h = tok.view(B,self.num_patches_h,self.num_patches_w,self.d_model).permute(0,3,1,2) tok = self.post_conv(h).permute(0,2,3,1).contiguous().view(B,self.num_patches,self.d_model) return self.unpatchify(self.final_proj(self.final_norm(tok))) def count_params(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) # ============================================================ # AUTO PATCH SIZE: keeps L ≤ 256 tokens for ALL image sizes # ============================================================ def _auto_patch(img_size, max_tokens=256): """Pick smallest patch_size that keeps L ≤ max_tokens.""" for ps in [4, 8, 16, 32, 64]: L = (img_size // ps) ** 2 if L <= max_tokens: return ps return img_size // int(max_tokens ** 0.5) def liquidflow_tiny(img_size=128, num_classes=0): """~4-5M params. L≤256 for any image size.""" ps = _auto_patch(img_size) return LiquidFlowNet(img_size=img_size, patch_size=ps, d_model=192, depth=6, d_state=8, expand=2, num_classes=num_classes) def liquidflow_small(img_size=128, num_classes=0): """~10M params. L≤256 for any image size.""" ps = _auto_patch(img_size) return LiquidFlowNet(img_size=img_size, patch_size=ps, d_model=256, depth=8, d_state=8, expand=2, num_classes=num_classes) def liquidflow_base(img_size=256, num_classes=0): """~24M params. L≤256 for any image size.""" ps = _auto_patch(img_size) return LiquidFlowNet(img_size=img_size, patch_size=ps, d_model=384, depth=10, d_state=8, expand=2, num_classes=num_classes) def liquidflow_512(img_size=512, num_classes=0): """~24M params. L≤256 for any image size.""" ps = _auto_patch(img_size) return LiquidFlowNet(img_size=img_size, patch_size=ps, d_model=384, depth=10, d_state=8, expand=2, num_classes=num_classes)