File size: 4,072 Bytes
00c2650 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | import numpy as np
def generate_run_ID(options):
'''
Create a unique run ID from the most relevant
parameters. Remaining parameters can be found in
params.npy file.
'''
params = [
'steps', str(options.sequence_length),
'batch', str(options.batch_size),
options.RNN_type,
str(options.Ng),
options.activation,
'rf', str(options.place_cell_rf),
'DoG', str(options.DoG),
'periodic', str(options.periodic),
'lr', str(options.learning_rate),
'weight_decay', str(options.weight_decay),
]
separator = '_'
run_ID = separator.join(params)
run_ID = run_ID.replace('.', '')
return run_ID
def get_2d_sort(x1,x2):
"""
Reshapes x1 and x2 into square arrays, and then sorts
them such that x1 increases downward and x2 increases
rightward. Returns the order.
"""
n = int(np.round(np.sqrt(len(x1))))
total_order = x1.argsort()
total_order = total_order.reshape(n,n)
for i in range(n):
row_order = x2[total_order.ravel()].reshape(n,n)[i].argsort()
total_order[i] = total_order[i,row_order]
total_order = total_order.ravel()
return total_order
def dft(N,real=False,scale='sqrtn'):
if not real:
return scipy.linalg.dft(N,scale)
else:
cosines = np.cos(2*np.pi*np.arange(N//2+1)[None,:]/N*np.arange(N)[:,None])
sines = np.sin(2*np.pi*np.arange(1,(N-1)//2+1)[None,:]/N*np.arange(N)[:,None])
if N%2==0:
cosines[:,-1] /= np.sqrt(2)
F = np.concatenate((cosines,sines[:,::-1]),1)
F[:,0] /= np.sqrt(N)
F[:,1:] /= np.sqrt(N/2)
return F
def skaggs_power(Jsort):
F = dft(int(np.sqrt(N)), real=True)
F2d = F[:,None,:,None]*F[None,:,None,:]
F2d_unroll = np.reshape(F2d, (N, N))
F2d_inv = F2d_unroll.conj().T
Jtilde = F2d_inv.dot(Jsort).dot(F2d_unroll)
return (Jtilde[1,1]**2 + Jtilde[-1,-1]**2) / (Jtilde**2).sum()
def skaggs_power_2(Jsort):
J_square = np.reshape(Jsort, (n,n,n,n))
Jmean = np.zeros([n,n])
for i in range(n):
for j in range(n):
Jmean += np.roll(np.roll(J_square[i,j], -i, axis=0), -j, axis=1)
# Jmean[0,0] = np.max(Jmean[1:,1:])
Jmean = np.roll(np.roll(Jmean, n//2, axis=0), n//2, axis=1)
Jtilde = np.real(np.fft.fft2(Jmean))
Jtilde[0,0] = 0
sk_power = Jtilde[1,1]**2 + Jtilde[0,1]**2 + Jtilde[1,0]**2
sk_power += Jtilde[-1,-1]**2 + Jtilde[0,-1]**2 + Jtilde[-1,0]**2
sk_power /= (Jtilde**2).sum()
return sk_power
def calc_err():
inputs, _, pos = next(gen)
pred = model(inputs)
pred_pos = place_cells.get_nearest_cell_pos(pred)
return tf.reduce_mean(tf.sqrt(tf.reduce_sum((pos - pred_pos)**2, axis=-1)))
from visualize import compute_ratemaps, plot_ratemaps
def compute_variance(res, n_avg):
activations, rate_map, g, pos = compute_ratemaps(model, data_manager, options, res=res, n_avg=n_avg)
counts = np.zeros([res,res])
variance = np.zeros([res,res])
x_all = (pos[:,0] + options['box_width']/2) / options['box_width'] * res
y_all = (pos[:,1] + options['box_height']/2) / options['box_height'] * res
for i in tqdm(range(len(g))):
x = int(x_all[i])
y = int(y_all[i])
if x >=0 and x < res and y >=0 and y < res:
counts[x, y] += 1
variance[x, y] += np.linalg.norm(g[i] - activations[:, x, y]) / np.linalg.norm(g[i]) / np.linalg.norm(activations[:,x,y])
for x in range(res):
for y in range(res):
if counts[x, y] > 0:
variance[x, y] /= counts[x, y]
return variance
def load_trained_weights(model, trainer, weight_dir):
''' Load weights stored as a .npy file (for github)'''
# Train for a single step to initialize weights
# trainer.train(n_epochs=1, n_steps=1, save=False)
# Load weights from npy array
weights = np.load(weight_dir, allow_pickle=True)
model.set_weights(weights)
print('Loaded trained weights.')
|