| import numpy as np |
|
|
|
|
| def generate_run_ID(options): |
| ''' |
| Create a unique run ID from the most relevant |
| parameters. Remaining parameters can be found in |
| params.npy file. |
| ''' |
| params = [ |
| 'steps', str(options.sequence_length), |
| 'batch', str(options.batch_size), |
| options.RNN_type, |
| str(options.Ng), |
| options.activation, |
| 'rf', str(options.place_cell_rf), |
| 'DoG', str(options.DoG), |
| 'periodic', str(options.periodic), |
| 'lr', str(options.learning_rate), |
| 'weight_decay', str(options.weight_decay), |
| ] |
| separator = '_' |
| run_ID = separator.join(params) |
| run_ID = run_ID.replace('.', '') |
|
|
| return run_ID |
|
|
|
|
| def get_2d_sort(x1,x2): |
| """ |
| Reshapes x1 and x2 into square arrays, and then sorts |
| them such that x1 increases downward and x2 increases |
| rightward. Returns the order. |
| """ |
| n = int(np.round(np.sqrt(len(x1)))) |
| total_order = x1.argsort() |
| total_order = total_order.reshape(n,n) |
| for i in range(n): |
| row_order = x2[total_order.ravel()].reshape(n,n)[i].argsort() |
| total_order[i] = total_order[i,row_order] |
| total_order = total_order.ravel() |
| return total_order |
|
|
|
|
| def dft(N,real=False,scale='sqrtn'): |
| if not real: |
| return scipy.linalg.dft(N,scale) |
| else: |
| cosines = np.cos(2*np.pi*np.arange(N//2+1)[None,:]/N*np.arange(N)[:,None]) |
| sines = np.sin(2*np.pi*np.arange(1,(N-1)//2+1)[None,:]/N*np.arange(N)[:,None]) |
| if N%2==0: |
| cosines[:,-1] /= np.sqrt(2) |
| F = np.concatenate((cosines,sines[:,::-1]),1) |
| F[:,0] /= np.sqrt(N) |
| F[:,1:] /= np.sqrt(N/2) |
| return F |
|
|
|
|
| def skaggs_power(Jsort): |
| F = dft(int(np.sqrt(N)), real=True) |
| F2d = F[:,None,:,None]*F[None,:,None,:] |
|
|
| F2d_unroll = np.reshape(F2d, (N, N)) |
|
|
| F2d_inv = F2d_unroll.conj().T |
| Jtilde = F2d_inv.dot(Jsort).dot(F2d_unroll) |
|
|
| return (Jtilde[1,1]**2 + Jtilde[-1,-1]**2) / (Jtilde**2).sum() |
|
|
|
|
| def skaggs_power_2(Jsort): |
| J_square = np.reshape(Jsort, (n,n,n,n)) |
| Jmean = np.zeros([n,n]) |
| for i in range(n): |
| for j in range(n): |
| Jmean += np.roll(np.roll(J_square[i,j], -i, axis=0), -j, axis=1) |
|
|
| |
| Jmean = np.roll(np.roll(Jmean, n//2, axis=0), n//2, axis=1) |
| Jtilde = np.real(np.fft.fft2(Jmean)) |
| |
| Jtilde[0,0] = 0 |
| sk_power = Jtilde[1,1]**2 + Jtilde[0,1]**2 + Jtilde[1,0]**2 |
| sk_power += Jtilde[-1,-1]**2 + Jtilde[0,-1]**2 + Jtilde[-1,0]**2 |
| sk_power /= (Jtilde**2).sum() |
| |
| return sk_power |
|
|
|
|
| def calc_err(): |
| inputs, _, pos = next(gen) |
| pred = model(inputs) |
| pred_pos = place_cells.get_nearest_cell_pos(pred) |
| return tf.reduce_mean(tf.sqrt(tf.reduce_sum((pos - pred_pos)**2, axis=-1))) |
|
|
| from visualize import compute_ratemaps, plot_ratemaps |
|
|
|
|
| def compute_variance(res, n_avg): |
| |
| activations, rate_map, g, pos = compute_ratemaps(model, data_manager, options, res=res, n_avg=n_avg) |
|
|
| counts = np.zeros([res,res]) |
| variance = np.zeros([res,res]) |
|
|
| x_all = (pos[:,0] + options['box_width']/2) / options['box_width'] * res |
| y_all = (pos[:,1] + options['box_height']/2) / options['box_height'] * res |
| for i in tqdm(range(len(g))): |
| x = int(x_all[i]) |
| y = int(y_all[i]) |
| if x >=0 and x < res and y >=0 and y < res: |
| counts[x, y] += 1 |
| variance[x, y] += np.linalg.norm(g[i] - activations[:, x, y]) / np.linalg.norm(g[i]) / np.linalg.norm(activations[:,x,y]) |
|
|
| for x in range(res): |
| for y in range(res): |
| if counts[x, y] > 0: |
| variance[x, y] /= counts[x, y] |
| |
| return variance |
|
|
|
|
| def load_trained_weights(model, trainer, weight_dir): |
| ''' Load weights stored as a .npy file (for github)''' |
|
|
| |
| |
|
|
| |
| weights = np.load(weight_dir, allow_pickle=True) |
| model.set_weights(weights) |
| print('Loaded trained weights.') |
|
|
|
|