hw / utils.py
violet1723's picture
Upload folder using huggingface_hub
00c2650 verified
import numpy as np
def generate_run_ID(options):
'''
Create a unique run ID from the most relevant
parameters. Remaining parameters can be found in
params.npy file.
'''
params = [
'steps', str(options.sequence_length),
'batch', str(options.batch_size),
options.RNN_type,
str(options.Ng),
options.activation,
'rf', str(options.place_cell_rf),
'DoG', str(options.DoG),
'periodic', str(options.periodic),
'lr', str(options.learning_rate),
'weight_decay', str(options.weight_decay),
]
separator = '_'
run_ID = separator.join(params)
run_ID = run_ID.replace('.', '')
return run_ID
def get_2d_sort(x1,x2):
"""
Reshapes x1 and x2 into square arrays, and then sorts
them such that x1 increases downward and x2 increases
rightward. Returns the order.
"""
n = int(np.round(np.sqrt(len(x1))))
total_order = x1.argsort()
total_order = total_order.reshape(n,n)
for i in range(n):
row_order = x2[total_order.ravel()].reshape(n,n)[i].argsort()
total_order[i] = total_order[i,row_order]
total_order = total_order.ravel()
return total_order
def dft(N,real=False,scale='sqrtn'):
if not real:
return scipy.linalg.dft(N,scale)
else:
cosines = np.cos(2*np.pi*np.arange(N//2+1)[None,:]/N*np.arange(N)[:,None])
sines = np.sin(2*np.pi*np.arange(1,(N-1)//2+1)[None,:]/N*np.arange(N)[:,None])
if N%2==0:
cosines[:,-1] /= np.sqrt(2)
F = np.concatenate((cosines,sines[:,::-1]),1)
F[:,0] /= np.sqrt(N)
F[:,1:] /= np.sqrt(N/2)
return F
def skaggs_power(Jsort):
F = dft(int(np.sqrt(N)), real=True)
F2d = F[:,None,:,None]*F[None,:,None,:]
F2d_unroll = np.reshape(F2d, (N, N))
F2d_inv = F2d_unroll.conj().T
Jtilde = F2d_inv.dot(Jsort).dot(F2d_unroll)
return (Jtilde[1,1]**2 + Jtilde[-1,-1]**2) / (Jtilde**2).sum()
def skaggs_power_2(Jsort):
J_square = np.reshape(Jsort, (n,n,n,n))
Jmean = np.zeros([n,n])
for i in range(n):
for j in range(n):
Jmean += np.roll(np.roll(J_square[i,j], -i, axis=0), -j, axis=1)
# Jmean[0,0] = np.max(Jmean[1:,1:])
Jmean = np.roll(np.roll(Jmean, n//2, axis=0), n//2, axis=1)
Jtilde = np.real(np.fft.fft2(Jmean))
Jtilde[0,0] = 0
sk_power = Jtilde[1,1]**2 + Jtilde[0,1]**2 + Jtilde[1,0]**2
sk_power += Jtilde[-1,-1]**2 + Jtilde[0,-1]**2 + Jtilde[-1,0]**2
sk_power /= (Jtilde**2).sum()
return sk_power
def calc_err():
inputs, _, pos = next(gen)
pred = model(inputs)
pred_pos = place_cells.get_nearest_cell_pos(pred)
return tf.reduce_mean(tf.sqrt(tf.reduce_sum((pos - pred_pos)**2, axis=-1)))
from visualize import compute_ratemaps, plot_ratemaps
def compute_variance(res, n_avg):
activations, rate_map, g, pos = compute_ratemaps(model, data_manager, options, res=res, n_avg=n_avg)
counts = np.zeros([res,res])
variance = np.zeros([res,res])
x_all = (pos[:,0] + options['box_width']/2) / options['box_width'] * res
y_all = (pos[:,1] + options['box_height']/2) / options['box_height'] * res
for i in tqdm(range(len(g))):
x = int(x_all[i])
y = int(y_all[i])
if x >=0 and x < res and y >=0 and y < res:
counts[x, y] += 1
variance[x, y] += np.linalg.norm(g[i] - activations[:, x, y]) / np.linalg.norm(g[i]) / np.linalg.norm(activations[:,x,y])
for x in range(res):
for y in range(res):
if counts[x, y] > 0:
variance[x, y] /= counts[x, y]
return variance
def load_trained_weights(model, trainer, weight_dir):
''' Load weights stored as a .npy file (for github)'''
# Train for a single step to initialize weights
# trainer.train(n_epochs=1, n_steps=1, save=False)
# Load weights from npy array
weights = np.load(weight_dir, allow_pickle=True)
model.set_weights(weights)
print('Loaded trained weights.')