| |
| import numpy as np |
| import torch |
| import scipy |
|
|
|
|
| class PlaceCells(object): |
| def __init__(self, options, us=None): |
| self.Np = options.Np |
| self.sigma = options.place_cell_rf |
| self.surround_scale = options.surround_scale |
| self.box_width = options.box_width |
| self.box_height = options.box_height |
| self.is_periodic = options.periodic |
| self.DoG = options.DoG |
| self.device = options.device |
| self.softmax = torch.nn.Softmax(dim=-1) |
|
|
| |
| usx = np.random.uniform(-self.box_width / 2, self.box_width / 2, (self.Np,)) |
| usy = np.random.uniform(-self.box_width / 2, self.box_width / 2, (self.Np,)) |
| self.us = torch.tensor(np.vstack([usx, usy]).T) |
| |
| self.us = self.us.to(self.device) |
| |
|
|
| def get_activation(self, pos): |
| """ |
| Get place cell activations for a given position. |
| |
| Args: |
| pos: 2d position of shape [batch_size, sequence_length, 2]. |
| |
| Returns: |
| outputs: Place cell activations with shape [batch_size, sequence_length, Np]. |
| """ |
| d = torch.abs(pos[:, :, None, :] - self.us[None, None, ...]).float() |
|
|
| if self.is_periodic: |
| dx = d[:, :, :, 0] |
| dy = d[:, :, :, 1] |
| dx = torch.minimum(dx, self.box_width - dx) |
| dy = torch.minimum(dy, self.box_height - dy) |
| d = torch.stack([dx, dy], axis=-1) |
|
|
| norm2 = (d**2).sum(-1) |
|
|
| |
| |
| |
| outputs = self.softmax(-norm2 / (2 * self.sigma**2)) |
|
|
| if self.DoG: |
| |
| |
| outputs -= self.softmax(-norm2 / (2 * self.surround_scale * self.sigma**2)) |
|
|
| |
| min_output, _ = outputs.min(-1, keepdims=True) |
| outputs += torch.abs(min_output) |
| outputs /= outputs.sum(-1, keepdims=True) |
| return outputs |
|
|
| def get_nearest_cell_pos(self, activation, k=3): |
| """ |
| Decode position using centers of k maximally active place cells. |
| |
| Args: |
| activation: Place cell activations of shape [batch_size, sequence_length, Np]. |
| k: Number of maximally active place cells with which to decode position. |
| |
| Returns: |
| pred_pos: Predicted 2d position with shape [batch_size, sequence_length, 2]. |
| """ |
| _, idxs = torch.topk(activation, k=k) |
| pred_pos = self.us[idxs].mean(-2) |
| return pred_pos |
|
|
| def grid_pc(self, pc_outputs, res=32): |
| """Interpolate place cell outputs onto a grid""" |
| coordsx = np.linspace(-self.box_width / 2, self.box_width / 2, res) |
| coordsy = np.linspace(-self.box_height / 2, self.box_height / 2, res) |
| grid_x, grid_y = np.meshgrid(coordsx, coordsy) |
| grid = np.stack([grid_x.ravel(), grid_y.ravel()]).T |
|
|
| |
| pc_outputs = pc_outputs.reshape(-1, self.Np) |
|
|
| T = pc_outputs.shape[0] |
| pc = np.zeros([T, res, res]) |
| for i in range(len(pc_outputs)): |
| gridval = scipy.interpolate.griddata(self.us.cpu(), pc_outputs[i], grid) |
| pc[i] = gridval.reshape([res, res]) |
|
|
| return pc |
|
|
| def compute_covariance(self, res=30): |
| """Compute spatial covariance matrix of place cell outputs""" |
| pos = np.array( |
| np.meshgrid( |
| np.linspace(-self.box_width / 2, self.box_width / 2, res), |
| np.linspace(-self.box_height / 2, self.box_height / 2, res), |
| ) |
| ).T |
|
|
| pos = torch.tensor(pos) |
|
|
| |
| pos = pos.to(self.device) |
|
|
| |
| pc_outputs = self.get_activation(pos).reshape(-1, self.Np).cpu() |
|
|
| C = pc_outputs @ pc_outputs.T |
| Csquare = C.reshape(res, res, res, res) |
|
|
| Cmean = np.zeros([res, res]) |
| for i in range(res): |
| for j in range(res): |
| Cmean += np.roll(np.roll(Csquare[i, j], -i, axis=0), -j, axis=1) |
|
|
| Cmean = np.roll(np.roll(Cmean, res // 2, axis=0), res // 2, axis=1) |
|
|
| return Cmean |
|
|