File size: 867 Bytes
28eb1af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
"""
Liquid Time-Constant (LTC) Adaptive Gate.
Adds dynamic time-constant τ per channel, routing info between
fast (texture) and slow (structure) pathways.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F


class LTCGate(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.tau_proj = nn.Linear(dim, dim, bias=True)
        self.slow_branch = nn.Sequential(
            nn.Conv1d(dim, dim, kernel_size=3, padding=1, groups=dim),
            nn.GroupNorm(8, dim),
            nn.SiLU(),
        )
        nn.init.constant_(self.tau_proj.bias, -2.0)

    def forward(self, x: torch.Tensor):
        B, L, dim = x.shape
        tau = torch.sigmoid(self.tau_proj(x))
        fast = x
        slow = self.slow_branch(x.transpose(1, 2)).transpose(1, 2)
        out = tau * slow + (1.0 - tau) * fast
        return out