Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| class GatingNetwork(nn.Module): | |
| def __init__(self, N, M, Experts, dtype=torch.float32): | |
| super().__init__() | |
| self.conv = nn.Conv1d(N, N, kernel_size=2, padding=0, bias=True, dtype=dtype) | |
| self.softmax_temp1 = nn.Parameter(torch.tensor([0.1], dtype=dtype)) | |
| self.D = nn.Parameter(torch.zeros(N, M, dtype=dtype)) | |
| self.D.data[:, :N] = torch.eye(N, dtype=dtype) | |
| self.mlp_layer1 = nn.Linear(M + N, Experts, dtype=dtype) | |
| self.mlp_layer2 = nn.Linear(Experts, Experts, dtype=dtype) | |
| self.softmax_temp2 = nn.Parameter(torch.tensor([0.1], dtype=dtype)) | |
| self.sigma = nn.Parameter(torch.ones(N, dtype=dtype) * 0.05, requires_grad=True) | |
| def forward(self, context, z, precomputed_cnn=None): | |
| # context: (seq_length, batch_size, N) | |
| # z: (M, batch_size) | |
| # precomputed_cnn: Optional precomputed CNN features for inference (seq_length-1, batch_size, N) | |
| seq_length, batch_size, N = context.shape | |
| M = z.shape[0] | |
| # Compute attention weights | |
| z_obs = self.D @ z.detach() | |
| z_current = z_obs + self.sigma.unsqueeze(1) * torch.randn(N, batch_size, dtype=z.dtype, device=z.device) | |
| z_current_t = z_current.transpose(0, 1) | |
| context_frames = context[:-1] | |
| distances = torch.sum(torch.abs(context_frames - z_current_t.unsqueeze(0)), dim=2) | |
| attention_weights = F.softmax(-distances / torch.abs(self.softmax_temp1[0]), dim=0) | |
| # Process context with convolution | |
| # Use precomputed CNN features if provided, otherwise compute them | |
| if precomputed_cnn is not None: | |
| encoded = precomputed_cnn | |
| else: | |
| context_for_conv = context.permute(1, 2, 0) | |
| encoded = self.conv(context_for_conv) | |
| encoded = encoded.permute(2, 0, 1) | |
| # Build weighted embedding | |
| weighted_encoded = encoded * attention_weights.unsqueeze(2) | |
| embedding = torch.sum(weighted_encoded, dim=0) | |
| embedding = embedding.transpose(0, 1) | |
| # Predict expert weights | |
| combined = torch.cat([embedding, z], dim=0) | |
| combined_t = combined.transpose(0, 1) | |
| mlp_output = self.mlp_layer2(F.relu(self.mlp_layer1(combined_t))) | |
| w_exp = F.softmax(-mlp_output.transpose(0, 1) / torch.abs(self.softmax_temp2[0]), dim=0) | |
| return w_exp | |
| def gaussian_init(self, M, N, dtype=torch.float32): | |
| return torch.randn(M, N, dtype=dtype) * 0.01 | |
| class ExpertNetwork(nn.Module): | |
| """Base class for different expert architectures.""" | |
| def __init__(self, M, P=0, probabilistic=False, dtype=torch.float32): | |
| super().__init__() | |
| self.M = M | |
| self.P = P | |
| self.probabilistic = probabilistic | |
| self.dtype = dtype | |
| # Parameter for probabilistic experts | |
| if probabilistic: | |
| self.sigma = nn.Parameter(torch.ones(1, dtype=dtype) * 0.05, requires_grad=True) | |
| def forward(self, z): | |
| raise NotImplementedError("Subclasses must implement forward method") | |
| def add_noise(self, z): | |
| """Add stochasticity to the latent state if in probabilistic mode. | |
| Args: | |
| z: Input tensor | |
| """ | |
| if self.probabilistic: | |
| batch_size = z.shape[1] | |
| noise = torch.randn(self.M, batch_size, dtype=z.dtype, device=z.device) | |
| return z + self.sigma * noise | |
| return z | |
| def gaussian_init(self, M, N): | |
| return torch.randn(M, N, dtype=self.dtype) * 0.01 | |
| def normalized_positive_definite(self, M): | |
| R = np.random.randn(M, M).astype(np.float32) | |
| K = R.T @ R / M + np.eye(M) | |
| lambd = np.max(np.abs(np.linalg.eigvals(K))) | |
| return K / lambd | |
| class AlmostLinearRNN(ExpertNetwork): | |
| """Almost linear RNN expert architecture.""" | |
| def __init__(self, M, P, probabilistic=False, dtype=torch.float32): | |
| super().__init__(M, P, probabilistic, dtype=dtype) | |
| self.A, self.W, self.h = self.initialize_A_W_h(M) | |
| def forward(self, z): | |
| # z: (M, batch_size) | |
| # Split z into regular and ReLU parts | |
| z1 = z[:-self.P, :] | |
| z2 = F.relu(z[-self.P:, :]) | |
| zcat = torch.cat([z1, z2], dim=0) | |
| output = self.A.unsqueeze(-1) * z + self.W @ zcat + self.h.unsqueeze(-1) | |
| # Add stochasticity if probabilistic | |
| if self.probabilistic: | |
| output = self.add_noise(output) | |
| return output | |
| def initialize_A_W_h(self, M): | |
| A = torch.nn.Parameter(torch.diag(torch.tensor(self.normalized_positive_definite(M), dtype=self.dtype))) | |
| W = torch.nn.Parameter(self.gaussian_init(M, M)) | |
| h = torch.nn.Parameter(torch.zeros(M, dtype=self.dtype)) | |
| return A, W, h | |
| class ClippedShallowPLRNN(ExpertNetwork): | |
| """Clipped shallow PLRNN expert architecture.""" | |
| def __init__(self, M, hidden_dim=50, probabilistic=False, dtype=torch.float32): | |
| super().__init__(M, hidden_dim, probabilistic, dtype=dtype) | |
| self.A = torch.nn.Parameter(torch.diag(torch.tensor(self.normalized_positive_definite(M), dtype=self.dtype))) | |
| self.W1 = torch.nn.Parameter(self.gaussian_init(M, hidden_dim)) | |
| self.W2 = torch.nn.Parameter(self.gaussian_init(hidden_dim, M)) | |
| self.h1 = torch.nn.Parameter(torch.zeros(M, dtype=self.dtype)) | |
| self.h2 = torch.nn.Parameter(torch.zeros(hidden_dim, dtype=self.dtype)) | |
| def forward(self, z): | |
| # z: (M, batch_size) | |
| W2z = self.W2 @ z | |
| output = (self.A.unsqueeze(-1) * z + | |
| self.W1 @ (F.relu(W2z + self.h2.unsqueeze(-1)) - F.relu(W2z)) + | |
| self.h1.unsqueeze(-1)) | |
| # Add stochasticity if probabilistic | |
| if self.probabilistic: | |
| output = self.add_noise(output) | |
| return output | |
| class DynaMix(nn.Module): | |
| def __init__(self, M, N, Experts, P=2, hidden_dim=50, expert_type="almost_linear_rnn", | |
| probabilistic_expert=False, dtype=torch.float32): | |
| """ | |
| Initialize a DynaMix model. | |
| Args: | |
| M: Dimension of latent state | |
| N: Dimension of observation space | |
| Experts: Number of experts | |
| P: Number of ReLU dimensions | |
| hidden_dim: Hidden dimension for clipped shallow PLRNN | |
| expert_type: Type of expert to use ("almost_linear_rnn" or "clipped_shallow_plrnn") | |
| probabilistic_expert: Whether to use probabilistic experts | |
| dtype: Data type for model parameters (default: torch.float32) | |
| """ | |
| super().__init__() | |
| self.expert_type = expert_type | |
| self.probabilistic_expert = probabilistic_expert | |
| self.experts = nn.ModuleList() | |
| self.dtype = dtype | |
| for _ in range(Experts): | |
| if expert_type == "almost_linear_rnn": | |
| self.experts.append(AlmostLinearRNN(M, P, probabilistic=probabilistic_expert, dtype=dtype)) | |
| elif expert_type == "clipped_shallow_plrnn": | |
| self.experts.append(ClippedShallowPLRNN(M, hidden_dim, probabilistic=probabilistic_expert, dtype=dtype)) | |
| else: | |
| raise ValueError(f"Unknown expert type: {expert_type}") | |
| self.gating_network = GatingNetwork(N, M, Experts, dtype=dtype) | |
| self.B = nn.Parameter(self.uniform_init((N, M), dtype=dtype)) | |
| self.N = N | |
| self.Experts = Experts | |
| self.P = P | |
| self.hidden_dim = hidden_dim | |
| self.M = M | |
| def step(self, z, context, precomputed_cnn=None): | |
| # z: (M, batch_size) | |
| # context: (seq_length, batch_size, N) | |
| # precomputed_cnn: Optional precomputed CNN features | |
| # Compute expert weights | |
| w_exp = self.gating_network(context, z, precomputed_cnn=precomputed_cnn) # (Experts, batch_size) | |
| results = [] | |
| # Compute expert outputs | |
| for i in range(self.Experts): | |
| expert_output = self.experts[i](z) | |
| results.append(expert_output * w_exp[i, :].unsqueeze(0)) | |
| # Combine expert outputs | |
| return torch.sum(torch.stack(results, dim=0), dim=0) | |
| def forward(self, z, context, precomputed_cnn=None): | |
| """ | |
| Forward pass through the DynaMix model. | |
| Args: | |
| z: Latent state of shape (M, batch_size) | |
| context: Context data of shape (seq_length, batch_size, N) | |
| precomputed_cnn: Optional precomputed CNN features to avoid redundant computation for inference | |
| Shape should be (seq_length-1, batch_size, N) | |
| Returns: | |
| Updated latent state | |
| """ | |
| return self.step(z, context, precomputed_cnn=precomputed_cnn) | |
| def precompute_cnn(self, context): | |
| """ | |
| Precompute CNN features for more efficient inference. | |
| Args: | |
| context: Context data of shape (seq_length, batch_size, N) | |
| Returns: | |
| Precomputed CNN features of shape (seq_length-1, batch_size, N) | |
| """ | |
| # Process context with convolution | |
| context_for_conv = context.permute(1, 2, 0) | |
| encoded = self.gating_network.conv(context_for_conv) | |
| return encoded.permute(2, 0, 1) | |
| def uniform_init(self, shape, dtype=torch.float32): | |
| din = shape[-1] | |
| r = 1 / np.sqrt(din) | |
| return (torch.rand(shape, dtype=dtype) * 2 - 1) * r | |
| def gaussian_init(self, M, N): | |
| return torch.randn(M, N, dtype=self.dtype) * 0.01 | |
| def print_model_parameters(model): | |
| """Print simplified breakdown of model parameters by component.""" | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| print("\n" + "-"*60) | |
| print("Model Parameter Summary:") | |
| print(f" Architecture: DynaMix with {model.expert_type} experts") | |
| if model.expert_type == "almost_linear_rnn": | |
| print(f" Dimensions: M={model.M}, N={model.N}, Experts={model.Experts}, P={model.P}") | |
| else: | |
| print(f" Dimensions: M={model.M}, N={model.N}, Experts={model.Experts}, Hidden dim={model.hidden_dim}") | |
| print(f" Probabilistic experts: {model.probabilistic_expert}") | |
| # Count parameters | |
| gating_params = sum(p.numel() for p in model.gating_network.parameters()) | |
| expert_params = sum(p.numel() for expert in model.experts for p in expert.parameters()) | |
| b_params = model.B.numel() | |
| # Print parameter counts | |
| print(f"\nParameter counts:") | |
| print(f" Gating Network: {gating_params:,} ({gating_params/total_params:.1%})") | |
| print(f" Experts: {expert_params:,} ({expert_params/total_params:.1%})") | |
| print(f" Observation matrix: {b_params:,} ({b_params/total_params:.1%})") | |
| print(f" Total: {total_params:,} parameters") | |
| print("-"*60) |