Spaces:
Paused
Paused
| import torch.nn as nn | |
| import torch | |
| import torch.nn.functional as F # Used for GLU if not in modules | |
| import numpy as np | |
| import math | |
| # Local imports (Assuming these contain necessary custom modules) | |
| from models.modules import * | |
| from models.utils import * # Assuming compute_decay, compute_normalized_entropy are here | |
| class LSTMBaseline(nn.Module): | |
| """ | |
| LSTM Baseline | |
| Args: | |
| iterations (int): Number of internal 'thought' steps (T, in paper). | |
| d_model (int): Core dimensionality of the CTM's latent space (D, in paper). | |
| d_input (int): Dimensionality of projected attention outputs or direct input features. | |
| backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none'). | |
| """ | |
| def __init__(self, | |
| iterations, | |
| d_model, | |
| d_input, | |
| backbone_type, | |
| ): | |
| super(LSTMBaseline, self).__init__() | |
| # --- Core Parameters --- | |
| self.iterations = iterations | |
| self.d_model = d_model | |
| self.backbone_type = backbone_type | |
| # --- Input Assertions --- | |
| assert backbone_type in ('navigation-backbone', 'classic-control-backbone'), f"Invalid backbone_type: {backbone_type}" | |
| # --- Backbone / Feature Extraction --- | |
| if self.backbone_type == 'navigation-backbone': | |
| grid_size = 7 | |
| self.backbone = MiniGridBackbone(d_input=d_input, grid_size=grid_size) | |
| lstm_cell_input_dim = grid_size * grid_size * d_input | |
| elif self.backbone_type == 'classic-control-backbone': | |
| self.backbone = ClassicControlBackbone(d_input=d_input) | |
| lstm_cell_input_dim = d_input | |
| else: | |
| raise NotImplemented('The only backbone supported for RL are for navigation (symbolic C x H x W inputs) and classic control (vectors of length D).') | |
| # --- Core LSTM Modules --- | |
| self.lstm_cell = nn.LSTMCell(lstm_cell_input_dim, d_model) | |
| self.register_parameter('start_hidden_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True)) | |
| self.register_parameter('start_cell_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True)) | |
| def compute_features(self, x): | |
| """Applies backbone and positional embedding to input.""" | |
| return self.backbone(x) | |
| def forward(self, x, hidden_states, track=False): | |
| """ | |
| Forward pass - Reverted to structure closer to user's working version. | |
| Executes T=iterations steps. | |
| """ | |
| # --- Tracking Initialization --- | |
| activations_tracking = [] | |
| # --- Featurise Input Data --- | |
| features = self.compute_features(x) | |
| hidden_state = hidden_states[0] | |
| cell_state = hidden_states[1] | |
| # --- Recurrent Loop --- | |
| for stepi in range(self.iterations): | |
| lstm_input = features.reshape(x.size(0), -1) | |
| hidden_state, cell_state = self.lstm_cell(lstm_input.squeeze(1), (hidden_state, cell_state)) | |
| # --- Tracking --- | |
| if track: | |
| activations_tracking.append(hidden_state.squeeze(1).detach().cpu().numpy()) | |
| hidden_states = ( | |
| hidden_state, | |
| cell_state | |
| ) | |
| # --- Return Values --- | |
| if track: | |
| return hidden_state, hidden_states, np.array(activations_tracking), np.array(activations_tracking) | |
| return hidden_state, hidden_states |