""" PMA-VAE: Parallel Mobile Artistic Variational Autoencoder ========================================================= Attention-free, mobile-deployable VAE with: - Parallel 2D Mamba/SSM blocks (no sequential pixel loops) - Mobile depthwise-separable convolutions - Multi-scale latents: z_base (H/16), z_detail (H/8), z_style (global vector) - FiLM style conditioning throughout decoder - Designed for: image generation, super-resolution, artifact removal, style transfer Architecture: Image → PixelUnshuffle stem → MobileConv + Parallel 2D Mamba encoder → Multi-scale latent (base + detail + style) → Light parallel decoder with FiLM modulation → Reconstructed image Total params target: ~20-40M (encoder heavier, decoder light for mobile) """ import math import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange # ============================================================================== # Parallel Scan (Blelloch-style) — Pure PyTorch, no CUDA kernels # Based on: https://github.com/alxndrTL/mamba.py/blob/main/mambapy/pscan.py # ============================================================================== class PScan(torch.autograd.Function): """ Parallel prefix scan (Blelloch algorithm) in pure PyTorch. Computes: y[t] = A[t] * y[t-1] + X[t] for all t in parallel. """ @staticmethod def pscan_forward(A, X): B, D, L, N = A.size() # Pad to next power of 2 if needed orig_L = L if L & (L - 1) != 0: # not power of 2 next_pow2 = 1 << (L - 1).bit_length() pad = next_pow2 - L A = F.pad(A, (0, 0, 0, pad), value=1.0) X = F.pad(X, (0, 0, 0, pad), value=0.0) L = next_pow2 num_steps = int(math.log2(L)) # Store intermediate values for down-sweep Aa = A.clone() Xa = X.clone() # Up-sweep (reduce) for k in range(num_steps): step = 1 << (k + 1) half = step // 2 # Indices for even/odd pairs idx = torch.arange(half - 1, L, step, device=A.device) idx_prev = idx - half Xa[:, :, idx] = Aa[:, :, idx] * Xa[:, :, idx_prev] + Xa[:, :, idx] Aa[:, :, idx] = Aa[:, :, idx] * Aa[:, :, idx_prev] # Down-sweep for k in range(num_steps - 2, -1, -1): step = 1 << (k + 1) half = step // 2 idx = torch.arange(step - 1, L, step, device=A.device) if idx.numel() > 0 and (idx + half < L).any(): valid = idx + half valid = valid[valid < L] if valid.numel() > 0: src_idx = valid - half Xa[:, :, valid] = Aa[:, :, valid] * Xa[:, :, src_idx] + Xa[:, :, valid] return Xa[:, :, :orig_L] @staticmethod def forward(ctx, A_in, X_in): A = A_in.clone() X = X_in.clone() result = PScan.pscan_forward(A, X) ctx.save_for_backward(A_in, X_in, result) return result @staticmethod def backward(ctx, grad_output): A_in, X_in, result = ctx.saved_tensors # For backward: reversed scan # dA[t] = grad[t] * y[t-1], dX[t] = cumulative product of future A's * grad # Simplified: use autograd-friendly sequential for backward (still fast enough) B, D, L, N = A_in.size() grad_A = torch.zeros_like(A_in) grad_X = torch.zeros_like(X_in) # Sequential backward (simpler, correct) grad_h = torch.zeros(B, D, N, device=A_in.device, dtype=A_in.dtype) for t in range(L - 1, -1, -1): grad_h = grad_h + grad_output[:, :, t] grad_X[:, :, t] = grad_h if t > 0: # y[t-1] from forward y_prev = result[:, :, t - 1] grad_A[:, :, t] = (grad_h * y_prev).sum(-1, keepdim=True).expand_as(A_in[:, :, t]) grad_h = grad_h * A_in[:, :, t] else: grad_A[:, :, 0] = torch.zeros_like(A_in[:, :, 0]) return grad_A, grad_X pscan = PScan.apply # ============================================================================== # Selective State Space (S6) Block — The core Mamba mechanism # ============================================================================== class SelectiveSSM(nn.Module): """ Selective State Space Model (S6) from Mamba paper. Uses parallel scan for O(L) computation without sequential loops. For 2D images: we flatten H*W to sequence, process with SSM, reshape back. """ def __init__(self, d_model, d_state=16, d_conv=4, expand=2, use_parallel_scan=True): super().__init__() self.d_model = d_model self.d_state = d_state self.d_conv = d_conv self.expand = expand self.d_inner = int(expand * d_model) self.use_parallel_scan = use_parallel_scan # Input projection: x → (xz) where x goes through SSM, z is gate self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False) # 1D depthwise conv (local context before SSM) self.conv1d = nn.Conv1d( self.d_inner, self.d_inner, kernel_size=d_conv, bias=True, groups=self.d_inner, padding=d_conv - 1 ) # Input-dependent SSM parameters self.x_proj = nn.Linear(self.d_inner, self.d_state * 2 + 1, bias=False) self.dt_proj = nn.Linear(1, self.d_inner, bias=True) # A matrix (structured, log-parameterized) A = torch.arange(1, d_state + 1, dtype=torch.float32).unsqueeze(0).expand(self.d_inner, -1) self.A_log = nn.Parameter(torch.log(A)) # D skip connection self.D = nn.Parameter(torch.ones(self.d_inner)) # Output projection self.out_proj = nn.Linear(self.d_inner, d_model, bias=False) # Pre-norm self.norm = nn.RMSNorm(d_model) def ssm_parallel(self, x): """Parallel scan SSM — no sequential loops.""" B_size, L, D = x.shape A = -torch.exp(self.A_log.float()) # (d_inner, d_state) D_skip = self.D.float() # Compute input-dependent B, C, dt x_dbl = self.x_proj(x) # (B, L, d_state*2 + 1) dt, B_mat, C_mat = x_dbl.split([1, self.d_state, self.d_state], dim=-1) dt = F.softplus(self.dt_proj(dt)) # (B, L, d_inner) # Discretize: dA = exp(dt * A), dB = dt * B dA = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0)) # (B, L, D, N) dBx = dt.unsqueeze(-1) * B_mat.unsqueeze(2) * x.unsqueeze(-1) # (B, L, D, N) # Rearrange for parallel scan: (B, D, L, N) dA = dA.permute(0, 2, 1, 3).contiguous() dBx = dBx.permute(0, 2, 1, 3).contiguous() if self.use_parallel_scan: # Parallel prefix scan h = pscan(dA, dBx) # (B, D, L, N) else: # Sequential fallback h = torch.zeros_like(dBx) state = torch.zeros(B_size, self.d_inner, self.d_state, device=x.device, dtype=x.dtype) for t in range(L): state = dA[:, :, t] * state + dBx[:, :, t] h[:, :, t] = state # Output: y = C * h + D * x h = h.permute(0, 2, 1, 3) # (B, L, D, N) C_mat_exp = C_mat.unsqueeze(2) # (B, L, 1, N) y = (h * C_mat_exp).sum(-1) # (B, L, D) y = y + D_skip * x return y def forward(self, x): """x: (B, L, d_model)""" residual = x x = self.norm(x) # Input projection + gate split xz = self.in_proj(x) # (B, L, 2*d_inner) x_ssm, z = xz.chunk(2, dim=-1) # 1D conv for local context x_ssm = rearrange(x_ssm, 'b l d -> b d l') x_ssm = self.conv1d(x_ssm)[:, :, :residual.shape[1]] x_ssm = rearrange(x_ssm, 'b d l -> b l d') x_ssm = F.silu(x_ssm) # SSM y = self.ssm_parallel(x_ssm) # Gated output y = y * F.silu(z) return self.out_proj(y) + residual # ============================================================================== # 2D Cross-Scan for Vision — VMamba style # ============================================================================== def cross_scan_2d(x): """ Convert 2D feature map to 4 directional 1D sequences. x: (B, H, W, C) Returns: list of 4 tensors, each (B, H*W, C) """ B, H, W, C = x.shape # Direction 1: raster (top-left → bottom-right) d1 = rearrange(x, 'b h w c -> b (h w) c') # Direction 2: reverse raster d2 = rearrange(x.flip([1, 2]), 'b h w c -> b (h w) c') # Direction 3: column-first d3 = rearrange(x.permute(0, 2, 1, 3), 'b w h c -> b (w h) c') # Direction 4: reverse column-first d4 = rearrange(x.permute(0, 2, 1, 3).flip([1, 2]), 'b w h c -> b (w h) c') return [d1, d2, d3, d4] def cross_merge_2d(ys, H, W): """ Merge 4 directional sequences back to 2D. ys: list of 4 tensors (B, H*W, C) Returns: (B, H, W, C) """ d1 = rearrange(ys[0], 'b (h w) c -> b h w c', h=H, w=W) d2 = rearrange(ys[1], 'b (h w) c -> b h w c', h=H, w=W).flip([1, 2]) d3 = rearrange(ys[2], 'b (h w) c -> b w h c', h=H, w=W).permute(0, 2, 1, 3) d4 = rearrange(ys[3], 'b (h w) c -> b w h c', h=H, w=W).permute(0, 2, 1, 3).flip([1, 2]) return (d1 + d2 + d3 + d4) * 0.25 class Mamba2DBlock(nn.Module): """ 2D Mamba block using cross-scan pattern. Processes feature maps with 4 directional SSM scans in parallel. No attention — pure SSM + local conv. """ def __init__(self, channels, d_state=16, expand=2, use_parallel_scan=True): super().__init__() self.channels = channels # One SSM shared across all 4 directions (weight sharing saves params) self.ssm = SelectiveSSM( d_model=channels, d_state=d_state, d_conv=4, expand=expand, use_parallel_scan=use_parallel_scan ) self.mix_proj = nn.Linear(channels, channels) self.norm = nn.RMSNorm(channels) def forward(self, x): """x: (B, C, H, W)""" B, C, H, W = x.shape residual = x # Convert to (B, H, W, C) x_hwc = x.permute(0, 2, 3, 1) # Cross-scan: 4 directional 1D sequences seqs = cross_scan_2d(x_hwc) # Process all 4 directions with shared SSM outputs = [self.ssm(s) for s in seqs] # Cross-merge back to 2D merged = cross_merge_2d(outputs, H, W) # (B, H, W, C) merged = self.norm(merged) merged = self.mix_proj(merged) # Back to (B, C, H, W) return merged.permute(0, 3, 1, 2) + residual # ============================================================================== # Mobile Convolution Blocks # ============================================================================== class SqueezeExcitation(nn.Module): """Channel attention via squeeze-excitation.""" def __init__(self, channels, reduction=4): super().__init__() reduced = max(8, channels // reduction) self.pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channels, reduced), nn.SiLU(inplace=True), nn.Linear(reduced, channels), nn.Sigmoid() ) def forward(self, x): B, C, H, W = x.shape w = self.pool(x).view(B, C) w = self.fc(w).view(B, C, 1, 1) return x * w class FiLM(nn.Module): """Feature-wise Linear Modulation for style conditioning.""" def __init__(self, cond_dim, channels): super().__init__() self.proj = nn.Linear(cond_dim, channels * 2) def forward(self, x, cond): """x: (B,C,H,W), cond: (B, cond_dim)""" params = self.proj(cond) # (B, 2*C) gamma, beta = params.chunk(2, dim=-1) # each (B, C) gamma = gamma.view(-1, x.shape[1], 1, 1) beta = beta.view(-1, x.shape[1], 1, 1) return x * (1 + gamma) + beta class MobileConvBlock(nn.Module): """ Mobile-friendly inverted residual block with: - Depthwise separable convolution - Squeeze-Excitation - Optional FiLM style conditioning - Reparameterizable for mobile deployment """ def __init__(self, in_ch, out_ch, expand_ratio=4, stride=1, use_se=True, cond_dim=None): super().__init__() mid_ch = in_ch * expand_ratio self.use_residual = (stride == 1 and in_ch == out_ch) layers = [] # Expand if expand_ratio != 1: layers.extend([ nn.Conv2d(in_ch, mid_ch, 1, bias=False), nn.BatchNorm2d(mid_ch), nn.SiLU(inplace=True), ]) # Depthwise layers.extend([ nn.Conv2d(mid_ch, mid_ch, 3, stride=stride, padding=1, groups=mid_ch, bias=False), nn.BatchNorm2d(mid_ch), nn.SiLU(inplace=True), ]) self.conv = nn.Sequential(*layers) # Squeeze-Excitation self.se = SqueezeExcitation(mid_ch) if use_se else nn.Identity() # Project self.project = nn.Sequential( nn.Conv2d(mid_ch, out_ch, 1, bias=False), nn.BatchNorm2d(out_ch), ) # FiLM conditioning self.film = FiLM(cond_dim, out_ch) if cond_dim else None # Skip connection if not self.use_residual and stride == 1: self.skip = nn.Conv2d(in_ch, out_ch, 1, bias=False) elif not self.use_residual: self.skip = nn.Sequential( nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False), nn.BatchNorm2d(out_ch), ) else: self.skip = nn.Identity() def forward(self, x, cond=None): out = self.conv(x) out = self.se(out) out = self.project(out) if self.film is not None and cond is not None: out = self.film(out, cond) if self.use_residual: return out + x else: return out + self.skip(x) if hasattr(self, 'skip') else out class GatedConvBlock(nn.Module): """Gated convolution block — alternative to attention for global mixing.""" def __init__(self, channels): super().__init__() self.norm = nn.GroupNorm(min(32, channels), channels) self.proj = nn.Conv2d(channels, channels * 2, 1) self.dw = nn.Conv2d(channels, channels, 5, padding=2, groups=channels) self.out = nn.Conv2d(channels, channels, 1) def forward(self, x): residual = x x = self.norm(x) gate, val = self.proj(x).chunk(2, dim=1) val = self.dw(val) x = val * F.silu(gate) return self.out(x) + residual # ============================================================================== # PMA-VAE Encoder # ============================================================================== class PMAEncoder(nn.Module): """ Encoder with progressive downsampling: H → H/2 → H/4 → H/8 → H/16 Outputs multi-scale latents: - z_base: H/16 x W/16 x latent_base_dim - z_detail: H/8 x W/8 x latent_detail_dim - z_style: 1 x 1 x latent_style_dim (global) """ def __init__(self, in_channels=3, stage_channels=(64, 128, 192, 256), stage_blocks=(2, 2, 4, 4), latent_base_dim=32, latent_detail_dim=8, latent_style_dim=128, d_state=16, use_parallel_scan=True): super().__init__() self.latent_base_dim = latent_base_dim self.latent_detail_dim = latent_detail_dim self.latent_style_dim = latent_style_dim # Stem: PixelUnshuffle (lossless 2x downsample) + Conv self.stem = nn.Sequential( nn.PixelUnshuffle(2), # (B, C*4, H/2, W/2) nn.Conv2d(in_channels * 4, stage_channels[0], 3, padding=1, bias=False), nn.BatchNorm2d(stage_channels[0]), nn.SiLU(inplace=True), ) # Stage 1: H/2 → H/4, MobileConv only self.stage1 = self._make_mobile_stage( stage_channels[0], stage_channels[1], stage_blocks[0], stride=2 ) # Stage 2: H/4 → H/8, MobileConv + some Mamba self.stage2 = self._make_hybrid_stage( stage_channels[1], stage_channels[2], stage_blocks[1], stride=2, d_state=d_state, mamba_ratio=0.5, use_parallel_scan=use_parallel_scan ) # Detail latent head (at H/8 resolution) self.detail_head_mu = nn.Conv2d(stage_channels[2], latent_detail_dim, 1) self.detail_head_logvar = nn.Conv2d(stage_channels[2], latent_detail_dim, 1) # Stage 3: H/8 → H/16, Mamba-heavy self.stage3 = self._make_hybrid_stage( stage_channels[2], stage_channels[3], stage_blocks[2], stride=2, d_state=d_state, mamba_ratio=0.75, use_parallel_scan=use_parallel_scan ) # One global mixing block at H/16 self.global_mix = GatedConvBlock(stage_channels[3]) # Base latent head (at H/16 resolution) self.base_head_mu = nn.Conv2d(stage_channels[3], latent_base_dim, 1) self.base_head_logvar = nn.Conv2d(stage_channels[3], latent_base_dim, 1) # Style latent head (global) self.style_pool = nn.AdaptiveAvgPool2d(1) self.style_head_mu = nn.Linear(stage_channels[3], latent_style_dim) self.style_head_logvar = nn.Linear(stage_channels[3], latent_style_dim) def _make_mobile_stage(self, in_ch, out_ch, num_blocks, stride=1): blocks = [MobileConvBlock(in_ch, out_ch, stride=stride)] for _ in range(num_blocks - 1): blocks.append(MobileConvBlock(out_ch, out_ch)) return nn.Sequential(*blocks) def _make_hybrid_stage(self, in_ch, out_ch, num_blocks, stride=1, d_state=16, mamba_ratio=0.5, use_parallel_scan=True): blocks = nn.ModuleList() # First block handles stride blocks.append(MobileConvBlock(in_ch, out_ch, stride=stride)) num_mamba = max(1, int((num_blocks - 1) * mamba_ratio)) num_mobile = (num_blocks - 1) - num_mamba for _ in range(num_mobile): blocks.append(MobileConvBlock(out_ch, out_ch)) for _ in range(num_mamba): blocks.append(Mamba2DBlock(out_ch, d_state=d_state, expand=2, use_parallel_scan=use_parallel_scan)) return blocks def forward(self, x): """ x: (B, 3, H, W) Returns: dict with mu/logvar for base, detail, style latents """ # Stem: H → H/2 x = self.stem(x) # Stage 1: H/2 → H/4 x = self.stage1(x) # Stage 2: H/4 → H/8 for block in self.stage2: x = block(x) # Detail latent at H/8 detail_mu = self.detail_head_mu(x) detail_logvar = self.detail_head_logvar(x) # Stage 3: H/8 → H/16 for block in self.stage3: x = block(x) # Global mixing x = self.global_mix(x) # Base latent at H/16 base_mu = self.base_head_mu(x) base_logvar = self.base_head_logvar(x) # Style latent (global) style_feat = self.style_pool(x).flatten(1) style_mu = self.style_head_mu(style_feat) style_logvar = self.style_head_logvar(style_feat) return { 'base_mu': base_mu, 'base_logvar': base_logvar, 'detail_mu': detail_mu, 'detail_logvar': detail_logvar, 'style_mu': style_mu, 'style_logvar': style_logvar, } # ============================================================================== # PMA-VAE Decoder # ============================================================================== class UpsampleBlock(nn.Module): """Efficient 2x upsample with pixel shuffle.""" def __init__(self, in_ch, out_ch): super().__init__() self.conv = nn.Conv2d(in_ch, out_ch * 4, 3, padding=1, bias=False) self.ps = nn.PixelShuffle(2) self.norm = nn.BatchNorm2d(out_ch) self.act = nn.SiLU(inplace=True) def forward(self, x): return self.act(self.norm(self.ps(self.conv(x)))) class PMADecoder(nn.Module): """ Lightweight decoder for mobile deployment. Takes multi-scale latents and reconstructs image: z_base (H/16) + z_style → decode → fuse z_detail (H/8) → upsample → image """ def __init__(self, out_channels=3, stage_channels=(256, 192, 128, 96, 64), latent_base_dim=32, latent_detail_dim=8, latent_style_dim=128, d_state=16, use_parallel_scan=True): super().__init__() # Initial projection from latent to feature space self.base_proj = nn.Sequential( nn.Conv2d(latent_base_dim, stage_channels[0], 3, padding=1, bias=False), nn.BatchNorm2d(stage_channels[0]), nn.SiLU(inplace=True), ) # Stage 1: H/16, Mamba blocks with FiLM style conditioning self.stage1_blocks = nn.ModuleList([ MobileConvBlock(stage_channels[0], stage_channels[0], cond_dim=latent_style_dim), Mamba2DBlock(stage_channels[0], d_state=d_state, use_parallel_scan=use_parallel_scan), ]) # Upsample H/16 → H/8 self.up1 = UpsampleBlock(stage_channels[0], stage_channels[1]) # Fuse detail latent at H/8 self.detail_fuse = nn.Sequential( nn.Conv2d(stage_channels[1] + latent_detail_dim, stage_channels[1], 1, bias=False), nn.BatchNorm2d(stage_channels[1]), nn.SiLU(inplace=True), ) # Stage 2: H/8, MobileConv with FiLM self.stage2_blocks = nn.ModuleList([ MobileConvBlock(stage_channels[1], stage_channels[1], cond_dim=latent_style_dim), MobileConvBlock(stage_channels[1], stage_channels[1], cond_dim=latent_style_dim), Mamba2DBlock(stage_channels[1], d_state=d_state, use_parallel_scan=use_parallel_scan), ]) # Upsample H/8 → H/4 self.up2 = UpsampleBlock(stage_channels[1], stage_channels[2]) # Stage 3: H/4 self.stage3_blocks = nn.ModuleList([ MobileConvBlock(stage_channels[2], stage_channels[2], cond_dim=latent_style_dim), MobileConvBlock(stage_channels[2], stage_channels[2]), ]) # Upsample H/4 → H/2 self.up3 = UpsampleBlock(stage_channels[2], stage_channels[3]) # Stage 4: H/2 self.stage4_blocks = nn.ModuleList([ MobileConvBlock(stage_channels[3], stage_channels[3]), MobileConvBlock(stage_channels[3], stage_channels[3]), ]) # Upsample H/2 → H (PixelShuffle) self.up4 = UpsampleBlock(stage_channels[3], stage_channels[4]) # Final output head self.head = nn.Sequential( nn.Conv2d(stage_channels[4], stage_channels[4], 3, padding=1), nn.SiLU(inplace=True), nn.Conv2d(stage_channels[4], out_channels, 3, padding=1), nn.Tanh(), # output [-1, 1] ) def forward(self, z_base, z_detail, z_style): """ z_base: (B, latent_base_dim, H/16, W/16) z_detail: (B, latent_detail_dim, H/8, W/8) z_style: (B, latent_style_dim) """ # Project base latent x = self.base_proj(z_base) # Stage 1: H/16 with style conditioning for block in self.stage1_blocks: if isinstance(block, MobileConvBlock): x = block(x, cond=z_style) else: x = block(x) # Upsample to H/8 x = self.up1(x) # Fuse detail latent x = self.detail_fuse(torch.cat([x, z_detail], dim=1)) # Stage 2: H/8 for block in self.stage2_blocks: if isinstance(block, MobileConvBlock): x = block(x, cond=z_style) else: x = block(x) # Upsample to H/4 x = self.up2(x) # Stage 3: H/4 for block in self.stage3_blocks: if isinstance(block, MobileConvBlock): x = block(x, cond=z_style) else: x = block(x) # Upsample to H/2 x = self.up3(x) # Stage 4: H/2 for block in self.stage4_blocks: x = block(x) # Upsample to H x = self.up4(x) # Output return self.head(x) # ============================================================================== # Full PMA-VAE Model # ============================================================================== class PMAVAE(nn.Module): """ Parallel Mobile Artistic VAE — Full model. Features: - Attention-free (Mamba SSM + mobile convolutions) - Multi-scale latent space (base + detail + style) - FiLM style conditioning in decoder - Parallel scan training (no sequential pixel loops) - Mobile-deployable decoder (~15-20M params) Args: in_channels: Input image channels (3 for RGB) enc_channels: Channel widths per encoder stage dec_channels: Channel widths per decoder stage latent_base_dim: Channels for H/16 base latent latent_detail_dim: Channels for H/8 detail latent latent_style_dim: Dimension of global style vector d_state: SSM state dimension use_parallel_scan: Use Blelloch parallel scan (True) or sequential (False) """ def __init__(self, in_channels=3, enc_channels=(64, 128, 192, 256), dec_channels=(256, 192, 128, 96, 64), enc_blocks=(2, 2, 4, 4), latent_base_dim=32, latent_detail_dim=8, latent_style_dim=128, d_state=16, use_parallel_scan=True): super().__init__() self.encoder = PMAEncoder( in_channels=in_channels, stage_channels=enc_channels, stage_blocks=enc_blocks, latent_base_dim=latent_base_dim, latent_detail_dim=latent_detail_dim, latent_style_dim=latent_style_dim, d_state=d_state, use_parallel_scan=use_parallel_scan, ) self.decoder = PMADecoder( out_channels=in_channels, stage_channels=dec_channels, latent_base_dim=latent_base_dim, latent_detail_dim=latent_detail_dim, latent_style_dim=latent_style_dim, d_state=d_state, use_parallel_scan=use_parallel_scan, ) def reparameterize(self, mu, logvar): """Reparameterization trick: z = mu + eps * std""" if self.training: std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std return mu def encode(self, x): """Encode image to multi-scale latent distributions.""" posteriors = self.encoder(x) return posteriors def decode(self, z_base, z_detail, z_style): """Decode latents to image.""" return self.decoder(z_base, z_detail, z_style) def forward(self, x): """ Full forward pass: encode → sample → decode. Returns: (recon, posteriors_dict) """ posteriors = self.encode(x) # Sample from each latent distribution z_base = self.reparameterize(posteriors['base_mu'], posteriors['base_logvar']) z_detail = self.reparameterize(posteriors['detail_mu'], posteriors['detail_logvar']) z_style = self.reparameterize(posteriors['style_mu'], posteriors['style_logvar']) # Decode recon = self.decode(z_base, z_detail, z_style) return recon, posteriors def get_last_decoder_layer(self): """For adaptive discriminator weight balancing.""" return self.decoder.head[-2].weight @torch.no_grad() def encode_to_latent(self, x): """Encode to deterministic latent (use mu, no sampling).""" posteriors = self.encode(x) return (posteriors['base_mu'], posteriors['detail_mu'], posteriors['style_mu']) @torch.no_grad() def decode_from_latent(self, z_base, z_detail, z_style): """Decode from latents (inference mode).""" return self.decode(z_base, z_detail, z_style) def count_parameters(self): """Count and display parameter breakdown.""" enc_params = sum(p.numel() for p in self.encoder.parameters()) dec_params = sum(p.numel() for p in self.decoder.parameters()) total = enc_params + dec_params return { 'encoder': enc_params, 'decoder': dec_params, 'total': total, 'encoder_M': enc_params / 1e6, 'decoder_M': dec_params / 1e6, 'total_M': total / 1e6, } # ============================================================================== # Model Configs # ============================================================================== def pmavae_tiny(**kwargs): """Tiny config for testing. ~5M params.""" return PMAVAE( enc_channels=(32, 64, 96, 128), dec_channels=(128, 96, 64, 48, 32), enc_blocks=(1, 1, 2, 2), latent_base_dim=16, latent_detail_dim=4, latent_style_dim=64, d_state=8, **kwargs ) def pmavae_small(**kwargs): """Small config for Colab free tier. ~20M params.""" return PMAVAE( enc_channels=(48, 96, 144, 192), dec_channels=(192, 144, 96, 72, 48), enc_blocks=(2, 2, 3, 3), latent_base_dim=24, latent_detail_dim=6, latent_style_dim=96, d_state=16, **kwargs ) def pmavae_base(**kwargs): """Base config. ~40M params.""" return PMAVAE( enc_channels=(64, 128, 192, 256), dec_channels=(256, 192, 128, 96, 64), enc_blocks=(2, 2, 4, 4), latent_base_dim=32, latent_detail_dim=8, latent_style_dim=128, d_state=16, **kwargs ) if __name__ == '__main__': # Quick test device = 'cpu' model = pmavae_tiny(use_parallel_scan=False).to(device) x = torch.randn(2, 3, 256, 256, device=device) recon, posteriors = model(x) print(f"Input: {x.shape}") print(f"Recon: {recon.shape}") for k, v in posteriors.items(): print(f" {k}: {v.shape}") params = model.count_parameters() print(f"\nParams: {params['total_M']:.2f}M (enc: {params['encoder_M']:.2f}M, dec: {params['decoder_M']:.2f}M)")