File size: 7,924 Bytes
40a4412
 
3214be6
40a4412
 
3214be6
40a4412
3214be6
 
40a4412
 
 
 
 
3214be6
40a4412
3214be6
 
40a4412
 
 
 
 
 
3214be6
 
 
 
40a4412
3214be6
 
40a4412
 
 
 
 
 
 
3214be6
 
 
40a4412
3214be6
 
 
 
40a4412
3214be6
 
 
 
 
40a4412
 
 
 
 
 
 
 
 
 
 
 
 
 
3214be6
40a4412
 
 
 
 
 
3214be6
40a4412
3214be6
 
 
40a4412
3214be6
 
 
 
40a4412
3214be6
 
 
 
40a4412
3214be6
 
40a4412
3214be6
 
40a4412
3214be6
 
40a4412
 
3214be6
 
40a4412
 
 
3214be6
40a4412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3214be6
40a4412
3214be6
40a4412
3214be6
 
 
 
 
 
 
 
40a4412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3214be6
40a4412
 
3214be6
40a4412
 
 
 
 
 
3214be6
 
 
 
 
 
 
40a4412
 
 
 
 
3214be6
40a4412
 
 
 
 
 
 
 
3214be6
40a4412
3214be6
40a4412
3214be6
40a4412
3214be6
 
 
 
 
 
 
40a4412
 
3214be6
 
40a4412
3214be6
40a4412
3214be6
 
40a4412
 
 
 
3214be6
 
40a4412
3214be6
40a4412
 
3214be6
40a4412
3214be6
40a4412
3214be6
40a4412
3214be6
 
40a4412
3214be6
 
40a4412
3214be6
 
 
40a4412
3214be6
 
40a4412
 
 
3214be6
40a4412
 
3214be6
 
 
40a4412
3214be6
 
40a4412
3214be6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
"""
LiquidFlow Block β€” Hybrid CfC + Mamba-2 SSD architecture.
CORRECTED VERSION: proper dimensions, no sequential loops.

Architecture per block:
    Input β†’ Mamba-2 SSD (bidirectional) β†’ CfC adaptive gate β†’ Output
    
The CfC provides adaptive gating that modulates the SSM output
based on input-dependent "liquid" time constants.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from .cfc_cell import CfCCell, CfCBlock
from .mamba2_ssd import Mamba2SSD, Mamba2Block


class LiquidMambaBlock(nn.Module):
    """
    LiquidMamba: CfC-gated Mamba-2 SSD block.
    
    Flow:
    1. Input β†’ LayerNorm β†’ Mamba-2 SSD (bidirectional scan)
    2. SSM output β†’ CfC adaptive gate (parallel over all positions)
    3. Gated output β†’ residual + feed-forward
    
    The CfC gate learns WHEN to trust the SSM output vs the raw input,
    creating content-aware adaptive processing.
    """
    
    def __init__(self, dim, d_state=16, d_conv=4, expand=2, dropout=0.0):
        super().__init__()
        self.dim = dim
        
        # LayerNorms
        self.norm_ssm = nn.LayerNorm(dim)
        self.norm_gate = nn.LayerNorm(dim)
        self.norm_ff = nn.LayerNorm(dim)
        
        # Mamba-2 SSD: bidirectional scan
        self.ssd_fwd = Mamba2SSD(dim=dim, d_state=d_state, d_conv=d_conv, expand=expand)
        self.ssd_bwd = Mamba2SSD(dim=dim, d_state=d_state, d_conv=d_conv, expand=expand)
        self.merge = nn.Linear(dim * 2, dim, bias=False)
        
        # CfC gate: parallel adaptive gating
        self.cfc_gate = CfCCell(dim=dim, dropout=dropout)
        
        # Gate projection (learnable mixing)
        self.gate_proj = nn.Linear(dim, dim)
        
        # Feed-forward
        ff_dim = dim * expand
        self.ff = nn.Sequential(
            nn.Linear(dim, ff_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, dim),
            nn.Dropout(dropout),
        )
    
    def forward(self, x):
        """
        Args:
            x: [B, C, H, W] or [B, L, C]
        Returns:
            Same shape as input
        """
        is_2d = x.dim() == 4
        if is_2d:
            B, C, H, W = x.shape
            x = x.flatten(2).transpose(1, 2)  # [B, HW, C]
        
        # === SSM branch ===
        residual = x
        x_norm = self.norm_ssm(x)
        
        # Bidirectional Mamba-2 scan
        fwd_out = self.ssd_fwd(x_norm)
        bwd_out = torch.flip(self.ssd_bwd(torch.flip(x_norm, [1])), [1])
        ssm_out = self.merge(torch.cat([fwd_out, bwd_out], dim=-1))
        
        # === CfC gate ===
        # CfC processes the SSM output and produces adaptive gate
        gate_input = self.norm_gate(ssm_out)
        cfc_out = self.cfc_gate(gate_input)  # [B, L, D] β€” parallel!
        
        # Sigmoid gate: how much SSM output to use
        gate = torch.sigmoid(self.gate_proj(cfc_out))
        
        # Gated residual: blend SSM output with residual
        x = residual + gate * ssm_out
        
        # === Feed-forward ===
        x = x + self.ff(self.norm_ff(x))
        
        if is_2d:
            x = x.transpose(1, 2).reshape(B, C, H, W)
        return x


class LiquidFlowStage(nn.Module):
    """Stack of LiquidMamba blocks at the same resolution."""
    
    def __init__(self, dim, num_blocks=4, d_state=16, expand=2, dropout=0.0):
        super().__init__()
        self.blocks = nn.ModuleList([
            LiquidMambaBlock(dim=dim, d_state=d_state, expand=expand, dropout=dropout)
            for _ in range(num_blocks)
        ])
    
    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x


class LiquidFlowBackbone(nn.Module):
    """
    Complete LiquidFlow backbone β€” DiT-style noise predictor.
    
    FIXED: Output shape == Input shape (no patch_size confusion).
    
    Architecture:
        Input [B, in_ch, H, W] 
        β†’ Conv2d projection to hidden_dim
        β†’ + sinusoidal timestep embedding (AdaLN-style)
        β†’ + learnable positional encoding
        β†’ N Γ— LiquidMamba Stages
        β†’ Conv2d projection back to in_ch
        β†’ Output [B, in_ch, H, W]
    """
    
    def __init__(
        self,
        in_channels=4,
        hidden_dim=256,
        num_stages=4,
        blocks_per_stage=4,
        d_state=16,
        expand=2,
        dropout=0.0,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.hidden_dim = hidden_dim
        
        # Input projection (pointwise conv)
        self.in_proj = nn.Conv2d(in_channels, hidden_dim, kernel_size=1)
        
        # Timestep embedding
        self.time_embed = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.SiLU(),
            nn.Linear(hidden_dim * 4, hidden_dim),
        )
        
        # AdaLN-style conditioning: scale and shift
        self.t_cond = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim * 2),
        )
        
        # Positional encoding (learnable, supports up to 64Γ—64 = 4096 positions)
        self.pos_embed = nn.Parameter(torch.randn(1, 4096, hidden_dim) * 0.02)
        
        # LiquidFlow stages
        self.stages = nn.ModuleList([
            LiquidFlowStage(
                dim=hidden_dim,
                num_blocks=blocks_per_stage,
                d_state=d_state,
                expand=expand,
                dropout=dropout,
            )
            for _ in range(num_stages)
        ])
        
        # Output projection
        self.out_norm = nn.LayerNorm(hidden_dim)
        self.out_proj = nn.Linear(hidden_dim, in_channels)
        
        self._init_weights()
    
    def _init_weights(self):
        # Zero-init output projection for residual learning
        nn.init.zeros_(self.out_proj.weight)
        nn.init.zeros_(self.out_proj.bias)
    
    def _sinusoidal_embedding(self, timesteps, dim):
        """Sinusoidal positional embedding for diffusion timesteps."""
        half = dim // 2
        freqs = torch.exp(
            -math.log(10000.0) * torch.arange(half, device=timesteps.device).float() / half
        )
        args = timesteps.float().unsqueeze(-1) * freqs.unsqueeze(0)
        emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            emb = F.pad(emb, (0, 1))
        return emb
    
    def forward(self, x, t):
        """
        Args:
            x: [B, in_channels, H, W] β€” noisy latent
            t: [B] β€” diffusion timesteps (integers 0..T-1)
        Returns:
            [B, in_channels, H, W] β€” predicted noise (same shape as input!)
        """
        B, C, H, W = x.shape
        L = H * W
        
        # Project to hidden dim
        x = self.in_proj(x)  # [B, hidden_dim, H, W]
        x = x.flatten(2).transpose(1, 2)  # [B, HW, hidden_dim]
        
        # Timestep conditioning (AdaLN)
        t_emb = self._sinusoidal_embedding(t, self.hidden_dim)  # [B, hidden_dim]
        t_emb = self.time_embed(t_emb)  # [B, hidden_dim]
        t_cond = self.t_cond(t_emb)  # [B, hidden_dim*2]
        scale, shift = t_cond.chunk(2, dim=-1)  # each [B, hidden_dim]
        
        # Apply conditioning + positional encoding
        x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
        x = x + self.pos_embed[:, :L, :]
        
        # Reshape to 2D for processing
        x = x.transpose(1, 2).reshape(B, self.hidden_dim, H, W)
        
        # Process through all stages
        for stage in self.stages:
            x = stage(x)
        
        # Output head
        x = x.flatten(2).transpose(1, 2)  # [B, HW, hidden_dim]
        x = self.out_norm(x)
        x = self.out_proj(x)  # [B, HW, in_channels]
        
        # Reshape back to image: [B, in_channels, H, W]
        x = x.transpose(1, 2).reshape(B, self.in_channels, H, W)
        
        return x