hw / place_cells.py
violet1723's picture
Upload folder using huggingface_hub
00c2650 verified
# -*- coding: utf-8 -*-
import numpy as np
import torch
import scipy
class PlaceCells(object):
def __init__(self, options, us=None):
self.Np = options.Np
self.sigma = options.place_cell_rf
self.surround_scale = options.surround_scale
self.box_width = options.box_width
self.box_height = options.box_height
self.is_periodic = options.periodic
self.DoG = options.DoG
self.device = options.device
self.softmax = torch.nn.Softmax(dim=-1)
# Randomly tile place cell centers across environment
usx = np.random.uniform(-self.box_width / 2, self.box_width / 2, (self.Np,))
usy = np.random.uniform(-self.box_width / 2, self.box_width / 2, (self.Np,))
self.us = torch.tensor(np.vstack([usx, usy]).T)
# If using a GPU, put on GPU
self.us = self.us.to(self.device)
# self.us = torch.tensor(np.load('models/example_pc_centers.npy')).cuda()
def get_activation(self, pos):
"""
Get place cell activations for a given position.
Args:
pos: 2d position of shape [batch_size, sequence_length, 2].
Returns:
outputs: Place cell activations with shape [batch_size, sequence_length, Np].
"""
d = torch.abs(pos[:, :, None, :] - self.us[None, None, ...]).float()
if self.is_periodic:
dx = d[:, :, :, 0]
dy = d[:, :, :, 1]
dx = torch.minimum(dx, self.box_width - dx)
dy = torch.minimum(dy, self.box_height - dy)
d = torch.stack([dx, dy], axis=-1)
norm2 = (d**2).sum(-1)
# Normalize place cell outputs with prefactor alpha=1/2/np.pi/self.sigma**2,
# or, simply normalize with softmax, which yields same normalization on
# average and seems to speed up training.
outputs = self.softmax(-norm2 / (2 * self.sigma**2))
if self.DoG:
# Again, normalize with prefactor
# beta=1/2/np.pi/self.sigma**2/self.surround_scale, or use softmax.
outputs -= self.softmax(-norm2 / (2 * self.surround_scale * self.sigma**2))
# Shift and scale outputs so that they lie in [0,1].
min_output, _ = outputs.min(-1, keepdims=True)
outputs += torch.abs(min_output)
outputs /= outputs.sum(-1, keepdims=True)
return outputs
def get_nearest_cell_pos(self, activation, k=3):
"""
Decode position using centers of k maximally active place cells.
Args:
activation: Place cell activations of shape [batch_size, sequence_length, Np].
k: Number of maximally active place cells with which to decode position.
Returns:
pred_pos: Predicted 2d position with shape [batch_size, sequence_length, 2].
"""
_, idxs = torch.topk(activation, k=k)
pred_pos = self.us[idxs].mean(-2)
return pred_pos
def grid_pc(self, pc_outputs, res=32):
"""Interpolate place cell outputs onto a grid"""
coordsx = np.linspace(-self.box_width / 2, self.box_width / 2, res)
coordsy = np.linspace(-self.box_height / 2, self.box_height / 2, res)
grid_x, grid_y = np.meshgrid(coordsx, coordsy)
grid = np.stack([grid_x.ravel(), grid_y.ravel()]).T
# Convert to numpy
pc_outputs = pc_outputs.reshape(-1, self.Np)
T = pc_outputs.shape[0] # T vs transpose? What is T? (dim's?)
pc = np.zeros([T, res, res])
for i in range(len(pc_outputs)):
gridval = scipy.interpolate.griddata(self.us.cpu(), pc_outputs[i], grid)
pc[i] = gridval.reshape([res, res])
return pc
def compute_covariance(self, res=30):
"""Compute spatial covariance matrix of place cell outputs"""
pos = np.array(
np.meshgrid(
np.linspace(-self.box_width / 2, self.box_width / 2, res),
np.linspace(-self.box_height / 2, self.box_height / 2, res),
)
).T
pos = torch.tensor(pos)
# Put on GPU if available
pos = pos.to(self.device)
# Maybe specify dimensions here again?
pc_outputs = self.get_activation(pos).reshape(-1, self.Np).cpu()
C = pc_outputs @ pc_outputs.T
Csquare = C.reshape(res, res, res, res)
Cmean = np.zeros([res, res])
for i in range(res):
for j in range(res):
Cmean += np.roll(np.roll(Csquare[i, j], -i, axis=0), -j, axis=1)
Cmean = np.roll(np.roll(Cmean, res // 2, axis=0), res // 2, axis=1)
return Cmean