|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| from __future__ import print_function
|
|
|
| import h5py
|
| import numpy as np
|
| import os
|
| import tensorflow as tf
|
|
|
| from utils import write_datasets
|
| from synthetic_data_utils import add_alignment_projections, generate_data
|
| from synthetic_data_utils import generate_rnn, get_train_n_valid_inds
|
| from synthetic_data_utils import nparray_and_transpose
|
| from synthetic_data_utils import spikify_data, gaussify_data, split_list_by_inds
|
| import matplotlib
|
| import matplotlib.pyplot as plt
|
| import scipy.signal
|
|
|
| matplotlib.rcParams['image.interpolation'] = 'nearest'
|
| DATA_DIR = "rnn_synth_data_v1.0"
|
|
|
| flags = tf.app.flags
|
| flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
|
| "Directory for saving data.")
|
| flags.DEFINE_string("datafile_name", "thits_data",
|
| "Name of data file for input case.")
|
| flags.DEFINE_string("noise_type", "poisson", "Noise type for data.")
|
| flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
|
| flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
|
| flags.DEFINE_integer("C", 100, "Number of conditions")
|
| flags.DEFINE_integer("N", 50, "Number of units for the RNN")
|
| flags.DEFINE_integer("S", 50, "Number of sampled units from RNN")
|
| flags.DEFINE_integer("npcs", 10, "Number of PCS for multi-session case.")
|
| flags.DEFINE_float("train_percentage", 4.0/5.0,
|
| "Percentage of train vs validation trials")
|
| flags.DEFINE_integer("nreplications", 40,
|
| "Number of noise replications of the same underlying rates.")
|
| flags.DEFINE_float("g", 1.5, "Complexity of dynamics")
|
| flags.DEFINE_float("x0_std", 1.0,
|
| "Volume from which to pull initial conditions (affects diversity of dynamics.")
|
| flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
|
| flags.DEFINE_float("dt", 0.010, "Time bin")
|
| flags.DEFINE_float("input_magnitude", 20.0,
|
| "For the input case, what is the value of the input?")
|
| flags.DEFINE_float("max_firing_rate", 30.0, "Map 1.0 of RNN to a spikes per second")
|
| FLAGS = flags.FLAGS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| rng = np.random.RandomState(seed=FLAGS.synth_data_seed)
|
| T = FLAGS.T
|
| C = FLAGS.C
|
| N = FLAGS.N
|
| S = FLAGS.S
|
| input_magnitude = FLAGS.input_magnitude
|
| nreplications = FLAGS.nreplications
|
| E = nreplications * C
|
|
|
|
|
| ndatasets = N/S
|
| train_percentage = FLAGS.train_percentage
|
| ntime_steps = int(T / FLAGS.dt)
|
|
|
|
|
| rnn = generate_rnn(rng, N, FLAGS.g, FLAGS.tau, FLAGS.dt, FLAGS.max_firing_rate)
|
|
|
|
|
| if N == 50:
|
| assert abs(rnn['W'][0,0] - 0.06239899) < 1e-8, 'Error in random seed?'
|
| rem_check = nreplications * train_percentage
|
| assert abs(rem_check - int(rem_check)) < 1e-8, \
|
| 'Train percentage * nreplications should be integral number.'
|
|
|
|
|
|
|
|
|
|
|
| condition_number = 0
|
| x0s = []
|
| condition_labels = []
|
| for c in range(C):
|
| x0 = FLAGS.x0_std * rng.randn(N, 1)
|
| x0s.append(np.tile(x0, nreplications))
|
|
|
| for ns in range(nreplications):
|
| condition_labels.append(condition_number)
|
| condition_number += 1
|
| x0s = np.concatenate(x0s, axis=1)
|
|
|
|
|
| datasets = {}
|
| for n in range(ndatasets):
|
| print(n+1, " of ", ndatasets)
|
|
|
|
|
|
|
|
|
| dataset_name = 'dataset_N' + str(N) + '_S' + str(S)
|
| if S < N:
|
| dataset_name += '_n' + str(n+1)
|
|
|
|
|
|
|
|
|
| P_sxn = np.eye(S,N)
|
| for m in range(n):
|
| P_sxn = np.roll(P_sxn, S, axis=1)
|
|
|
| if input_magnitude > 0.0:
|
|
|
| input_times = rng.choice(int(ntime_steps/2), size=[E]) + int(ntime_steps/4)
|
| else:
|
| input_times = None
|
|
|
| rates, x0s, inputs = \
|
| generate_data(rnn, T=T, E=E, x0s=x0s, P_sxn=P_sxn,
|
| input_magnitude=input_magnitude,
|
| input_times=input_times)
|
|
|
| if FLAGS.noise_type == "poisson":
|
| noisy_data = spikify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
|
| elif FLAGS.noise_type == "gaussian":
|
| noisy_data = gaussify_data(rates, rng, rnn['dt'], rnn['max_firing_rate'])
|
| else:
|
| raise ValueError("Only noise types supported are poisson or gaussian")
|
|
|
|
|
| train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
|
| nreplications)
|
|
|
|
|
| rates_train, rates_valid = \
|
| split_list_by_inds(rates, train_inds, valid_inds)
|
| noisy_data_train, noisy_data_valid = \
|
| split_list_by_inds(noisy_data, train_inds, valid_inds)
|
| input_train, inputs_valid = \
|
| split_list_by_inds(inputs, train_inds, valid_inds)
|
| condition_labels_train, condition_labels_valid = \
|
| split_list_by_inds(condition_labels, train_inds, valid_inds)
|
| input_times_train, input_times_valid = \
|
| split_list_by_inds(input_times, train_inds, valid_inds)
|
|
|
|
|
| rates_train = nparray_and_transpose(rates_train)
|
| rates_valid = nparray_and_transpose(rates_valid)
|
| noisy_data_train = nparray_and_transpose(noisy_data_train)
|
| noisy_data_valid = nparray_and_transpose(noisy_data_valid)
|
| input_train = nparray_and_transpose(input_train)
|
| inputs_valid = nparray_and_transpose(inputs_valid)
|
|
|
|
|
|
|
|
|
| data = {'train_truth': rates_train,
|
| 'valid_truth': rates_valid,
|
| 'input_train_truth' : input_train,
|
| 'input_valid_truth' : inputs_valid,
|
| 'train_data' : noisy_data_train,
|
| 'valid_data' : noisy_data_valid,
|
| 'train_percentage' : train_percentage,
|
| 'nreplications' : nreplications,
|
| 'dt' : rnn['dt'],
|
| 'input_magnitude' : input_magnitude,
|
| 'input_times_train' : input_times_train,
|
| 'input_times_valid' : input_times_valid,
|
| 'P_sxn' : P_sxn,
|
| 'condition_labels_train' : condition_labels_train,
|
| 'condition_labels_valid' : condition_labels_valid,
|
| 'conversion_factor': 1.0 / rnn['conversion_factor']}
|
| datasets[dataset_name] = data
|
|
|
| if S < N:
|
|
|
|
|
|
|
| datasets = add_alignment_projections(datasets, npcs=FLAGS.npcs)
|
|
|
|
|
| write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets)
|
|
|