| """ |
| Liquid Time-Constant (LTC) Adaptive Gate. |
| Adds dynamic time-constant τ per channel, routing info between |
| fast (texture) and slow (structure) pathways. |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class LTCGate(nn.Module): |
| def __init__(self, dim: int): |
| super().__init__() |
| self.tau_proj = nn.Linear(dim, dim, bias=True) |
| self.slow_branch = nn.Sequential( |
| nn.Conv1d(dim, dim, kernel_size=3, padding=1, groups=dim), |
| nn.GroupNorm(8, dim), |
| nn.SiLU(), |
| ) |
| nn.init.constant_(self.tau_proj.bias, -2.0) |
|
|
| def forward(self, x: torch.Tensor): |
| B, L, dim = x.shape |
| tau = torch.sigmoid(self.tau_proj(x)) |
| fast = x |
| slow = self.slow_branch(x.transpose(1, 2)).transpose(1, 2) |
| out = tau * slow + (1.0 - tau) * fast |
| return out |
|
|