|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| from __future__ import print_function
|
|
|
| import h5py
|
| import numpy as np
|
| import os
|
| from six.moves import xrange
|
| import tensorflow as tf
|
|
|
| from utils import write_datasets
|
| from synthetic_data_utils import normalize_rates
|
| from synthetic_data_utils import get_train_n_valid_inds, nparray_and_transpose
|
| from synthetic_data_utils import spikify_data, split_list_by_inds
|
|
|
| DATA_DIR = "rnn_synth_data_v1.0"
|
|
|
| flags = tf.app.flags
|
| flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/",
|
| "Directory for saving data.")
|
| flags.DEFINE_string("datafile_name", "itb_rnn",
|
| "Name of data file for input case.")
|
| flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.")
|
| flags.DEFINE_float("T", 1.0, "Time in seconds to generate.")
|
| flags.DEFINE_integer("C", 800, "Number of conditions")
|
| flags.DEFINE_integer("N", 50, "Number of units for the RNN")
|
| flags.DEFINE_float("train_percentage", 4.0/5.0,
|
| "Percentage of train vs validation trials")
|
| flags.DEFINE_integer("nreplications", 5,
|
| "Number of spikifications of the same underlying rates.")
|
| flags.DEFINE_float("tau", 0.025, "Time constant of RNN")
|
| flags.DEFINE_float("dt", 0.010, "Time bin")
|
| flags.DEFINE_float("max_firing_rate", 30.0,
|
| "Map 1.0 of RNN to a spikes per second")
|
| flags.DEFINE_float("u_std", 0.25,
|
| "Std dev of input to integration to bound model")
|
| flags.DEFINE_string("checkpoint_path", "SAMPLE_CHECKPOINT",
|
| """Path to directory with checkpoints of model
|
| trained on integration to bound task. Currently this
|
| is a placeholder which tells the code to grab the
|
| checkpoint that is provided with the code
|
| (in /trained_itb/..). If you have your own checkpoint
|
| you would like to restore, you would point it to
|
| that path.""")
|
| FLAGS = flags.FLAGS
|
|
|
|
|
| class IntegrationToBoundModel:
|
| def __init__(self, N):
|
| scale = 0.8 / float(N**0.5)
|
| self.N = N
|
| self.Wh_nxn = tf.Variable(tf.random_normal([N, N], stddev=scale))
|
| self.b_1xn = tf.Variable(tf.zeros([1, N]))
|
| self.Bu_1xn = tf.Variable(tf.zeros([1, N]))
|
| self.Wro_nxo = tf.Variable(tf.random_normal([N, 1], stddev=scale))
|
| self.bro_o = tf.Variable(tf.zeros([1]))
|
|
|
| def call(self, h_tm1_bxn, u_bx1):
|
| act_t_bxn = tf.matmul(h_tm1_bxn, self.Wh_nxn) + self.b_1xn + u_bx1 * self.Bu_1xn
|
| h_t_bxn = tf.nn.tanh(act_t_bxn)
|
| z_t = tf.nn.xw_plus_b(h_t_bxn, self.Wro_nxo, self.bro_o)
|
| return z_t, h_t_bxn
|
|
|
| def get_data_batch(batch_size, T, rng, u_std):
|
| u_bxt = rng.randn(batch_size, T) * u_std
|
| running_sum_b = np.zeros([batch_size])
|
| labels_bxt = np.zeros([batch_size, T])
|
| for t in xrange(T):
|
| running_sum_b += u_bxt[:, t]
|
| labels_bxt[:, t] += running_sum_b
|
| labels_bxt = np.clip(labels_bxt, -1, 1)
|
| return u_bxt, labels_bxt
|
|
|
|
|
| rng = np.random.RandomState(seed=FLAGS.synth_data_seed)
|
| u_rng = np.random.RandomState(seed=FLAGS.synth_data_seed+1)
|
| T = FLAGS.T
|
| C = FLAGS.C
|
| N = FLAGS.N
|
| nreplications = FLAGS.nreplications
|
| E = nreplications * C
|
| train_percentage = FLAGS.train_percentage
|
| ntimesteps = int(T / FLAGS.dt)
|
| batch_size = 1
|
|
|
| model = IntegrationToBoundModel(N)
|
| inputs_ph_t = [tf.placeholder(tf.float32,
|
| shape=[None, 1]) for _ in range(ntimesteps)]
|
| state = tf.zeros([batch_size, N])
|
| saver = tf.train.Saver()
|
|
|
| P_nxn = rng.randn(N,N) / np.sqrt(N)
|
|
|
|
|
| outputs_t = []
|
| states_t = []
|
|
|
| for inp in inputs_ph_t:
|
| output, state = model.call(state, inp)
|
| outputs_t.append(output)
|
| states_t.append(state)
|
|
|
| with tf.Session() as sess:
|
|
|
| if FLAGS.checkpoint_path == "SAMPLE_CHECKPOINT":
|
| dir_path = os.path.dirname(os.path.realpath(__file__))
|
| model_checkpoint_path = os.path.join(dir_path, "trained_itb/model-65000")
|
| else:
|
| model_checkpoint_path = FLAGS.checkpoint_path
|
| try:
|
| saver.restore(sess, model_checkpoint_path)
|
| print ('Model restored from', model_checkpoint_path)
|
| except:
|
| assert False, ("No checkpoints to restore from, is the path %s correct?"
|
| %model_checkpoint_path)
|
|
|
|
|
| data_e = []
|
| u_e = []
|
| outs_e = []
|
| for c in range(C):
|
| u_1xt, outs_1xt = get_data_batch(batch_size, ntimesteps, u_rng, FLAGS.u_std)
|
|
|
| feed_dict = {}
|
| for t in xrange(ntimesteps):
|
| feed_dict[inputs_ph_t[t]] = np.reshape(u_1xt[:,t], (batch_size,-1))
|
|
|
| states_t_bxn, outputs_t_bxn = sess.run([states_t, outputs_t],
|
| feed_dict=feed_dict)
|
| states_nxt = np.transpose(np.squeeze(np.asarray(states_t_bxn)))
|
| outputs_t_bxn = np.squeeze(np.asarray(outputs_t_bxn))
|
| r_sxt = np.dot(P_nxn, states_nxt)
|
|
|
| for s in xrange(nreplications):
|
| data_e.append(r_sxt)
|
| u_e.append(u_1xt)
|
| outs_e.append(outputs_t_bxn)
|
|
|
| truth_data_e = normalize_rates(data_e, E, N)
|
|
|
| spiking_data_e = spikify_data(truth_data_e, rng, dt=FLAGS.dt,
|
| max_firing_rate=FLAGS.max_firing_rate)
|
| train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage,
|
| nreplications)
|
|
|
| data_train_truth, data_valid_truth = split_list_by_inds(truth_data_e,
|
| train_inds,
|
| valid_inds)
|
| data_train_spiking, data_valid_spiking = split_list_by_inds(spiking_data_e,
|
| train_inds,
|
| valid_inds)
|
|
|
| data_train_truth = nparray_and_transpose(data_train_truth)
|
| data_valid_truth = nparray_and_transpose(data_valid_truth)
|
| data_train_spiking = nparray_and_transpose(data_train_spiking)
|
| data_valid_spiking = nparray_and_transpose(data_valid_spiking)
|
|
|
|
|
| train_inputs_u, valid_inputs_u = split_list_by_inds(u_e,
|
| train_inds,
|
| valid_inds)
|
| train_inputs_u = nparray_and_transpose(train_inputs_u)
|
| valid_inputs_u = nparray_and_transpose(valid_inputs_u)
|
|
|
|
|
| train_outputs_u, valid_outputs_u = split_list_by_inds(outs_e,
|
| train_inds,
|
| valid_inds)
|
| train_outputs_u = np.array(train_outputs_u)
|
| valid_outputs_u = np.array(valid_outputs_u)
|
|
|
|
|
| data = { 'train_truth': data_train_truth,
|
| 'valid_truth': data_valid_truth,
|
| 'train_data' : data_train_spiking,
|
| 'valid_data' : data_valid_spiking,
|
| 'train_percentage' : train_percentage,
|
| 'nreplications' : nreplications,
|
| 'dt' : FLAGS.dt,
|
| 'u_std' : FLAGS.u_std,
|
| 'max_firing_rate': FLAGS.max_firing_rate,
|
| 'train_inputs_u': train_inputs_u,
|
| 'valid_inputs_u': valid_inputs_u,
|
| 'train_outputs_u': train_outputs_u,
|
| 'valid_outputs_u': valid_outputs_u,
|
| 'conversion_factor' : FLAGS.max_firing_rate/(1.0/FLAGS.dt) }
|
|
|
|
|
| datasets = {}
|
| dataset_name = 'dataset_N' + str(N)
|
| datasets[dataset_name] = data
|
|
|
|
|
| write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets)
|
| print ('Saved to ', os.path.join(FLAGS.save_dir,
|
| FLAGS.datafile_name + '_' + dataset_name))
|
|
|