DynaMix / dynamix /dynamix.py
Dschobby's picture
Upload 14 files
776877d verified
raw
history blame
11 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class GatingNetwork(nn.Module):
def __init__(self, N, M, Experts, dtype=torch.float32):
super().__init__()
self.conv = nn.Conv1d(N, N, kernel_size=2, padding=0, bias=True, dtype=dtype)
self.softmax_temp1 = nn.Parameter(torch.tensor([0.1], dtype=dtype))
self.D = nn.Parameter(torch.zeros(N, M, dtype=dtype))
self.D.data[:, :N] = torch.eye(N, dtype=dtype)
self.mlp_layer1 = nn.Linear(M + N, Experts, dtype=dtype)
self.mlp_layer2 = nn.Linear(Experts, Experts, dtype=dtype)
self.softmax_temp2 = nn.Parameter(torch.tensor([0.1], dtype=dtype))
self.sigma = nn.Parameter(torch.ones(N, dtype=dtype) * 0.05, requires_grad=True)
def forward(self, context, z, precomputed_cnn=None):
# context: (seq_length, batch_size, N)
# z: (M, batch_size)
# precomputed_cnn: Optional precomputed CNN features for inference (seq_length-1, batch_size, N)
seq_length, batch_size, N = context.shape
M = z.shape[0]
# Compute attention weights
z_obs = self.D @ z.detach()
z_current = z_obs + self.sigma.unsqueeze(1) * torch.randn(N, batch_size, dtype=z.dtype, device=z.device)
z_current_t = z_current.transpose(0, 1)
context_frames = context[:-1]
distances = torch.sum(torch.abs(context_frames - z_current_t.unsqueeze(0)), dim=2)
attention_weights = F.softmax(-distances / torch.abs(self.softmax_temp1[0]), dim=0)
# Process context with convolution
# Use precomputed CNN features if provided, otherwise compute them
if precomputed_cnn is not None:
encoded = precomputed_cnn
else:
context_for_conv = context.permute(1, 2, 0)
encoded = self.conv(context_for_conv)
encoded = encoded.permute(2, 0, 1)
# Build weighted embedding
weighted_encoded = encoded * attention_weights.unsqueeze(2)
embedding = torch.sum(weighted_encoded, dim=0)
embedding = embedding.transpose(0, 1)
# Predict expert weights
combined = torch.cat([embedding, z], dim=0)
combined_t = combined.transpose(0, 1)
mlp_output = self.mlp_layer2(F.relu(self.mlp_layer1(combined_t)))
w_exp = F.softmax(-mlp_output.transpose(0, 1) / torch.abs(self.softmax_temp2[0]), dim=0)
return w_exp
def gaussian_init(self, M, N, dtype=torch.float32):
return torch.randn(M, N, dtype=dtype) * 0.01
class ExpertNetwork(nn.Module):
"""Base class for different expert architectures."""
def __init__(self, M, P=0, probabilistic=False, dtype=torch.float32):
super().__init__()
self.M = M
self.P = P
self.probabilistic = probabilistic
self.dtype = dtype
# Parameter for probabilistic experts
if probabilistic:
self.sigma = nn.Parameter(torch.ones(1, dtype=dtype) * 0.05, requires_grad=True)
def forward(self, z):
raise NotImplementedError("Subclasses must implement forward method")
def add_noise(self, z):
"""Add stochasticity to the latent state if in probabilistic mode.
Args:
z: Input tensor
"""
if self.probabilistic:
batch_size = z.shape[1]
noise = torch.randn(self.M, batch_size, dtype=z.dtype, device=z.device)
return z + self.sigma * noise
return z
def gaussian_init(self, M, N):
return torch.randn(M, N, dtype=self.dtype) * 0.01
def normalized_positive_definite(self, M):
R = np.random.randn(M, M).astype(np.float32)
K = R.T @ R / M + np.eye(M)
lambd = np.max(np.abs(np.linalg.eigvals(K)))
return K / lambd
class AlmostLinearRNN(ExpertNetwork):
"""Almost linear RNN expert architecture."""
def __init__(self, M, P, probabilistic=False, dtype=torch.float32):
super().__init__(M, P, probabilistic, dtype=dtype)
self.A, self.W, self.h = self.initialize_A_W_h(M)
def forward(self, z):
# z: (M, batch_size)
# Split z into regular and ReLU parts
z1 = z[:-self.P, :]
z2 = F.relu(z[-self.P:, :])
zcat = torch.cat([z1, z2], dim=0)
output = self.A.unsqueeze(-1) * z + self.W @ zcat + self.h.unsqueeze(-1)
# Add stochasticity if probabilistic
if self.probabilistic:
output = self.add_noise(output)
return output
def initialize_A_W_h(self, M):
A = torch.nn.Parameter(torch.diag(torch.tensor(self.normalized_positive_definite(M), dtype=self.dtype)))
W = torch.nn.Parameter(self.gaussian_init(M, M))
h = torch.nn.Parameter(torch.zeros(M, dtype=self.dtype))
return A, W, h
class ClippedShallowPLRNN(ExpertNetwork):
"""Clipped shallow PLRNN expert architecture."""
def __init__(self, M, hidden_dim=50, probabilistic=False, dtype=torch.float32):
super().__init__(M, hidden_dim, probabilistic, dtype=dtype)
self.A = torch.nn.Parameter(torch.diag(torch.tensor(self.normalized_positive_definite(M), dtype=self.dtype)))
self.W1 = torch.nn.Parameter(self.gaussian_init(M, hidden_dim))
self.W2 = torch.nn.Parameter(self.gaussian_init(hidden_dim, M))
self.h1 = torch.nn.Parameter(torch.zeros(M, dtype=self.dtype))
self.h2 = torch.nn.Parameter(torch.zeros(hidden_dim, dtype=self.dtype))
def forward(self, z):
# z: (M, batch_size)
W2z = self.W2 @ z
output = (self.A.unsqueeze(-1) * z +
self.W1 @ (F.relu(W2z + self.h2.unsqueeze(-1)) - F.relu(W2z)) +
self.h1.unsqueeze(-1))
# Add stochasticity if probabilistic
if self.probabilistic:
output = self.add_noise(output)
return output
class DynaMix(nn.Module):
def __init__(self, M, N, Experts, P=2, hidden_dim=50, expert_type="almost_linear_rnn",
probabilistic_expert=False, dtype=torch.float32):
"""
Initialize a DynaMix model.
Args:
M: Dimension of latent state
N: Dimension of observation space
Experts: Number of experts
P: Number of ReLU dimensions
hidden_dim: Hidden dimension for clipped shallow PLRNN
expert_type: Type of expert to use ("almost_linear_rnn" or "clipped_shallow_plrnn")
probabilistic_expert: Whether to use probabilistic experts
dtype: Data type for model parameters (default: torch.float32)
"""
super().__init__()
self.expert_type = expert_type
self.probabilistic_expert = probabilistic_expert
self.experts = nn.ModuleList()
self.dtype = dtype
for _ in range(Experts):
if expert_type == "almost_linear_rnn":
self.experts.append(AlmostLinearRNN(M, P, probabilistic=probabilistic_expert, dtype=dtype))
elif expert_type == "clipped_shallow_plrnn":
self.experts.append(ClippedShallowPLRNN(M, hidden_dim, probabilistic=probabilistic_expert, dtype=dtype))
else:
raise ValueError(f"Unknown expert type: {expert_type}")
self.gating_network = GatingNetwork(N, M, Experts, dtype=dtype)
self.B = nn.Parameter(self.uniform_init((N, M), dtype=dtype))
self.N = N
self.Experts = Experts
self.P = P
self.hidden_dim = hidden_dim
self.M = M
def step(self, z, context, precomputed_cnn=None):
# z: (M, batch_size)
# context: (seq_length, batch_size, N)
# precomputed_cnn: Optional precomputed CNN features
# Compute expert weights
w_exp = self.gating_network(context, z, precomputed_cnn=precomputed_cnn) # (Experts, batch_size)
results = []
# Compute expert outputs
for i in range(self.Experts):
expert_output = self.experts[i](z)
results.append(expert_output * w_exp[i, :].unsqueeze(0))
# Combine expert outputs
return torch.sum(torch.stack(results, dim=0), dim=0)
def forward(self, z, context, precomputed_cnn=None):
"""
Forward pass through the DynaMix model.
Args:
z: Latent state of shape (M, batch_size)
context: Context data of shape (seq_length, batch_size, N)
precomputed_cnn: Optional precomputed CNN features to avoid redundant computation for inference
Shape should be (seq_length-1, batch_size, N)
Returns:
Updated latent state
"""
return self.step(z, context, precomputed_cnn=precomputed_cnn)
def precompute_cnn(self, context):
"""
Precompute CNN features for more efficient inference.
Args:
context: Context data of shape (seq_length, batch_size, N)
Returns:
Precomputed CNN features of shape (seq_length-1, batch_size, N)
"""
# Process context with convolution
context_for_conv = context.permute(1, 2, 0)
encoded = self.gating_network.conv(context_for_conv)
return encoded.permute(2, 0, 1)
def uniform_init(self, shape, dtype=torch.float32):
din = shape[-1]
r = 1 / np.sqrt(din)
return (torch.rand(shape, dtype=dtype) * 2 - 1) * r
def gaussian_init(self, M, N):
return torch.randn(M, N, dtype=self.dtype) * 0.01
def print_model_parameters(model):
"""Print simplified breakdown of model parameters by component."""
total_params = sum(p.numel() for p in model.parameters())
print("\n" + "-"*60)
print("Model Parameter Summary:")
print(f" Architecture: DynaMix with {model.expert_type} experts")
if model.expert_type == "almost_linear_rnn":
print(f" Dimensions: M={model.M}, N={model.N}, Experts={model.Experts}, P={model.P}")
else:
print(f" Dimensions: M={model.M}, N={model.N}, Experts={model.Experts}, Hidden dim={model.hidden_dim}")
print(f" Probabilistic experts: {model.probabilistic_expert}")
# Count parameters
gating_params = sum(p.numel() for p in model.gating_network.parameters())
expert_params = sum(p.numel() for expert in model.experts for p in expert.parameters())
b_params = model.B.numel()
# Print parameter counts
print(f"\nParameter counts:")
print(f" Gating Network: {gating_params:,} ({gating_params/total_params:.1%})")
print(f" Experts: {expert_params:,} ({expert_params/total_params:.1%})")
print(f" Observation matrix: {b_params:,} ({b_params/total_params:.1%})")
print(f" Total: {total_params:,} parameters")
print("-"*60)