Spaces:
Runtime error
Runtime error
File size: 3,168 Bytes
28dbd6d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
"""
Liquid Neural Network Cell - Discrete-time approximation of continuous-time dynamics.
Implements a liquid cell with learnable per-neuron time constants.
The cell updates hidden state using a differential equation approximation.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class LiquidCell(nn.Module):
"""
A discrete-time liquid neural network cell.
Hidden state update rule:
h_{t+1,i} = h_{t,i} + dt / tau_i * ( tanh( W_hh[i]·h_t + W_xh[i]·x_t + b[i] ) - h_{t,i} )
where tau_i is a learnable per-neuron time constant.
Args:
hidden_size: Number of hidden neurons
input_size: Size of input vector
dt: Time step for discrete approximation (default: 0.1)
"""
def __init__(self, hidden_size: int, input_size: int, dt: float = 0.1):
super().__init__()
self.hidden_size = hidden_size
self.input_size = input_size
self.dt = dt
# Recurrent weight matrix: (hidden_size, hidden_size)
self.W_hh = nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.1)
# Input weight matrix: (hidden_size, input_size)
self.W_xh = nn.Parameter(torch.randn(hidden_size, input_size) * 0.1)
# Bias vector: (hidden_size,)
self.b = nn.Parameter(torch.zeros(hidden_size))
# Raw time constants (will be transformed to positive values)
# Shape: (hidden_size,)
self.tau_raw = nn.Parameter(torch.ones(hidden_size))
def forward(self, h: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the liquid cell.
Args:
h: Hidden state tensor of shape (batch, hidden_size)
x: Input tensor of shape (batch, input_size)
Returns:
Next hidden state tensor of shape (batch, hidden_size)
"""
# Compute time constants: tau = softplus(tau_raw) + 1e-3
# This ensures tau is always positive
tau = F.softplus(self.tau_raw) + 1e-3
# Compute preactivation:
# preact = tanh( W_hh @ h^T + W_xh @ x^T + b )
# Using batch matrix multiplication
# W_hh @ h^T: (hidden_size, hidden_size) @ (hidden_size, batch) -> (hidden_size, batch)
# Then transpose to (batch, hidden_size)
h_proj = torch.matmul(h, self.W_hh.t()) # (batch, hidden_size)
# W_xh @ x^T: (hidden_size, input_size) @ (input_size, batch) -> (hidden_size, batch)
# Then transpose to (batch, hidden_size)
x_proj = torch.matmul(x, self.W_xh.t()) # (batch, hidden_size)
# Add bias and apply tanh
preact = torch.tanh(h_proj + x_proj + self.b) # (batch, hidden_size)
# Update hidden state:
# h_next = h + dt * (preact - h) / tau
# tau is (hidden_size,), so we need to broadcast
h_next = h + self.dt * (preact - h) / tau.unsqueeze(0) # (batch, hidden_size)
# Clamp to reasonable range for stability
h_next = torch.clamp(h_next, -5.0, 5.0)
return h_next
|