| |
| import torch |
|
|
|
|
| class RNN(torch.nn.Module): |
| def __init__(self, options, place_cells): |
| super(RNN, self).__init__() |
| self.Ng = options.Ng |
| self.Np = options.Np |
| self.sequence_length = options.sequence_length |
| self.weight_decay = options.weight_decay |
| self.place_cells = place_cells |
|
|
| |
| self.encoder = torch.nn.Linear(self.Np, self.Ng, bias=False) |
| self.RNN = torch.nn.RNN( |
| input_size=2, |
| hidden_size=self.Ng, |
| nonlinearity=options.activation, |
| bias=False, |
| ) |
| |
| self.decoder = torch.nn.Linear(self.Ng, self.Np, bias=False) |
|
|
| self.softmax = torch.nn.Softmax(dim=-1) |
|
|
| def g(self, inputs): |
| """ |
| Compute grid cell activations. |
| Args: |
| inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2]. |
| |
| Returns: |
| g: Batch of grid cell activations with shape [batch_size, sequence_length, Ng]. |
| """ |
| v, p0 = inputs |
| init_state = self.encoder(p0)[None] |
| g, _ = self.RNN(v, init_state) |
| return g |
|
|
| def predict(self, inputs): |
| """ |
| Predict place cell code. |
| Args: |
| inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2]. |
| |
| Returns: |
| place_preds: Predicted place cell activations with shape |
| [batch_size, sequence_length, Np]. |
| """ |
| place_preds = self.decoder(self.g(inputs)) |
|
|
| return place_preds |
|
|
| def set_weights(self, weights): |
| """ |
| Load weights from a numpy array (e.g. from the provided example weights). |
| Assumes weights are in the order: [encoder, rnn_ih, rnn_hh, decoder] |
| and transposed (TF/Keras format). |
| """ |
| with torch.no_grad(): |
| |
| self.encoder.weight.copy_(torch.from_numpy(weights[0].T).float()) |
| |
| |
| self.RNN.weight_ih_l0.copy_(torch.from_numpy(weights[1].T).float()) |
| |
| |
| self.RNN.weight_hh_l0.copy_(torch.from_numpy(weights[2].T).float()) |
| |
| |
| self.decoder.weight.copy_(torch.from_numpy(weights[3].T).float()) |
|
|
| def compute_loss(self, inputs, pc_outputs, pos): |
| """ |
| Compute avg. loss and decoding error. |
| Args: |
| inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2]. |
| pc_outputs: Ground truth place cell activations with shape |
| [batch_size, sequence_length, Np]. |
| pos: Ground truth 2d position with shape [batch_size, sequence_length, 2]. |
| |
| Returns: |
| loss: Avg. loss for this training batch. |
| err: Avg. decoded position error in cm. |
| """ |
| y: torch.Tensor = pc_outputs |
| preds: torch.Tensor = self.predict(inputs) |
| loss = torch.nn.functional.cross_entropy(preds.flatten(0, 1), y.flatten(0, 1)) |
|
|
| |
| loss += self.weight_decay * (self.RNN.weight_hh_l0**2).sum() |
|
|
| |
| pred_pos = self.place_cells.get_nearest_cell_pos(preds) |
| err = torch.sqrt(((pos - pred_pos) ** 2).sum(-1)).mean() |
|
|
| return loss, err |
|
|