| """ |
| LiquidFlow v0.5 — Fixed: patch_size scales with image size → L≤256 always |
| |
| CRITICAL FIX: patch_size now auto-scales so sequence length L stays ≤256. |
| Before: 256px with patch=8 → L=1024 → 21x slower than needed → stuck |
| After: 256px with patch=16 → L=256 → same speed as 128px |
| |
| Config table (all have L=256 tokens): |
| tiny: 128px patch=8, d=192 depth=6 d_state=8 → ~4M params |
| small: 128px patch=8, d=256 depth=8 d_state=8 → ~10M params |
| base: 256px patch=16, d=384 depth=10 d_state=8 → ~24M params |
| 512: 512px patch=32, d=384 depth=10 d_state=8 → ~24M params |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.checkpoint import checkpoint |
|
|
| try: |
| from mambapy.pscan import pscan as _pscan |
| HAS_PSCAN = True |
| except ImportError: |
| HAS_PSCAN = False |
|
|
| def parallel_scan(A, X): |
| if HAS_PSCAN: |
| return _pscan(A, X.clone()) |
| else: |
| B, L, ED, N = A.shape |
| h = torch.zeros(B, ED, N, device=A.device, dtype=A.dtype) |
| ys = [] |
| for i in range(L): |
| h = A[:, i] * h + X[:, i] |
| ys.append(h) |
| return torch.stack(ys, dim=1) |
|
|
|
|
| class LiquidCfCCell(nn.Module): |
| def __init__(self, input_dim, hidden_dim): |
| super().__init__() |
| self.backbone = nn.Linear(input_dim, hidden_dim) |
| self.gate_proj = nn.Linear(hidden_dim, hidden_dim * 2) |
| self.act = nn.Tanh() |
| def forward(self, x): |
| h = self.act(self.backbone(x)) |
| f_tau, f_x = self.gate_proj(h).chunk(2, dim=-1) |
| gate = torch.sigmoid(-f_tau) |
| return gate * h + (1.0 - gate) * f_x |
|
|
|
|
| class SelectiveSSM(nn.Module): |
| def __init__(self, d_model, d_state=8, d_conv=4, expand=2): |
| super().__init__() |
| self.d_model = d_model |
| self.d_state = d_state |
| self.d_inner = int(d_model * expand) |
|
|
| self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False) |
| self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, d_conv, |
| padding=d_conv-1, groups=self.d_inner, bias=True) |
|
|
| A = torch.arange(1, d_state + 1, dtype=torch.float32) |
| self.A_log = nn.Parameter(torch.log(A).unsqueeze(0).expand(self.d_inner, -1).clone()) |
| self.D = nn.Parameter(torch.ones(self.d_inner)) |
|
|
| self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False) |
| self.dt_proj = nn.Linear(1, self.d_inner, bias=True) |
| self.out_proj = nn.Linear(self.d_inner, d_model, bias=False) |
|
|
| with torch.no_grad(): |
| dt_init = torch.exp(torch.rand(self.d_inner) * (math.log(0.1) - math.log(0.001)) + math.log(0.001)) |
| self.dt_proj.bias.copy_(dt_init + torch.log(-torch.expm1(-dt_init))) |
|
|
| def forward(self, x): |
| B, L, _ = x.shape |
| xz = self.in_proj(x) |
| x_inner, z = xz.chunk(2, dim=-1) |
| x_conv = F.silu(self.conv1d(x_inner.transpose(1, 2))[:, :, :L].transpose(1, 2)) |
|
|
| x_ssm = self.x_proj(x_conv) |
| B_sel = x_ssm[:, :, :self.d_state] |
| C_sel = x_ssm[:, :, self.d_state:2*self.d_state] |
| dt = F.softplus(self.dt_proj(x_ssm[:, :, -1:])) |
| A = -torch.exp(self.A_log) |
|
|
| A_bar = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)) |
| BX = dt.unsqueeze(-1) * B_sel.unsqueeze(2) * x_conv.unsqueeze(-1) |
|
|
| orig_L = L |
| next_pow2 = 1 << (L - 1).bit_length() |
| if next_pow2 != L: |
| pad = next_pow2 - L |
| A_bar = F.pad(A_bar, (0,0, 0,0, 0,pad), value=1.0) |
| BX = F.pad(BX, (0,0, 0,0, 0,pad), value=0.0) |
|
|
| h_all = parallel_scan(A_bar, BX)[:, :orig_L] |
| y = (h_all * C_sel.unsqueeze(2)).sum(-1) |
| y = y + x_conv * self.D.unsqueeze(0).unsqueeze(0) |
| return self.out_proj(y * F.silu(z)) |
|
|
|
|
| def create_scan_patterns(H, W): |
| total = H * W; idx = torch.arange(total); grid = idx.view(H, W) |
| patterns = [idx.clone(), idx.flip(0), grid.t().contiguous().view(-1), |
| torch.cat([grid[i].flip(0) if i % 2 else grid[i] for i in range(H)])] |
| inv = [] |
| for p in patterns: |
| i = torch.zeros_like(p); i[p] = torch.arange(total); inv.append(i) |
| return patterns, inv |
|
|
|
|
| class LiquidSSMBlock(nn.Module): |
| def __init__(self, d_model, d_state=8, d_conv=4, expand=2, dropout=0.0): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(d_model) |
| self.ssm = SelectiveSSM(d_model, d_state, d_conv, expand) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.liquid = LiquidCfCCell(d_model, d_model) |
| self.norm3 = nn.LayerNorm(d_model) |
| self.ff = nn.Sequential( |
| nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Dropout(dropout), |
| nn.Linear(d_model * 4, d_model), nn.Dropout(dropout)) |
| self.mix_alpha = nn.Parameter(torch.tensor(0.5)) |
| def _ssm_fwd(self, x): return self.ssm(self.norm1(x)) |
| def _liq_fwd(self, x): return self.liquid(self.norm2(x)) |
| def forward(self, x, scan_idx=None, unscan_idx=None): |
| xs = x[:, scan_idx] if scan_idx is not None else x |
| if self.training and x.requires_grad: |
| so = checkpoint(self._ssm_fwd, xs, use_reentrant=False) |
| lo = checkpoint(self._liq_fwd, x, use_reentrant=False) |
| else: |
| so = self._ssm_fwd(xs); lo = self._liq_fwd(x) |
| if unscan_idx is not None: so = so[:, unscan_idx] |
| a = torch.sigmoid(self.mix_alpha) |
| x = x + a * so + (1 - a) * lo |
| return x + self.ff(self.norm3(x)) |
|
|
|
|
| class SinusoidalPosEmb(nn.Module): |
| def __init__(self, dim): super().__init__(); self.dim = dim |
| def forward(self, t): |
| h = self.dim // 2; e = math.log(10000)/(h-1) |
| e = torch.exp(torch.arange(h, device=t.device)*-e) |
| return torch.cat([(t.unsqueeze(-1)*e.unsqueeze(0)).sin(), (t.unsqueeze(-1)*e.unsqueeze(0)).cos()], -1) |
|
|
| class AdaptiveLayerNorm(nn.Module): |
| def __init__(self, d, c): |
| super().__init__(); self.norm = nn.LayerNorm(d, elementwise_affine=False) |
| self.proj = nn.Sequential(nn.SiLU(), nn.Linear(c, d*2)) |
| def forward(self, x, cond): |
| s, b = self.proj(cond).chunk(2, -1) |
| return self.norm(x) * (1+s.unsqueeze(1)) + b.unsqueeze(1) |
|
|
|
|
| class LiquidFlowNet(nn.Module): |
| def __init__(self, img_size=128, patch_size=8, in_channels=3, d_model=256, |
| depth=8, d_state=8, d_conv=4, expand=2, dropout=0.0, num_classes=0): |
| super().__init__() |
| self.img_size = img_size; self.patch_size = patch_size |
| self.in_channels = in_channels; self.d_model = d_model |
| self.depth = depth; self.num_classes = num_classes |
| self.num_patches_h = img_size // patch_size |
| self.num_patches_w = img_size // patch_size |
| self.num_patches = self.num_patches_h * self.num_patches_w |
| self.patch_dim = in_channels * patch_size * patch_size |
|
|
| self.patch_embed = nn.Sequential(nn.Linear(self.patch_dim, d_model), nn.LayerNorm(d_model)) |
| self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, d_model) * 0.02) |
| self.time_embed = nn.Sequential(SinusoidalPosEmb(d_model), nn.Linear(d_model, d_model*4), nn.GELU(), nn.Linear(d_model*4, d_model)) |
| self.class_embed = nn.Embedding(num_classes, d_model) if num_classes > 0 else None |
|
|
| self.blocks = nn.ModuleList([LiquidSSMBlock(d_model, d_state, d_conv, expand, dropout) for _ in range(depth)]) |
| self.adaln = nn.ModuleList([AdaptiveLayerNorm(d_model, d_model) for _ in range(depth)]) |
| self.skips = nn.ModuleList([nn.Linear(d_model*2, d_model) for _ in range(depth//2)]) |
| self.final_norm = nn.LayerNorm(d_model) |
| self.final_proj = nn.Linear(d_model, self.patch_dim) |
|
|
| pats, ipats = create_scan_patterns(self.num_patches_h, self.num_patches_w) |
| for i,(p,ip) in enumerate(zip(pats, ipats)): |
| self.register_buffer(f'scan_{i}', p); self.register_buffer(f'unscan_{i}', ip) |
| self.n_scans = len(pats) |
| self.pre_conv = nn.Conv2d(d_model, d_model, 3, padding=1, groups=d_model) |
| self.post_conv = nn.Conv2d(d_model, d_model, 3, padding=1, groups=d_model) |
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: nn.init.zeros_(m.bias) |
| elif isinstance(m, (nn.Conv2d, nn.Conv1d)): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: nn.init.zeros_(m.bias) |
| nn.init.zeros_(self.final_proj.weight); nn.init.zeros_(self.final_proj.bias) |
|
|
| def patchify(self, x): |
| B,C,H,W = x.shape; p = self.patch_size |
| return x.unfold(2,p,p).unfold(3,p,p).contiguous().view(B,C,self.num_patches_h,self.num_patches_w,p*p).permute(0,2,3,1,4).contiguous().view(B,self.num_patches,self.patch_dim) |
|
|
| def unpatchify(self, x): |
| B=x.shape[0]; p=self.patch_size |
| return x.view(B,self.num_patches_h,self.num_patches_w,self.in_channels,p,p).permute(0,3,1,4,2,5).contiguous().view(B,self.in_channels,self.num_patches_h*p,self.num_patches_w*p) |
|
|
| def forward(self, x, t, class_label=None): |
| B = x.shape[0] |
| tok = self.patch_embed(self.patchify(x)) + self.pos_embed |
| h = tok.view(B,self.num_patches_h,self.num_patches_w,self.d_model).permute(0,3,1,2) |
| tok = self.pre_conv(h).permute(0,2,3,1).contiguous().view(B,self.num_patches,self.d_model) |
| te = self.time_embed(t) |
| if self.class_embed is not None and class_label is not None: te = te + self.class_embed(class_label) |
| sk = [] |
| for i,(blk,aln) in enumerate(zip(self.blocks, self.adaln)): |
| tok = aln(tok, te); si = i % self.n_scans |
| if i < self.depth//2: sk.append(tok) |
| tok = blk(tok, getattr(self,f'scan_{si}'), getattr(self,f'unscan_{si}')) |
| if i >= self.depth//2: |
| j = self.depth-1-i |
| if j < len(sk): tok = self.skips[j](torch.cat([tok, sk[j]], -1)) |
| h = tok.view(B,self.num_patches_h,self.num_patches_w,self.d_model).permute(0,3,1,2) |
| tok = self.post_conv(h).permute(0,2,3,1).contiguous().view(B,self.num_patches,self.d_model) |
| return self.unpatchify(self.final_proj(self.final_norm(tok))) |
|
|
| def count_params(self): |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
|
|
| |
| |
| |
|
|
| def _auto_patch(img_size, max_tokens=256): |
| """Pick smallest patch_size that keeps L ≤ max_tokens.""" |
| for ps in [4, 8, 16, 32, 64]: |
| L = (img_size // ps) ** 2 |
| if L <= max_tokens: |
| return ps |
| return img_size // int(max_tokens ** 0.5) |
|
|
| def liquidflow_tiny(img_size=128, num_classes=0): |
| """~4-5M params. L≤256 for any image size.""" |
| ps = _auto_patch(img_size) |
| return LiquidFlowNet(img_size=img_size, patch_size=ps, d_model=192, depth=6, d_state=8, expand=2, num_classes=num_classes) |
|
|
| def liquidflow_small(img_size=128, num_classes=0): |
| """~10M params. L≤256 for any image size.""" |
| ps = _auto_patch(img_size) |
| return LiquidFlowNet(img_size=img_size, patch_size=ps, d_model=256, depth=8, d_state=8, expand=2, num_classes=num_classes) |
|
|
| def liquidflow_base(img_size=256, num_classes=0): |
| """~24M params. L≤256 for any image size.""" |
| ps = _auto_patch(img_size) |
| return LiquidFlowNet(img_size=img_size, patch_size=ps, d_model=384, depth=10, d_state=8, expand=2, num_classes=num_classes) |
|
|
| def liquidflow_512(img_size=512, num_classes=0): |
| """~24M params. L≤256 for any image size.""" |
| ps = _auto_patch(img_size) |
| return LiquidFlowNet(img_size=img_size, patch_size=ps, d_model=384, depth=10, d_state=8, expand=2, num_classes=num_classes) |