Spaces:
Sleeping
Sleeping
| """ | |
| Liquid Neural Network Layer | |
| Stage 4a: Adaptive Time-Constant Processing | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional, Tuple | |
| import math | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class LiquidTimeConstant(nn.Module): | |
| """ | |
| Liquid Time-Constant (LTC) Cell | |
| Implements continuous-time dynamics with adaptive time constants. | |
| The "liquid" behavior allows the network to adjust its reaction | |
| speed based on input volatility. | |
| Reference: "Liquid Time-constant Networks" (Hasani et al., 2020) | |
| """ | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| hidden_dim: int, | |
| tau_min: float = 0.1, | |
| tau_max: float = 10.0, | |
| ): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.hidden_dim = hidden_dim | |
| self.tau_min = tau_min | |
| self.tau_max = tau_max | |
| # Input-to-hidden mapping | |
| self.W_in = nn.Linear(input_dim, hidden_dim) | |
| # Hidden-to-hidden mapping | |
| self.W_h = nn.Linear(hidden_dim, hidden_dim) | |
| # Time constant modulation | |
| # Tau is computed as: tau = tau_min + (tau_max - tau_min) * sigmoid(tau_net) | |
| self.tau_net = nn.Sequential( | |
| nn.Linear(input_dim + hidden_dim, hidden_dim), | |
| nn.Tanh(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| ) | |
| # Output gate | |
| self.W_out = nn.Linear(hidden_dim, hidden_dim) | |
| # Layer normalization for stability | |
| self.ln_h = nn.LayerNorm(hidden_dim) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| h: Optional[torch.Tensor] = None, | |
| dt: float = 0.01, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Single LTC step. | |
| Args: | |
| x: Input (batch, input_dim) | |
| h: Previous hidden state (batch, hidden_dim) | |
| dt: Time step | |
| Returns: | |
| output: (batch, hidden_dim) | |
| new_h: (batch, hidden_dim) | |
| """ | |
| batch_size = x.size(0) | |
| # Initialize hidden state if None | |
| if h is None: | |
| h = torch.zeros(batch_size, self.hidden_dim, device=x.device) | |
| # Compute adaptive time constant | |
| tau_input = torch.cat([x, h], dim=-1) | |
| tau_raw = self.tau_net(tau_input) | |
| tau = self.tau_min + (self.tau_max - self.tau_min) * torch.sigmoid(tau_raw) | |
| # Compute input contribution | |
| f_x = torch.tanh(self.W_in(x)) | |
| # Compute recurrent contribution | |
| f_h = torch.tanh(self.W_h(h)) | |
| # ODE update: dh/dt = (1/tau) * (-h + f(x, h)) | |
| # Euler discretization: h_new = h + dt * (1/tau) * (-h + f_x + f_h) | |
| activation = f_x + f_h | |
| dh = (dt / tau) * (-h + activation) | |
| h_new = h + dh | |
| # Apply layer normalization for stability | |
| h_new = self.ln_h(h_new) | |
| # Output | |
| output = torch.tanh(self.W_out(h_new)) | |
| return output, h_new | |
| class LiquidLayer(nn.Module): | |
| """ | |
| Complete Liquid Neural Network Layer | |
| Key Features: | |
| - Adaptive time constants based on input volatility | |
| - ODE-based dynamics for robustness to noise | |
| - Dropout for regularization (Layer 1) | |
| """ | |
| def __init__(self, config: dict): | |
| super().__init__() | |
| liquid_config = config["model"]["liquid"] | |
| self.input_dim = liquid_config["input_dim"] | |
| self.hidden_dim = liquid_config["hidden_dim"] | |
| self.tau_min = liquid_config["tau_min"] | |
| self.tau_max = liquid_config["tau_max"] | |
| self.dt = liquid_config["dt"] | |
| self.num_steps = liquid_config["num_steps"] | |
| self.dropout_rate = liquid_config["dropout"] | |
| # LTC cell | |
| self.ltc = LiquidTimeConstant( | |
| input_dim=self.input_dim, | |
| hidden_dim=self.hidden_dim, | |
| tau_min=self.tau_min, | |
| tau_max=self.tau_max, | |
| ) | |
| # Input projection (if dimensions don't match) | |
| if self.input_dim != self.hidden_dim: | |
| self.input_proj = nn.Linear(self.input_dim, self.input_dim) | |
| else: | |
| self.input_proj = nn.Identity() | |
| # Dropout (Layer 1) | |
| self.dropout = nn.Dropout(self.dropout_rate) | |
| # Output layer normalization | |
| self.output_norm = nn.LayerNorm(self.hidden_dim) | |
| logger.info(f"Liquid Layer initialized:") | |
| logger.info(f" Hidden Dim: {self.hidden_dim}") | |
| logger.info(f" Tau Range: [{self.tau_min}, {self.tau_max}]") | |
| logger.info(f" Num Steps: {self.num_steps}") | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Process input through Liquid layer. | |
| The layer runs multiple ODE steps to "think" about the input, | |
| with the time constant adapting based on input complexity. | |
| Args: | |
| x: Input features (batch, input_dim) | |
| Returns: | |
| Output features (batch, hidden_dim) | |
| """ | |
| # Project input | |
| x = self.input_proj(x) | |
| # Initialize hidden state | |
| h = None | |
| # Run multiple liquid steps | |
| for _ in range(self.num_steps): | |
| output, h = self.ltc(x, h, self.dt) | |
| # Apply dropout | |
| output = self.dropout(output) | |
| # Normalize output | |
| output = self.output_norm(output) | |
| return output | |
| def get_time_constants(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Get the adaptive time constants for interpretability. | |
| Returns: | |
| Time constants (batch, hidden_dim) | |
| """ | |
| h = torch.zeros(x.size(0), self.hidden_dim, device=x.device) | |
| x = self.input_proj(x) | |
| tau_input = torch.cat([x, h], dim=-1) | |
| tau_raw = self.ltc.tau_net(tau_input) | |
| tau = self.tau_min + (self.tau_max - self.tau_min) * torch.sigmoid(tau_raw) | |
| return tau |