| |
| import numpy as np |
| from matplotlib import pyplot as plt |
|
|
| import scipy |
| import scipy.stats |
| from imageio import imsave |
| import cv2 |
|
|
|
|
| def concat_images(images, image_width, spacer_size): |
| """ Concat image horizontally with spacer """ |
| spacer = np.ones([image_width, spacer_size, 4], dtype=np.uint8) * 255 |
| images_with_spacers = [] |
|
|
| image_size = len(images) |
|
|
| for i in range(image_size): |
| images_with_spacers.append(images[i]) |
| if i != image_size - 1: |
| |
| images_with_spacers.append(spacer) |
| ret = np.hstack(images_with_spacers) |
| return ret |
|
|
|
|
| def concat_images_in_rows(images, row_size, image_width, spacer_size=4): |
| """ Concat images in rows """ |
| column_size = len(images) // row_size |
| spacer_h = np.ones([spacer_size, image_width*column_size + (column_size-1)*spacer_size, 4], |
| dtype=np.uint8) * 255 |
|
|
| row_images_with_spacers = [] |
|
|
| for row in range(row_size): |
| row_images = images[column_size*row:column_size*row+column_size] |
| row_concated_images = concat_images(row_images, image_width, spacer_size) |
| row_images_with_spacers.append(row_concated_images) |
|
|
| if row != row_size-1: |
| row_images_with_spacers.append(spacer_h) |
|
|
| ret = np.vstack(row_images_with_spacers) |
| return ret |
|
|
|
|
| def convert_to_colormap(im, cmap): |
| im = cmap(im) |
| im = np.uint8(im * 255) |
| return im |
|
|
|
|
| def rgb(im, cmap='jet', smooth=True): |
| cmap = plt.cm.get_cmap(cmap) |
| np.seterr(invalid='ignore') |
| im = (im - np.min(im)) / (np.max(im) - np.min(im)) |
| if smooth: |
| im = cv2.GaussianBlur(im, (3,3), sigmaX=1, sigmaY=0) |
| im = cmap(im) |
| im = np.uint8(im * 255) |
| return im |
|
|
|
|
| def plot_ratemaps(activations, n_plots, cmap='jet', smooth=True, width=16): |
| images = [rgb(im, cmap, smooth) for im in activations[:n_plots]] |
| rm_fig = concat_images_in_rows(images, n_plots//width, activations.shape[-1]) |
| return rm_fig |
|
|
|
|
| def compute_ratemaps(model, trajectory_generator, options, res=20, n_avg=None, Ng=512, idxs=None, return_raw=False): |
| '''Compute spatial firing fields |
| |
| Args: |
| model: The RNN model |
| trajectory_generator: Generator for test trajectories |
| options: Training options |
| res: Resolution of the rate map grid |
| n_avg: Number of batches to average over |
| Ng: Number of grid cells to analyze |
| idxs: Indices of specific grid cells to analyze |
| return_raw: If True, also return raw activations (g) and positions (pos). |
| Warning: This uses significant memory for large batch_size/n_avg. |
| If False, returns None for g and pos to save memory. |
| |
| Returns: |
| activations: Spatial firing fields [Ng, res, res] |
| rate_map: Flattened rate maps [Ng, res*res] |
| g: Raw activations (None if return_raw=False) |
| pos: Raw positions (None if return_raw=False) |
| ''' |
|
|
| if not n_avg: |
| n_avg = 1000 // options.sequence_length |
|
|
| if not np.any(idxs): |
| idxs = np.arange(Ng) |
| idxs = idxs[:Ng] |
|
|
| |
| if return_raw: |
| g = np.zeros([n_avg, options.batch_size * options.sequence_length, Ng]) |
| pos = np.zeros([n_avg, options.batch_size * options.sequence_length, 2]) |
| else: |
| g = None |
| pos = None |
|
|
| activations = np.zeros([Ng, res, res]) |
| counts = np.zeros([res, res]) |
|
|
| for index in range(n_avg): |
| inputs, pos_batch, _ = trajectory_generator.get_test_batch() |
| g_batch = model.g(inputs).detach().cpu().numpy() |
| |
| pos_batch = np.reshape(pos_batch.cpu(), [-1, 2]) |
| g_batch = g_batch[:,:,idxs].reshape(-1, Ng) |
| |
| if return_raw: |
| g[index] = g_batch |
| pos[index] = pos_batch |
|
|
| x_batch = (pos_batch[:,0] + options.box_width/2) / (options.box_width) * res |
| y_batch = (pos_batch[:,1] + options.box_height/2) / (options.box_height) * res |
|
|
| for i in range(options.batch_size*options.sequence_length): |
| x = x_batch[i] |
| y = y_batch[i] |
| if x >=0 and x < res and y >=0 and y < res: |
| counts[int(x), int(y)] += 1 |
| activations[:, int(x), int(y)] += g_batch[i, :] |
|
|
| for x in range(res): |
| for y in range(res): |
| if counts[x, y] > 0: |
| activations[:, x, y] /= counts[x, y] |
| |
| if return_raw: |
| g = g.reshape([-1, Ng]) |
| pos = pos.reshape([-1, 2]) |
|
|
| |
| |
| rate_map = activations.reshape(Ng, -1) |
|
|
| return activations, rate_map, g, pos |
|
|
|
|
| def save_ratemaps(model, trajectory_generator, options, step, res=20, n_avg=None): |
| if not n_avg: |
| n_avg = 1000 // options.sequence_length |
| activations, rate_map, g, pos = compute_ratemaps(model, trajectory_generator, |
| options, res=res, n_avg=n_avg) |
| rm_fig = plot_ratemaps(activations, n_plots=len(activations)) |
| imdir = options.save_dir + "/" + options.run_ID |
| imsave(imdir + "/" + str(step) + ".png", rm_fig) |
|
|
|
|
| def save_autocorr(sess, model, save_name, trajectory_generator, step, flags): |
| starts = [0.2] * 10 |
| ends = np.linspace(0.4, 1.0, num=10) |
| coord_range=((-1.1, 1.1), (-1.1, 1.1)) |
| masks_parameters = zip(starts, ends.tolist()) |
| latest_epoch_scorer = scores.GridScorer(20, coord_range, masks_parameters) |
| |
| res = dict() |
| index_size = 100 |
| for _ in range(index_size): |
| feed_dict = trajectory_generator.feed_dict(flags.box_width, flags.box_height) |
| mb_res = sess.run({ |
| 'pos_xy': model.target_pos, |
| 'bottleneck': model.g, |
| }, feed_dict=feed_dict) |
| res = utils.concat_dict(res, mb_res) |
| |
| filename = save_name + '/autocorrs_' + str(step) + '.pdf' |
| imdir = flags.save_dir + '/' |
| out = utils.get_scores_and_plot( |
| latest_epoch_scorer, res['pos_xy'], res['bottleneck'], |
| imdir, filename) |
|
|