LiquidFlow / liquidflow /model.py
krystv's picture
v0.5 CRITICAL: Fix patch_size to scale with image size (L≤256 always), reduce d_state for speed, add config table
1c47b5f verified
"""
LiquidFlow v0.5 — Fixed: patch_size scales with image size → L≤256 always
CRITICAL FIX: patch_size now auto-scales so sequence length L stays ≤256.
Before: 256px with patch=8 → L=1024 → 21x slower than needed → stuck
After: 256px with patch=16 → L=256 → same speed as 128px
Config table (all have L=256 tokens):
tiny: 128px patch=8, d=192 depth=6 d_state=8 → ~4M params
small: 128px patch=8, d=256 depth=8 d_state=8 → ~10M params
base: 256px patch=16, d=384 depth=10 d_state=8 → ~24M params
512: 512px patch=32, d=384 depth=10 d_state=8 → ~24M params
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
try:
from mambapy.pscan import pscan as _pscan
HAS_PSCAN = True
except ImportError:
HAS_PSCAN = False
def parallel_scan(A, X):
if HAS_PSCAN:
return _pscan(A, X.clone())
else:
B, L, ED, N = A.shape
h = torch.zeros(B, ED, N, device=A.device, dtype=A.dtype)
ys = []
for i in range(L):
h = A[:, i] * h + X[:, i]
ys.append(h)
return torch.stack(ys, dim=1)
class LiquidCfCCell(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.backbone = nn.Linear(input_dim, hidden_dim)
self.gate_proj = nn.Linear(hidden_dim, hidden_dim * 2)
self.act = nn.Tanh()
def forward(self, x):
h = self.act(self.backbone(x))
f_tau, f_x = self.gate_proj(h).chunk(2, dim=-1)
gate = torch.sigmoid(-f_tau)
return gate * h + (1.0 - gate) * f_x
class SelectiveSSM(nn.Module):
def __init__(self, d_model, d_state=8, d_conv=4, expand=2):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_inner = int(d_model * expand)
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, d_conv,
padding=d_conv-1, groups=self.d_inner, bias=True)
A = torch.arange(1, d_state + 1, dtype=torch.float32)
self.A_log = nn.Parameter(torch.log(A).unsqueeze(0).expand(self.d_inner, -1).clone())
self.D = nn.Parameter(torch.ones(self.d_inner))
self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
self.dt_proj = nn.Linear(1, self.d_inner, bias=True)
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
with torch.no_grad():
dt_init = torch.exp(torch.rand(self.d_inner) * (math.log(0.1) - math.log(0.001)) + math.log(0.001))
self.dt_proj.bias.copy_(dt_init + torch.log(-torch.expm1(-dt_init)))
def forward(self, x):
B, L, _ = x.shape
xz = self.in_proj(x)
x_inner, z = xz.chunk(2, dim=-1)
x_conv = F.silu(self.conv1d(x_inner.transpose(1, 2))[:, :, :L].transpose(1, 2))
x_ssm = self.x_proj(x_conv)
B_sel = x_ssm[:, :, :self.d_state]
C_sel = x_ssm[:, :, self.d_state:2*self.d_state]
dt = F.softplus(self.dt_proj(x_ssm[:, :, -1:]))
A = -torch.exp(self.A_log)
A_bar = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0))
BX = dt.unsqueeze(-1) * B_sel.unsqueeze(2) * x_conv.unsqueeze(-1)
orig_L = L
next_pow2 = 1 << (L - 1).bit_length()
if next_pow2 != L:
pad = next_pow2 - L
A_bar = F.pad(A_bar, (0,0, 0,0, 0,pad), value=1.0)
BX = F.pad(BX, (0,0, 0,0, 0,pad), value=0.0)
h_all = parallel_scan(A_bar, BX)[:, :orig_L]
y = (h_all * C_sel.unsqueeze(2)).sum(-1)
y = y + x_conv * self.D.unsqueeze(0).unsqueeze(0)
return self.out_proj(y * F.silu(z))
def create_scan_patterns(H, W):
total = H * W; idx = torch.arange(total); grid = idx.view(H, W)
patterns = [idx.clone(), idx.flip(0), grid.t().contiguous().view(-1),
torch.cat([grid[i].flip(0) if i % 2 else grid[i] for i in range(H)])]
inv = []
for p in patterns:
i = torch.zeros_like(p); i[p] = torch.arange(total); inv.append(i)
return patterns, inv
class LiquidSSMBlock(nn.Module):
def __init__(self, d_model, d_state=8, d_conv=4, expand=2, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.ssm = SelectiveSSM(d_model, d_state, d_conv, expand)
self.norm2 = nn.LayerNorm(d_model)
self.liquid = LiquidCfCCell(d_model, d_model)
self.norm3 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Dropout(dropout),
nn.Linear(d_model * 4, d_model), nn.Dropout(dropout))
self.mix_alpha = nn.Parameter(torch.tensor(0.5))
def _ssm_fwd(self, x): return self.ssm(self.norm1(x))
def _liq_fwd(self, x): return self.liquid(self.norm2(x))
def forward(self, x, scan_idx=None, unscan_idx=None):
xs = x[:, scan_idx] if scan_idx is not None else x
if self.training and x.requires_grad:
so = checkpoint(self._ssm_fwd, xs, use_reentrant=False)
lo = checkpoint(self._liq_fwd, x, use_reentrant=False)
else:
so = self._ssm_fwd(xs); lo = self._liq_fwd(x)
if unscan_idx is not None: so = so[:, unscan_idx]
a = torch.sigmoid(self.mix_alpha)
x = x + a * so + (1 - a) * lo
return x + self.ff(self.norm3(x))
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim): super().__init__(); self.dim = dim
def forward(self, t):
h = self.dim // 2; e = math.log(10000)/(h-1)
e = torch.exp(torch.arange(h, device=t.device)*-e)
return torch.cat([(t.unsqueeze(-1)*e.unsqueeze(0)).sin(), (t.unsqueeze(-1)*e.unsqueeze(0)).cos()], -1)
class AdaptiveLayerNorm(nn.Module):
def __init__(self, d, c):
super().__init__(); self.norm = nn.LayerNorm(d, elementwise_affine=False)
self.proj = nn.Sequential(nn.SiLU(), nn.Linear(c, d*2))
def forward(self, x, cond):
s, b = self.proj(cond).chunk(2, -1)
return self.norm(x) * (1+s.unsqueeze(1)) + b.unsqueeze(1)
class LiquidFlowNet(nn.Module):
def __init__(self, img_size=128, patch_size=8, in_channels=3, d_model=256,
depth=8, d_state=8, d_conv=4, expand=2, dropout=0.0, num_classes=0):
super().__init__()
self.img_size = img_size; self.patch_size = patch_size
self.in_channels = in_channels; self.d_model = d_model
self.depth = depth; self.num_classes = num_classes
self.num_patches_h = img_size // patch_size
self.num_patches_w = img_size // patch_size
self.num_patches = self.num_patches_h * self.num_patches_w
self.patch_dim = in_channels * patch_size * patch_size
self.patch_embed = nn.Sequential(nn.Linear(self.patch_dim, d_model), nn.LayerNorm(d_model))
self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, d_model) * 0.02)
self.time_embed = nn.Sequential(SinusoidalPosEmb(d_model), nn.Linear(d_model, d_model*4), nn.GELU(), nn.Linear(d_model*4, d_model))
self.class_embed = nn.Embedding(num_classes, d_model) if num_classes > 0 else None
self.blocks = nn.ModuleList([LiquidSSMBlock(d_model, d_state, d_conv, expand, dropout) for _ in range(depth)])
self.adaln = nn.ModuleList([AdaptiveLayerNorm(d_model, d_model) for _ in range(depth)])
self.skips = nn.ModuleList([nn.Linear(d_model*2, d_model) for _ in range(depth//2)])
self.final_norm = nn.LayerNorm(d_model)
self.final_proj = nn.Linear(d_model, self.patch_dim)
pats, ipats = create_scan_patterns(self.num_patches_h, self.num_patches_w)
for i,(p,ip) in enumerate(zip(pats, ipats)):
self.register_buffer(f'scan_{i}', p); self.register_buffer(f'unscan_{i}', ip)
self.n_scans = len(pats)
self.pre_conv = nn.Conv2d(d_model, d_model, 3, padding=1, groups=d_model)
self.post_conv = nn.Conv2d(d_model, d_model, 3, padding=1, groups=d_model)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
elif isinstance(m, (nn.Conv2d, nn.Conv1d)):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
nn.init.zeros_(self.final_proj.weight); nn.init.zeros_(self.final_proj.bias)
def patchify(self, x):
B,C,H,W = x.shape; p = self.patch_size
return x.unfold(2,p,p).unfold(3,p,p).contiguous().view(B,C,self.num_patches_h,self.num_patches_w,p*p).permute(0,2,3,1,4).contiguous().view(B,self.num_patches,self.patch_dim)
def unpatchify(self, x):
B=x.shape[0]; p=self.patch_size
return x.view(B,self.num_patches_h,self.num_patches_w,self.in_channels,p,p).permute(0,3,1,4,2,5).contiguous().view(B,self.in_channels,self.num_patches_h*p,self.num_patches_w*p)
def forward(self, x, t, class_label=None):
B = x.shape[0]
tok = self.patch_embed(self.patchify(x)) + self.pos_embed
h = tok.view(B,self.num_patches_h,self.num_patches_w,self.d_model).permute(0,3,1,2)
tok = self.pre_conv(h).permute(0,2,3,1).contiguous().view(B,self.num_patches,self.d_model)
te = self.time_embed(t)
if self.class_embed is not None and class_label is not None: te = te + self.class_embed(class_label)
sk = []
for i,(blk,aln) in enumerate(zip(self.blocks, self.adaln)):
tok = aln(tok, te); si = i % self.n_scans
if i < self.depth//2: sk.append(tok)
tok = blk(tok, getattr(self,f'scan_{si}'), getattr(self,f'unscan_{si}'))
if i >= self.depth//2:
j = self.depth-1-i
if j < len(sk): tok = self.skips[j](torch.cat([tok, sk[j]], -1))
h = tok.view(B,self.num_patches_h,self.num_patches_w,self.d_model).permute(0,3,1,2)
tok = self.post_conv(h).permute(0,2,3,1).contiguous().view(B,self.num_patches,self.d_model)
return self.unpatchify(self.final_proj(self.final_norm(tok)))
def count_params(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
# ============================================================
# AUTO PATCH SIZE: keeps L ≤ 256 tokens for ALL image sizes
# ============================================================
def _auto_patch(img_size, max_tokens=256):
"""Pick smallest patch_size that keeps L ≤ max_tokens."""
for ps in [4, 8, 16, 32, 64]:
L = (img_size // ps) ** 2
if L <= max_tokens:
return ps
return img_size // int(max_tokens ** 0.5)
def liquidflow_tiny(img_size=128, num_classes=0):
"""~4-5M params. L≤256 for any image size."""
ps = _auto_patch(img_size)
return LiquidFlowNet(img_size=img_size, patch_size=ps, d_model=192, depth=6, d_state=8, expand=2, num_classes=num_classes)
def liquidflow_small(img_size=128, num_classes=0):
"""~10M params. L≤256 for any image size."""
ps = _auto_patch(img_size)
return LiquidFlowNet(img_size=img_size, patch_size=ps, d_model=256, depth=8, d_state=8, expand=2, num_classes=num_classes)
def liquidflow_base(img_size=256, num_classes=0):
"""~24M params. L≤256 for any image size."""
ps = _auto_patch(img_size)
return LiquidFlowNet(img_size=img_size, patch_size=ps, d_model=384, depth=10, d_state=8, expand=2, num_classes=num_classes)
def liquidflow_512(img_size=512, num_classes=0):
"""~24M params. L≤256 for any image size."""
ps = _auto_patch(img_size)
return LiquidFlowNet(img_size=img_size, patch_size=ps, d_model=384, depth=10, d_state=8, expand=2, num_classes=num_classes)