indicguard / src /models /liquid_layer.py
realruneet's picture
Upload 4 files
dadfa66 verified
"""
Liquid Neural Network Layer
Stage 4a: Adaptive Time-Constant Processing
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
import math
import logging
logger = logging.getLogger(__name__)
class LiquidTimeConstant(nn.Module):
"""
Liquid Time-Constant (LTC) Cell
Implements continuous-time dynamics with adaptive time constants.
The "liquid" behavior allows the network to adjust its reaction
speed based on input volatility.
Reference: "Liquid Time-constant Networks" (Hasani et al., 2020)
"""
def __init__(
self,
input_dim: int,
hidden_dim: int,
tau_min: float = 0.1,
tau_max: float = 10.0,
):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.tau_min = tau_min
self.tau_max = tau_max
# Input-to-hidden mapping
self.W_in = nn.Linear(input_dim, hidden_dim)
# Hidden-to-hidden mapping
self.W_h = nn.Linear(hidden_dim, hidden_dim)
# Time constant modulation
# Tau is computed as: tau = tau_min + (tau_max - tau_min) * sigmoid(tau_net)
self.tau_net = nn.Sequential(
nn.Linear(input_dim + hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
)
# Output gate
self.W_out = nn.Linear(hidden_dim, hidden_dim)
# Layer normalization for stability
self.ln_h = nn.LayerNorm(hidden_dim)
def forward(
self,
x: torch.Tensor,
h: Optional[torch.Tensor] = None,
dt: float = 0.01,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Single LTC step.
Args:
x: Input (batch, input_dim)
h: Previous hidden state (batch, hidden_dim)
dt: Time step
Returns:
output: (batch, hidden_dim)
new_h: (batch, hidden_dim)
"""
batch_size = x.size(0)
# Initialize hidden state if None
if h is None:
h = torch.zeros(batch_size, self.hidden_dim, device=x.device)
# Compute adaptive time constant
tau_input = torch.cat([x, h], dim=-1)
tau_raw = self.tau_net(tau_input)
tau = self.tau_min + (self.tau_max - self.tau_min) * torch.sigmoid(tau_raw)
# Compute input contribution
f_x = torch.tanh(self.W_in(x))
# Compute recurrent contribution
f_h = torch.tanh(self.W_h(h))
# ODE update: dh/dt = (1/tau) * (-h + f(x, h))
# Euler discretization: h_new = h + dt * (1/tau) * (-h + f_x + f_h)
activation = f_x + f_h
dh = (dt / tau) * (-h + activation)
h_new = h + dh
# Apply layer normalization for stability
h_new = self.ln_h(h_new)
# Output
output = torch.tanh(self.W_out(h_new))
return output, h_new
class LiquidLayer(nn.Module):
"""
Complete Liquid Neural Network Layer
Key Features:
- Adaptive time constants based on input volatility
- ODE-based dynamics for robustness to noise
- Dropout for regularization (Layer 1)
"""
def __init__(self, config: dict):
super().__init__()
liquid_config = config["model"]["liquid"]
self.input_dim = liquid_config["input_dim"]
self.hidden_dim = liquid_config["hidden_dim"]
self.tau_min = liquid_config["tau_min"]
self.tau_max = liquid_config["tau_max"]
self.dt = liquid_config["dt"]
self.num_steps = liquid_config["num_steps"]
self.dropout_rate = liquid_config["dropout"]
# LTC cell
self.ltc = LiquidTimeConstant(
input_dim=self.input_dim,
hidden_dim=self.hidden_dim,
tau_min=self.tau_min,
tau_max=self.tau_max,
)
# Input projection (if dimensions don't match)
if self.input_dim != self.hidden_dim:
self.input_proj = nn.Linear(self.input_dim, self.input_dim)
else:
self.input_proj = nn.Identity()
# Dropout (Layer 1)
self.dropout = nn.Dropout(self.dropout_rate)
# Output layer normalization
self.output_norm = nn.LayerNorm(self.hidden_dim)
logger.info(f"Liquid Layer initialized:")
logger.info(f" Hidden Dim: {self.hidden_dim}")
logger.info(f" Tau Range: [{self.tau_min}, {self.tau_max}]")
logger.info(f" Num Steps: {self.num_steps}")
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Process input through Liquid layer.
The layer runs multiple ODE steps to "think" about the input,
with the time constant adapting based on input complexity.
Args:
x: Input features (batch, input_dim)
Returns:
Output features (batch, hidden_dim)
"""
# Project input
x = self.input_proj(x)
# Initialize hidden state
h = None
# Run multiple liquid steps
for _ in range(self.num_steps):
output, h = self.ltc(x, h, self.dt)
# Apply dropout
output = self.dropout(output)
# Normalize output
output = self.output_norm(output)
return output
def get_time_constants(self, x: torch.Tensor) -> torch.Tensor:
"""
Get the adaptive time constants for interpretability.
Returns:
Time constants (batch, hidden_dim)
"""
h = torch.zeros(x.size(0), self.hidden_dim, device=x.device)
x = self.input_proj(x)
tau_input = torch.cat([x, h], dim=-1)
tau_raw = self.ltc.tau_net(tau_input)
tau = self.tau_min + (self.tau_max - self.tau_min) * torch.sigmoid(tau_raw)
return tau