artigen / ltc_gate.py
krystv's picture
Upload ltc_gate.py
28eb1af verified
"""
Liquid Time-Constant (LTC) Adaptive Gate.
Adds dynamic time-constant τ per channel, routing info between
fast (texture) and slow (structure) pathways.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class LTCGate(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.tau_proj = nn.Linear(dim, dim, bias=True)
self.slow_branch = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size=3, padding=1, groups=dim),
nn.GroupNorm(8, dim),
nn.SiLU(),
)
nn.init.constant_(self.tau_proj.bias, -2.0)
def forward(self, x: torch.Tensor):
B, L, dim = x.shape
tau = torch.sigmoid(self.tau_proj(x))
fast = x
slow = self.slow_branch(x.transpose(1, 2)).transpose(1, 2)
out = tau * slow + (1.0 - tau) * fast
return out