team222 / models /liquid_cell.py
ylop's picture
Deploy 2M step LNN training with optimized GPU utilization
28dbd6d verified
"""
Liquid Neural Network Cell - Discrete-time approximation of continuous-time dynamics.
Implements a liquid cell with learnable per-neuron time constants.
The cell updates hidden state using a differential equation approximation.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class LiquidCell(nn.Module):
"""
A discrete-time liquid neural network cell.
Hidden state update rule:
h_{t+1,i} = h_{t,i} + dt / tau_i * ( tanh( W_hh[i]路h_t + W_xh[i]路x_t + b[i] ) - h_{t,i} )
where tau_i is a learnable per-neuron time constant.
Args:
hidden_size: Number of hidden neurons
input_size: Size of input vector
dt: Time step for discrete approximation (default: 0.1)
"""
def __init__(self, hidden_size: int, input_size: int, dt: float = 0.1):
super().__init__()
self.hidden_size = hidden_size
self.input_size = input_size
self.dt = dt
# Recurrent weight matrix: (hidden_size, hidden_size)
self.W_hh = nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.1)
# Input weight matrix: (hidden_size, input_size)
self.W_xh = nn.Parameter(torch.randn(hidden_size, input_size) * 0.1)
# Bias vector: (hidden_size,)
self.b = nn.Parameter(torch.zeros(hidden_size))
# Raw time constants (will be transformed to positive values)
# Shape: (hidden_size,)
self.tau_raw = nn.Parameter(torch.ones(hidden_size))
def forward(self, h: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the liquid cell.
Args:
h: Hidden state tensor of shape (batch, hidden_size)
x: Input tensor of shape (batch, input_size)
Returns:
Next hidden state tensor of shape (batch, hidden_size)
"""
# Compute time constants: tau = softplus(tau_raw) + 1e-3
# This ensures tau is always positive
tau = F.softplus(self.tau_raw) + 1e-3
# Compute preactivation:
# preact = tanh( W_hh @ h^T + W_xh @ x^T + b )
# Using batch matrix multiplication
# W_hh @ h^T: (hidden_size, hidden_size) @ (hidden_size, batch) -> (hidden_size, batch)
# Then transpose to (batch, hidden_size)
h_proj = torch.matmul(h, self.W_hh.t()) # (batch, hidden_size)
# W_xh @ x^T: (hidden_size, input_size) @ (input_size, batch) -> (hidden_size, batch)
# Then transpose to (batch, hidden_size)
x_proj = torch.matmul(x, self.W_xh.t()) # (batch, hidden_size)
# Add bias and apply tanh
preact = torch.tanh(h_proj + x_proj + self.b) # (batch, hidden_size)
# Update hidden state:
# h_next = h + dt * (preact - h) / tau
# tau is (hidden_size,), so we need to broadcast
h_next = h + self.dt * (preact - h) / tau.unsqueeze(0) # (batch, hidden_size)
# Clamp to reasonable range for stability
h_next = torch.clamp(h_next, -5.0, 5.0)
return h_next