File size: 6,180 Bytes
00c2650 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | # -*- coding: utf-8 -*-
import numpy as np
from matplotlib import pyplot as plt
import scipy
import scipy.stats
from imageio import imsave
import cv2
def concat_images(images, image_width, spacer_size):
""" Concat image horizontally with spacer """
spacer = np.ones([image_width, spacer_size, 4], dtype=np.uint8) * 255
images_with_spacers = []
image_size = len(images)
for i in range(image_size):
images_with_spacers.append(images[i])
if i != image_size - 1:
# Add spacer
images_with_spacers.append(spacer)
ret = np.hstack(images_with_spacers)
return ret
def concat_images_in_rows(images, row_size, image_width, spacer_size=4):
""" Concat images in rows """
column_size = len(images) // row_size
spacer_h = np.ones([spacer_size, image_width*column_size + (column_size-1)*spacer_size, 4],
dtype=np.uint8) * 255
row_images_with_spacers = []
for row in range(row_size):
row_images = images[column_size*row:column_size*row+column_size]
row_concated_images = concat_images(row_images, image_width, spacer_size)
row_images_with_spacers.append(row_concated_images)
if row != row_size-1:
row_images_with_spacers.append(spacer_h)
ret = np.vstack(row_images_with_spacers)
return ret
def convert_to_colormap(im, cmap):
im = cmap(im)
im = np.uint8(im * 255)
return im
def rgb(im, cmap='jet', smooth=True):
cmap = plt.cm.get_cmap(cmap)
np.seterr(invalid='ignore') # ignore divide by zero err
im = (im - np.min(im)) / (np.max(im) - np.min(im))
if smooth:
im = cv2.GaussianBlur(im, (3,3), sigmaX=1, sigmaY=0)
im = cmap(im)
im = np.uint8(im * 255)
return im
def plot_ratemaps(activations, n_plots, cmap='jet', smooth=True, width=16):
images = [rgb(im, cmap, smooth) for im in activations[:n_plots]]
rm_fig = concat_images_in_rows(images, n_plots//width, activations.shape[-1])
return rm_fig
def compute_ratemaps(model, trajectory_generator, options, res=20, n_avg=None, Ng=512, idxs=None, return_raw=False):
'''Compute spatial firing fields
Args:
model: The RNN model
trajectory_generator: Generator for test trajectories
options: Training options
res: Resolution of the rate map grid
n_avg: Number of batches to average over
Ng: Number of grid cells to analyze
idxs: Indices of specific grid cells to analyze
return_raw: If True, also return raw activations (g) and positions (pos).
Warning: This uses significant memory for large batch_size/n_avg.
If False, returns None for g and pos to save memory.
Returns:
activations: Spatial firing fields [Ng, res, res]
rate_map: Flattened rate maps [Ng, res*res]
g: Raw activations (None if return_raw=False)
pos: Raw positions (None if return_raw=False)
'''
if not n_avg:
n_avg = 1000 // options.sequence_length
if not np.any(idxs):
idxs = np.arange(Ng)
idxs = idxs[:Ng]
# Only allocate large arrays if return_raw is True
if return_raw:
g = np.zeros([n_avg, options.batch_size * options.sequence_length, Ng])
pos = np.zeros([n_avg, options.batch_size * options.sequence_length, 2])
else:
g = None
pos = None
activations = np.zeros([Ng, res, res])
counts = np.zeros([res, res])
for index in range(n_avg):
inputs, pos_batch, _ = trajectory_generator.get_test_batch()
g_batch = model.g(inputs).detach().cpu().numpy()
pos_batch = np.reshape(pos_batch.cpu(), [-1, 2])
g_batch = g_batch[:,:,idxs].reshape(-1, Ng)
if return_raw:
g[index] = g_batch
pos[index] = pos_batch
x_batch = (pos_batch[:,0] + options.box_width/2) / (options.box_width) * res
y_batch = (pos_batch[:,1] + options.box_height/2) / (options.box_height) * res
for i in range(options.batch_size*options.sequence_length):
x = x_batch[i]
y = y_batch[i]
if x >=0 and x < res and y >=0 and y < res:
counts[int(x), int(y)] += 1
activations[:, int(x), int(y)] += g_batch[i, :]
for x in range(res):
for y in range(res):
if counts[x, y] > 0:
activations[:, x, y] /= counts[x, y]
if return_raw:
g = g.reshape([-1, Ng])
pos = pos.reshape([-1, 2])
# # scipy binned_statistic_2d is slightly slower
# activations = scipy.stats.binned_statistic_2d(pos[:,0], pos[:,1], g.T, bins=res)[0]
rate_map = activations.reshape(Ng, -1)
return activations, rate_map, g, pos
def save_ratemaps(model, trajectory_generator, options, step, res=20, n_avg=None):
if not n_avg:
n_avg = 1000 // options.sequence_length
activations, rate_map, g, pos = compute_ratemaps(model, trajectory_generator,
options, res=res, n_avg=n_avg)
rm_fig = plot_ratemaps(activations, n_plots=len(activations))
imdir = options.save_dir + "/" + options.run_ID
imsave(imdir + "/" + str(step) + ".png", rm_fig)
def save_autocorr(sess, model, save_name, trajectory_generator, step, flags):
starts = [0.2] * 10
ends = np.linspace(0.4, 1.0, num=10)
coord_range=((-1.1, 1.1), (-1.1, 1.1))
masks_parameters = zip(starts, ends.tolist())
latest_epoch_scorer = scores.GridScorer(20, coord_range, masks_parameters)
res = dict()
index_size = 100
for _ in range(index_size):
feed_dict = trajectory_generator.feed_dict(flags.box_width, flags.box_height)
mb_res = sess.run({
'pos_xy': model.target_pos,
'bottleneck': model.g,
}, feed_dict=feed_dict)
res = utils.concat_dict(res, mb_res)
filename = save_name + '/autocorrs_' + str(step) + '.pdf'
imdir = flags.save_dir + '/'
out = utils.get_scores_and_plot(
latest_epoch_scorer, res['pos_xy'], res['bottleneck'],
imdir, filename)
|