File size: 5,869 Bytes
714db4b
 
f4749f1
714db4b
 
 
f4749f1
 
714db4b
f4749f1
 
 
 
 
 
 
 
 
714db4b
 
 
 
 
 
 
 
 
f4749f1
 
 
 
 
 
 
 
 
 
714db4b
 
f4749f1
 
 
714db4b
 
f4749f1
714db4b
 
 
 
f4749f1
714db4b
f4749f1
714db4b
f4749f1
 
714db4b
 
f4749f1
 
 
 
 
714db4b
f4749f1
 
 
 
 
 
 
 
 
 
 
 
714db4b
 
 
 
 
 
f4749f1
714db4b
 
 
f4749f1
714db4b
f4749f1
 
714db4b
f4749f1
 
 
 
 
714db4b
 
 
 
f4749f1
714db4b
f4749f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
714db4b
 
 
f4749f1
 
 
714db4b
f4749f1
 
 
 
 
 
714db4b
 
f4749f1
714db4b
f4749f1
714db4b
 
f4749f1
 
 
 
 
714db4b
 
f4749f1
 
 
 
 
 
 
714db4b
 
 
f4749f1
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""
CfC Cell — Closed-form Continuous-time neural network cell.
FULLY PARALLEL implementation — no sequential loops.

From: "Closed-form Continuous-time Neural Networks" (Hasani et al., 2022)

Core CfC equation (Eq. 10 from paper):
    x(t) = σ(-f(x,I;θ_f)·t) ⊙ g(x,I;θ_g) + (1 - σ(-f(x,I;θ_f)·t)) ⊙ h(x,I;θ_h)

Key insight for parallelization:
    The CfC equation is a CLOSED-FORM expression. It maps (input, time) → output
    with NO recurrent dependency between timesteps. This means for image processing
    we can compute ALL spatial positions in a single parallel pass.
    
    We use it as an adaptive gating mechanism:
    - f network produces position-dependent time constants
    - g/h networks produce two candidate feature maps
    - The sigmoid gate blends them adaptively per-position
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class CfCCell(nn.Module):
    """
    Parallel CfC cell — processes ALL positions simultaneously.
    
    The key realization: CfC's closed-form solution is NOT recurrent.
    It's a function of (input, time) → output. So we apply it to all
    spatial positions in parallel.
    
    For a sequence [B, L, D]:
        - f, g, h networks are applied to ALL L positions in parallel
        - The time parameter t modulates the gate per-position
        - Output is computed in a single vectorized operation
    
    Args:
        dim: Feature dimension
        dropout: Dropout rate
        time_scale: Range for time parameter
    """
    
    def __init__(self, dim, dropout=0.0, time_scale=(0.1, 1.0)):
        super().__init__()
        self.dim = dim
        self.time_scale = time_scale
        
        # Shared backbone (processes all positions in parallel)
        self.backbone = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.LayerNorm(dim * 4),
            nn.SiLU(),
            nn.Dropout(dropout),
        )
        
        # f head: time-constant (bounded by tanh for stability)
        self.f_head = nn.Sequential(
            nn.Linear(dim * 4, dim),
            nn.Tanh(),
        )
        
        # g head: "fast" feature (dominant when gate ≈ 1, i.e. small t)
        self.g_head = nn.Sequential(
            nn.Linear(dim * 4, dim),
        )
        
        # h head: "slow" feature (dominant when gate ≈ 0, i.e. large t)
        self.h_head = nn.Sequential(
            nn.Linear(dim * 4, dim),
        )
        
        # Learnable time-bias per channel (makes time adaptive per feature)
        self.time_bias = nn.Parameter(torch.zeros(dim))
        
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def forward(self, x, t=None):
        """
        Fully parallel CfC forward pass.
        
        Args:
            x: [B, L, D] — all positions processed simultaneously
            t: Optional time parameter [B, 1, 1] or scalar. 
               If None, sampled randomly during training, fixed during eval.
        Returns:
            out: [B, L, D]
        """
        B, L, D = x.shape
        device = x.device
        
        # Time parameter
        if t is None:
            if self.training:
                # Random time per batch during training (data augmentation)
                t = torch.rand(B, 1, 1, device=device) * (
                    self.time_scale[1] - self.time_scale[0]
                ) + self.time_scale[0]
            else:
                # Fixed midpoint during inference
                t = torch.full((B, 1, 1), 0.5 * (self.time_scale[0] + self.time_scale[1]), device=device)
        
        # Shared backbone (parallel over all B*L positions)
        features = self.backbone(x)  # [B, L, dim*4]
        
        # Three heads (all parallel)
        f_out = self.f_head(features)  # [B, L, D] — bounded by tanh
        g_out = self.g_head(features)  # [B, L, D]
        h_out = self.h_head(features)  # [B, L, D]
        
        # CfC gating: σ(-f * (t + time_bias))
        # time_bias makes gating adaptive per-channel
        effective_t = t + self.time_bias.view(1, 1, -1)  # [B, 1, D] broadcast
        gate = torch.sigmoid(-f_out * effective_t)  # [B, L, D]
        
        # CfC output: gate * g + (1-gate) * h
        out = gate * g_out + (1 - gate) * h_out  # [B, L, D]
        
        return out


class CfCBlock(nn.Module):
    """
    CfC block for 2D image processing.
    Fully parallel — no sequential loops.
    
    Architecture:
        Input [B, C, H, W] → flatten → CfC (parallel) → reshape → Output
        With: pre-norm, residual connection, feed-forward
    """
    
    def __init__(self, dim, dropout=0.0, expansion_factor=2):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.cfc = CfCCell(dim=dim, dropout=dropout)
        
        self.norm2 = nn.LayerNorm(dim)
        ff_dim = dim * expansion_factor
        self.ff = nn.Sequential(
            nn.Linear(dim, ff_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, dim),
            nn.Dropout(dropout),
        )
    
    def forward(self, x):
        """
        Args:
            x: [B, C, H, W] or [B, L, C]
        Returns:
            Same shape as input
        """
        is_2d = x.dim() == 4
        if is_2d:
            B, C, H, W = x.shape
            x = x.flatten(2).transpose(1, 2)  # [B, HW, C]
        
        # Pre-norm + CfC + residual
        x = x + self.cfc(self.norm1(x))
        # Pre-norm + FF + residual
        x = x + self.ff(self.norm2(x))
        
        if is_2d:
            x = x.transpose(1, 2).reshape(B, C, H, W)
        return x