|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """
|
| LFADS - Latent Factor Analysis via Dynamical Systems.
|
|
|
| LFADS is an unsupervised method to decompose time series data into
|
| various factors, such as an initial condition, a generative
|
| dynamical system, control inputs to that generator, and a low
|
| dimensional description of the observed data, called the factors.
|
| Additionally, the observations have a noise model (in this case
|
| Poisson), so a denoised version of the observations is also created
|
| (e.g. underlying rates of a Poisson distribution given the observed
|
| event counts).
|
|
|
| The main data structure being passed around is a dataset. This is a dictionary
|
| of data dictionaries.
|
|
|
| DATASET: The top level dictionary is simply name (string -> dictionary).
|
| The nested dictionary is the DATA DICTIONARY, which has the following keys:
|
| 'train_data' and 'valid_data', whose values are the corresponding training
|
| and validation data with shape
|
| ExTxD, E - # examples, T - # time steps, D - # dimensions in data.
|
| The data dictionary also has a few more keys:
|
| 'train_ext_input' and 'valid_ext_input', if there are know external inputs
|
| to the system being modeled, these take on dimensions:
|
| ExTxI, E - # examples, T - # time steps, I = # dimensions in input.
|
| 'alignment_matrix_cxf' - If you are using multiple days data, it's possible
|
| that one can align the channels (see manuscript). If so each dataset will
|
| contain this matrix, which will be used for both the input adapter and the
|
| output adapter for each dataset. These matrices, if provided, must be of
|
| size [data_dim x factors] where data_dim is the number of neurons recorded
|
| on that day, and factors is chosen and set through the '--factors' flag.
|
| 'alignment_bias_c' - See alignment_matrix_cxf. This bias will used to
|
| the offset for the alignment transformation. It will *subtract* off the
|
| bias from the data, so pca style inits can align factors across sessions.
|
|
|
|
|
| If one runs LFADS on data where the true rates are known for some trials,
|
| (say simulated, testing data, as in the example shipped with the paper), then
|
| one can add three more fields for plotting purposes. These are 'train_truth'
|
| and 'valid_truth', and 'conversion_factor'. These have the same dimensions as
|
| 'train_data', and 'valid_data' but represent the underlying rates of the
|
| observations. Finally, if one needs to convert scale for plotting the true
|
| underlying firing rates, there is the 'conversion_factor' key.
|
| """
|
|
|
| from __future__ import absolute_import
|
| from __future__ import division
|
| from __future__ import print_function
|
|
|
|
|
| import numpy as np
|
| import os
|
| import tensorflow as tf
|
| from distributions import LearnableDiagonalGaussian, DiagonalGaussianFromInput
|
| from distributions import diag_gaussian_log_likelihood
|
| from distributions import KLCost_GaussianGaussian, Poisson
|
| from distributions import LearnableAutoRegressive1Prior
|
| from distributions import KLCost_GaussianGaussianProcessSampled
|
|
|
| from utils import init_linear, linear, list_t_bxn_to_tensor_bxtxn, write_data
|
| from utils import log_sum_exp, flatten
|
| from plot_lfads import plot_lfads
|
|
|
|
|
| class GRU(object):
|
| """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).
|
|
|
| """
|
| def __init__(self, num_units, forget_bias=1.0, weight_scale=1.0,
|
| clip_value=np.inf, collections=None):
|
| """Create a GRU object.
|
|
|
| Args:
|
| num_units: Number of units in the GRU.
|
| forget_bias (optional): Hack to help learning.
|
| weight_scale (optional): Weights are scaled by ws/sqrt(#inputs), with
|
| ws being the weight scale.
|
| clip_value (optional): If the recurrent values grow above this value,
|
| clip them.
|
| collections (optional): List of additional collections variables should
|
| belong to.
|
| """
|
| self._num_units = num_units
|
| self._forget_bias = forget_bias
|
| self._weight_scale = weight_scale
|
| self._clip_value = clip_value
|
| self._collections = collections
|
|
|
| @property
|
| def state_size(self):
|
| return self._num_units
|
|
|
| @property
|
| def output_size(self):
|
| return self._num_units
|
|
|
| @property
|
| def state_multiplier(self):
|
| return 1
|
|
|
| def output_from_state(self, state):
|
| """Return the output portion of the state."""
|
| return state
|
|
|
| def __call__(self, inputs, state, scope=None):
|
| """Gated recurrent unit (GRU) function.
|
|
|
| Args:
|
| inputs: A 2D batch x input_dim tensor of inputs.
|
| state: The previous state from the last time step.
|
| scope (optional): TF variable scope for defined GRU variables.
|
|
|
| Returns:
|
| A tuple (state, state), where state is the newly computed state at time t.
|
| It is returned twice to respect an interface that works for LSTMs.
|
| """
|
|
|
| x = inputs
|
| h = state
|
| if inputs is not None:
|
| xh = tf.concat(axis=1, values=[x, h])
|
| else:
|
| xh = h
|
|
|
| with tf.variable_scope(scope or type(self).__name__):
|
| with tf.variable_scope("Gates"):
|
|
|
| r, u = tf.split(axis=1, num_or_size_splits=2, value=linear(xh,
|
| 2 * self._num_units,
|
| alpha=self._weight_scale,
|
| name="xh_2_ru",
|
| collections=self._collections))
|
| r, u = tf.sigmoid(r), tf.sigmoid(u + self._forget_bias)
|
| with tf.variable_scope("Candidate"):
|
| xrh = tf.concat(axis=1, values=[x, r * h])
|
| c = tf.tanh(linear(xrh, self._num_units, name="xrh_2_c",
|
| collections=self._collections))
|
| new_h = u * h + (1 - u) * c
|
| new_h = tf.clip_by_value(new_h, -self._clip_value, self._clip_value)
|
|
|
| return new_h, new_h
|
|
|
|
|
| class GenGRU(object):
|
| """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).
|
|
|
| This version is specialized for the generator, but isn't as fast, so
|
| we have two. Note this allows for l2 regularization on the recurrent
|
| weights, but also implicitly rescales the inputs via the 1/sqrt(input)
|
| scaling in the linear helper routine to be large magnitude, if there are
|
| fewer inputs than recurrent state.
|
|
|
| """
|
| def __init__(self, num_units, forget_bias=1.0,
|
| input_weight_scale=1.0, rec_weight_scale=1.0, clip_value=np.inf,
|
| input_collections=None, recurrent_collections=None):
|
| """Create a GRU object.
|
|
|
| Args:
|
| num_units: Number of units in the GRU.
|
| forget_bias (optional): Hack to help learning.
|
| input_weight_scale (optional): Weights are scaled ws/sqrt(#inputs), with
|
| ws being the weight scale.
|
| rec_weight_scale (optional): Weights are scaled ws/sqrt(#inputs),
|
| with ws being the weight scale.
|
| clip_value (optional): If the recurrent values grow above this value,
|
| clip them.
|
| input_collections (optional): List of additional collections variables
|
| that input->rec weights should belong to.
|
| recurrent_collections (optional): List of additional collections variables
|
| that rec->rec weights should belong to.
|
| """
|
| self._num_units = num_units
|
| self._forget_bias = forget_bias
|
| self._input_weight_scale = input_weight_scale
|
| self._rec_weight_scale = rec_weight_scale
|
| self._clip_value = clip_value
|
| self._input_collections = input_collections
|
| self._rec_collections = recurrent_collections
|
|
|
| @property
|
| def state_size(self):
|
| return self._num_units
|
|
|
| @property
|
| def output_size(self):
|
| return self._num_units
|
|
|
| @property
|
| def state_multiplier(self):
|
| return 1
|
|
|
| def output_from_state(self, state):
|
| """Return the output portion of the state."""
|
| return state
|
|
|
| def __call__(self, inputs, state, scope=None):
|
| """Gated recurrent unit (GRU) function.
|
|
|
| Args:
|
| inputs: A 2D batch x input_dim tensor of inputs.
|
| state: The previous state from the last time step.
|
| scope (optional): TF variable scope for defined GRU variables.
|
|
|
| Returns:
|
| A tuple (state, state), where state is the newly computed state at time t.
|
| It is returned twice to respect an interface that works for LSTMs.
|
| """
|
|
|
| x = inputs
|
| h = state
|
| with tf.variable_scope(scope or type(self).__name__):
|
| with tf.variable_scope("Gates"):
|
|
|
| r_x = u_x = 0.0
|
| if x is not None:
|
| r_x, u_x = tf.split(axis=1, num_or_size_splits=2, value=linear(x,
|
| 2 * self._num_units,
|
| alpha=self._input_weight_scale,
|
| do_bias=False,
|
| name="x_2_ru",
|
| normalized=False,
|
| collections=self._input_collections))
|
|
|
| r_h, u_h = tf.split(axis=1, num_or_size_splits=2, value=linear(h,
|
| 2 * self._num_units,
|
| do_bias=True,
|
| alpha=self._rec_weight_scale,
|
| name="h_2_ru",
|
| collections=self._rec_collections))
|
| r = r_x + r_h
|
| u = u_x + u_h
|
| r, u = tf.sigmoid(r), tf.sigmoid(u + self._forget_bias)
|
|
|
| with tf.variable_scope("Candidate"):
|
| c_x = 0.0
|
| if x is not None:
|
| c_x = linear(x, self._num_units, name="x_2_c", do_bias=False,
|
| alpha=self._input_weight_scale,
|
| normalized=False,
|
| collections=self._input_collections)
|
| c_rh = linear(r*h, self._num_units, name="rh_2_c", do_bias=True,
|
| alpha=self._rec_weight_scale,
|
| collections=self._rec_collections)
|
| c = tf.tanh(c_x + c_rh)
|
|
|
| new_h = u * h + (1 - u) * c
|
| new_h = tf.clip_by_value(new_h, -self._clip_value, self._clip_value)
|
|
|
| return new_h, new_h
|
|
|
|
|
| class LFADS(object):
|
| """LFADS - Latent Factor Analysis via Dynamical Systems.
|
|
|
| LFADS is an unsupervised method to decompose time series data into
|
| various factors, such as an initial condition, a generative
|
| dynamical system, inferred inputs to that generator, and a low
|
| dimensional description of the observed data, called the factors.
|
| Additionally, the observations have a noise model (in this case
|
| Poisson), so a denoised version of the observations is also created
|
| (e.g. underlying rates of a Poisson distribution given the observed
|
| event counts).
|
| """
|
|
|
| def __init__(self, hps, kind="train", datasets=None):
|
| """Create an LFADS model.
|
|
|
| train - a model for training, sampling of posteriors is used
|
| posterior_sample_and_average - sample from the posterior, this is used
|
| for evaluating the expected value of the outputs of LFADS, given a
|
| specific input, by averaging over multiple samples from the approx
|
| posterior. Also used for the lower bound on the negative
|
| log-likelihood using IWAE error (Importance Weighed Auto-encoder).
|
| This is the denoising operation.
|
| prior_sample - a model for generation - sampling from priors is used
|
|
|
| Args:
|
| hps: The dictionary of hyper parameters.
|
| kind: The type of model to build (see above).
|
| datasets: A dictionary of named data_dictionaries, see top of lfads.py
|
| """
|
| print("Building graph...")
|
| all_kinds = ['train', 'posterior_sample_and_average', 'posterior_push_mean',
|
| 'prior_sample']
|
| assert kind in all_kinds, 'Wrong kind'
|
| if hps.feedback_factors_or_rates == "rates":
|
| assert len(hps.dataset_names) == 1, \
|
| "Multiple datasets not supported for rate feedback."
|
| num_steps = hps.num_steps
|
| ic_dim = hps.ic_dim
|
| co_dim = hps.co_dim
|
| ext_input_dim = hps.ext_input_dim
|
| cell_class = GRU
|
| gen_cell_class = GenGRU
|
|
|
| def makelambda(v):
|
| return lambda: v
|
|
|
|
|
|
|
| self.dataName = tf.placeholder(tf.string, shape=())
|
|
|
|
|
|
|
| if hps.output_dist == 'poisson':
|
|
|
| assert np.issubdtype(
|
| datasets[hps.dataset_names[0]]['train_data'].dtype, int), \
|
| "Data dtype must be int for poisson output distribution"
|
| data_dtype = tf.int32
|
| elif hps.output_dist == 'gaussian':
|
| assert np.issubdtype(
|
| datasets[hps.dataset_names[0]]['train_data'].dtype, float), \
|
| "Data dtype must be float for gaussian output dsitribution"
|
| data_dtype = tf.float32
|
| else:
|
| assert False, "NIY"
|
| self.dataset_ph = dataset_ph = tf.placeholder(data_dtype,
|
| [None, num_steps, None],
|
| name="data")
|
| self.train_step = tf.get_variable("global_step", [], tf.int64,
|
| tf.zeros_initializer(),
|
| trainable=False)
|
| self.hps = hps
|
| ndatasets = hps.ndatasets
|
| factors_dim = hps.factors_dim
|
| self.preds = preds = [None] * ndatasets
|
| self.fns_in_fac_Ws = fns_in_fac_Ws = [None] * ndatasets
|
| self.fns_in_fatcor_bs = fns_in_fac_bs = [None] * ndatasets
|
| self.fns_out_fac_Ws = fns_out_fac_Ws = [None] * ndatasets
|
| self.fns_out_fac_bs = fns_out_fac_bs = [None] * ndatasets
|
| self.datasetNames = dataset_names = hps.dataset_names
|
| self.ext_inputs = ext_inputs = None
|
|
|
| if len(dataset_names) == 1:
|
| if 'alignment_matrix_cxf' in datasets[dataset_names[0]].keys():
|
| used_in_factors_dim = factors_dim
|
| in_identity_if_poss = False
|
| else:
|
| used_in_factors_dim = hps.dataset_dims[dataset_names[0]]
|
| in_identity_if_poss = True
|
| else:
|
| used_in_factors_dim = factors_dim
|
| in_identity_if_poss = False
|
|
|
| for d, name in enumerate(dataset_names):
|
| data_dim = hps.dataset_dims[name]
|
| in_mat_cxf = None
|
| in_bias_1xf = None
|
| align_bias_1xc = None
|
|
|
| if datasets and 'alignment_matrix_cxf' in datasets[name].keys():
|
| dataset = datasets[name]
|
| if hps.do_train_readin:
|
| print("Initializing trainable readin matrix with alignment matrix" \
|
| " provided for dataset:", name)
|
| else:
|
| print("Setting non-trainable readin matrix to alignment matrix" \
|
| " provided for dataset:", name)
|
| in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
|
| if in_mat_cxf.shape != (data_dim, factors_dim):
|
| raise ValueError("""Alignment matrix must have dimensions %d x %d
|
| (data_dim x factors_dim), but currently has %d x %d."""%
|
| (data_dim, factors_dim, in_mat_cxf.shape[0],
|
| in_mat_cxf.shape[1]))
|
| if datasets and 'alignment_bias_c' in datasets[name].keys():
|
| dataset = datasets[name]
|
| if hps.do_train_readin:
|
| print("Initializing trainable readin bias with alignment bias " \
|
| "provided for dataset:", name)
|
| else:
|
| print("Setting non-trainable readin bias to alignment bias " \
|
| "provided for dataset:", name)
|
| align_bias_c = dataset['alignment_bias_c'].astype(np.float32)
|
| align_bias_1xc = np.expand_dims(align_bias_c, axis=0)
|
| if align_bias_1xc.shape[1] != data_dim:
|
| raise ValueError("""Alignment bias must have dimensions %d
|
| (data_dim), but currently has %d."""%
|
| (data_dim, in_mat_cxf.shape[0]))
|
| if in_mat_cxf is not None and align_bias_1xc is not None:
|
|
|
|
|
|
|
| in_bias_1xf = -np.dot(align_bias_1xc, in_mat_cxf)
|
|
|
| if hps.do_train_readin:
|
|
|
|
|
|
|
| collections_readin=['IO_transformations']
|
| else:
|
| collections_readin=None
|
|
|
| in_fac_lin = init_linear(data_dim, used_in_factors_dim,
|
| do_bias=True,
|
| mat_init_value=in_mat_cxf,
|
| bias_init_value=in_bias_1xf,
|
| identity_if_possible=in_identity_if_poss,
|
| normalized=False, name="x_2_infac_"+name,
|
| collections=collections_readin,
|
| trainable=hps.do_train_readin)
|
| in_fac_W, in_fac_b = in_fac_lin
|
| fns_in_fac_Ws[d] = makelambda(in_fac_W)
|
| fns_in_fac_bs[d] = makelambda(in_fac_b)
|
|
|
| with tf.variable_scope("glm"):
|
| out_identity_if_poss = False
|
| if len(dataset_names) == 1 and \
|
| factors_dim == hps.dataset_dims[dataset_names[0]]:
|
| out_identity_if_poss = True
|
| for d, name in enumerate(dataset_names):
|
| data_dim = hps.dataset_dims[name]
|
| in_mat_cxf = None
|
| if datasets and 'alignment_matrix_cxf' in datasets[name].keys():
|
| dataset = datasets[name]
|
| in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
|
|
|
| if datasets and 'alignment_bias_c' in datasets[name].keys():
|
| dataset = datasets[name]
|
| align_bias_c = dataset['alignment_bias_c'].astype(np.float32)
|
| align_bias_1xc = np.expand_dims(align_bias_c, axis=0)
|
|
|
| out_mat_fxc = None
|
| out_bias_1xc = None
|
| if in_mat_cxf is not None:
|
| out_mat_fxc = in_mat_cxf.T
|
| if align_bias_1xc is not None:
|
| out_bias_1xc = align_bias_1xc
|
|
|
| if hps.output_dist == 'poisson':
|
| out_fac_lin = init_linear(factors_dim, data_dim, do_bias=True,
|
| mat_init_value=out_mat_fxc,
|
| bias_init_value=out_bias_1xc,
|
| identity_if_possible=out_identity_if_poss,
|
| normalized=False,
|
| name="fac_2_logrates_"+name,
|
| collections=['IO_transformations'])
|
| out_fac_W, out_fac_b = out_fac_lin
|
|
|
| elif hps.output_dist == 'gaussian':
|
| out_fac_lin_mean = \
|
| init_linear(factors_dim, data_dim, do_bias=True,
|
| mat_init_value=out_mat_fxc,
|
| bias_init_value=out_bias_1xc,
|
| normalized=False,
|
| name="fac_2_means_"+name,
|
| collections=['IO_transformations'])
|
| out_fac_W_mean, out_fac_b_mean = out_fac_lin_mean
|
|
|
| mat_init_value = np.zeros([factors_dim, data_dim]).astype(np.float32)
|
| bias_init_value = np.ones([1, data_dim]).astype(np.float32)
|
| out_fac_lin_logvar = \
|
| init_linear(factors_dim, data_dim, do_bias=True,
|
| mat_init_value=mat_init_value,
|
| bias_init_value=bias_init_value,
|
| normalized=False,
|
| name="fac_2_logvars_"+name,
|
| collections=['IO_transformations'])
|
| out_fac_W_mean, out_fac_b_mean = out_fac_lin_mean
|
| out_fac_W_logvar, out_fac_b_logvar = out_fac_lin_logvar
|
| out_fac_W = tf.concat(
|
| axis=1, values=[out_fac_W_mean, out_fac_W_logvar])
|
| out_fac_b = tf.concat(
|
| axis=1, values=[out_fac_b_mean, out_fac_b_logvar])
|
| else:
|
| assert False, "NIY"
|
|
|
| preds[d] = tf.equal(tf.constant(name), self.dataName)
|
| data_dim = hps.dataset_dims[name]
|
| fns_out_fac_Ws[d] = makelambda(out_fac_W)
|
| fns_out_fac_bs[d] = makelambda(out_fac_b)
|
|
|
| pf_pairs_in_fac_Ws = zip(preds, fns_in_fac_Ws)
|
| pf_pairs_in_fac_bs = zip(preds, fns_in_fac_bs)
|
| pf_pairs_out_fac_Ws = zip(preds, fns_out_fac_Ws)
|
| pf_pairs_out_fac_bs = zip(preds, fns_out_fac_bs)
|
|
|
| this_in_fac_W = tf.case(pf_pairs_in_fac_Ws, exclusive=True)
|
| this_in_fac_b = tf.case(pf_pairs_in_fac_bs, exclusive=True)
|
| this_out_fac_W = tf.case(pf_pairs_out_fac_Ws, exclusive=True)
|
| this_out_fac_b = tf.case(pf_pairs_out_fac_bs, exclusive=True)
|
|
|
|
|
| if hps.ext_input_dim > 0:
|
| self.ext_input = tf.placeholder(tf.float32,
|
| [None, num_steps, ext_input_dim],
|
| name="ext_input")
|
| else:
|
| self.ext_input = None
|
| ext_input_bxtxi = self.ext_input
|
|
|
| self.keep_prob = keep_prob = tf.placeholder(tf.float32, [], "keep_prob")
|
| self.batch_size = batch_size = int(hps.batch_size)
|
| self.learning_rate = tf.Variable(float(hps.learning_rate_init),
|
| trainable=False, name="learning_rate")
|
| self.learning_rate_decay_op = self.learning_rate.assign(
|
| self.learning_rate * hps.learning_rate_decay_factor)
|
|
|
|
|
| dataset_do_bxtxd = tf.nn.dropout(tf.to_float(dataset_ph), keep_prob)
|
| if hps.ext_input_dim > 0:
|
| ext_input_do_bxtxi = tf.nn.dropout(ext_input_bxtxi, keep_prob)
|
| else:
|
| ext_input_do_bxtxi = None
|
|
|
|
|
| def encode_data(dataset_bxtxd, enc_cell, name, forward_or_reverse,
|
| num_steps_to_encode):
|
| """Encode data for LFADS
|
| Args:
|
| dataset_bxtxd - the data to encode, as a 3 tensor, with dims
|
| time x batch x data dims.
|
| enc_cell: encoder cell
|
| name: name of encoder
|
| forward_or_reverse: string, encode in forward or reverse direction
|
| num_steps_to_encode: number of steps to encode, 0:num_steps_to_encode
|
| Returns:
|
| encoded data as a list with num_steps_to_encode items, in order
|
| """
|
| if forward_or_reverse == "forward":
|
| dstr = "_fwd"
|
| time_fwd_or_rev = range(num_steps_to_encode)
|
| else:
|
| dstr = "_rev"
|
| time_fwd_or_rev = reversed(range(num_steps_to_encode))
|
|
|
| with tf.variable_scope(name+"_enc"+dstr, reuse=False):
|
| enc_state = tf.tile(
|
| tf.Variable(tf.zeros([1, enc_cell.state_size]),
|
| name=name+"_enc_t0"+dstr), tf.stack([batch_size, 1]))
|
| enc_state.set_shape([None, enc_cell.state_size])
|
|
|
| enc_outs = [None] * num_steps_to_encode
|
| for i, t in enumerate(time_fwd_or_rev):
|
| with tf.variable_scope(name+"_enc"+dstr, reuse=True if i > 0 else None):
|
| dataset_t_bxd = dataset_bxtxd[:,t,:]
|
| in_fac_t_bxf = tf.matmul(dataset_t_bxd, this_in_fac_W) + this_in_fac_b
|
| in_fac_t_bxf.set_shape([None, used_in_factors_dim])
|
| if ext_input_dim > 0 and not hps.inject_ext_input_to_gen:
|
| ext_input_t_bxi = ext_input_do_bxtxi[:,t,:]
|
| enc_input_t_bxfpe = tf.concat(
|
| axis=1, values=[in_fac_t_bxf, ext_input_t_bxi])
|
| else:
|
| enc_input_t_bxfpe = in_fac_t_bxf
|
| enc_out, enc_state = enc_cell(enc_input_t_bxfpe, enc_state)
|
| enc_outs[t] = enc_out
|
|
|
| return enc_outs
|
|
|
|
|
|
|
| self.ic_enc_fwd = [None] * num_steps
|
| self.ic_enc_rev = [None] * num_steps
|
| if ic_dim > 0:
|
| enc_ic_cell = cell_class(hps.ic_enc_dim,
|
| weight_scale=hps.cell_weight_scale,
|
| clip_value=hps.cell_clip_value)
|
| ic_enc_fwd = encode_data(dataset_do_bxtxd, enc_ic_cell,
|
| "ic", "forward",
|
| hps.num_steps_for_gen_ic)
|
| ic_enc_rev = encode_data(dataset_do_bxtxd, enc_ic_cell,
|
| "ic", "reverse",
|
| hps.num_steps_for_gen_ic)
|
| self.ic_enc_fwd = ic_enc_fwd
|
| self.ic_enc_rev = ic_enc_rev
|
|
|
|
|
|
|
| self.ci_enc_fwd = [None] * num_steps
|
| self.ci_enc_rev = [None] * num_steps
|
| if co_dim > 0:
|
| enc_ci_cell = cell_class(hps.ci_enc_dim,
|
| weight_scale=hps.cell_weight_scale,
|
| clip_value=hps.cell_clip_value)
|
| ci_enc_fwd = encode_data(dataset_do_bxtxd, enc_ci_cell,
|
| "ci", "forward",
|
| hps.num_steps)
|
| if hps.do_causal_controller:
|
| ci_enc_rev = None
|
| else:
|
| ci_enc_rev = encode_data(dataset_do_bxtxd, enc_ci_cell,
|
| "ci", "reverse",
|
| hps.num_steps)
|
| self.ci_enc_fwd = ci_enc_fwd
|
| self.ci_enc_rev = ci_enc_rev
|
|
|
|
|
|
|
|
|
| with tf.variable_scope("z", reuse=False):
|
| self.prior_zs_g0 = None
|
| self.posterior_zs_g0 = None
|
| self.g0s_val = None
|
| if ic_dim > 0:
|
| self.prior_zs_g0 = \
|
| LearnableDiagonalGaussian(batch_size, ic_dim, name="prior_g0",
|
| mean_init=0.0,
|
| var_min=hps.ic_prior_var_min,
|
| var_init=hps.ic_prior_var_scale,
|
| var_max=hps.ic_prior_var_max)
|
| ic_enc = tf.concat(axis=1, values=[ic_enc_fwd[-1], ic_enc_rev[0]])
|
| ic_enc = tf.nn.dropout(ic_enc, keep_prob)
|
| self.posterior_zs_g0 = \
|
| DiagonalGaussianFromInput(ic_enc, ic_dim, "ic_enc_2_post_g0",
|
| var_min=hps.ic_post_var_min)
|
| if kind in ["train", "posterior_sample_and_average",
|
| "posterior_push_mean"]:
|
| zs_g0 = self.posterior_zs_g0
|
| else:
|
| zs_g0 = self.prior_zs_g0
|
| if kind in ["train", "posterior_sample_and_average", "prior_sample"]:
|
| self.g0s_val = zs_g0.sample
|
| else:
|
| self.g0s_val = zs_g0.mean
|
|
|
|
|
| self.prior_zs_co = prior_zs_co = [None] * num_steps
|
| self.posterior_zs_co = posterior_zs_co = [None] * num_steps
|
| self.zs_co = zs_co = [None] * num_steps
|
| self.prior_zs_ar_con = None
|
| if co_dim > 0:
|
|
|
| autocorrelation_taus = [hps.prior_ar_atau for x in range(hps.co_dim)]
|
| noise_variances = [hps.prior_ar_nvar for x in range(hps.co_dim)]
|
| self.prior_zs_ar_con = prior_zs_ar_con = \
|
| LearnableAutoRegressive1Prior(batch_size, hps.co_dim,
|
| autocorrelation_taus,
|
| noise_variances,
|
| hps.do_train_prior_ar_atau,
|
| hps.do_train_prior_ar_nvar,
|
| num_steps, "u_prior_ar1")
|
|
|
|
|
|
|
| self.controller_outputs = u_t = [None] * num_steps
|
| self.con_ics = con_state = None
|
| self.con_states = con_states = [None] * num_steps
|
| self.con_outs = con_outs = [None] * num_steps
|
| self.gen_inputs = gen_inputs = [None] * num_steps
|
| if co_dim > 0:
|
|
|
|
|
| con_cell = gen_cell_class(hps.con_dim,
|
| input_weight_scale=hps.cell_weight_scale,
|
| rec_weight_scale=hps.cell_weight_scale,
|
| clip_value=hps.cell_clip_value,
|
| recurrent_collections=['l2_con_reg'])
|
| with tf.variable_scope("con", reuse=False):
|
| self.con_ics = tf.tile(
|
| tf.Variable(tf.zeros([1, hps.con_dim*con_cell.state_multiplier]),
|
| name="c0"),
|
| tf.stack([batch_size, 1]))
|
| self.con_ics.set_shape([None, con_cell.state_size])
|
| con_states[-1] = self.con_ics
|
|
|
| gen_cell = gen_cell_class(hps.gen_dim,
|
| input_weight_scale=hps.gen_cell_input_weight_scale,
|
| rec_weight_scale=hps.gen_cell_rec_weight_scale,
|
| clip_value=hps.cell_clip_value,
|
| recurrent_collections=['l2_gen_reg'])
|
| with tf.variable_scope("gen", reuse=False):
|
| if ic_dim == 0:
|
| self.gen_ics = tf.tile(
|
| tf.Variable(tf.zeros([1, gen_cell.state_size]), name="g0"),
|
| tf.stack([batch_size, 1]))
|
| else:
|
| self.gen_ics = linear(self.g0s_val, gen_cell.state_size,
|
| identity_if_possible=True,
|
| name="g0_2_gen_ic")
|
|
|
| self.gen_states = gen_states = [None] * num_steps
|
| self.gen_outs = gen_outs = [None] * num_steps
|
| gen_states[-1] = self.gen_ics
|
| gen_outs[-1] = gen_cell.output_from_state(gen_states[-1])
|
| self.factors = factors = [None] * num_steps
|
| factors[-1] = linear(gen_outs[-1], factors_dim, do_bias=False,
|
| normalized=True, name="gen_2_fac")
|
|
|
| self.rates = rates = [None] * num_steps
|
|
|
| with tf.variable_scope("glm", reuse=False):
|
| if hps.output_dist == 'poisson':
|
| log_rates_t0 = tf.matmul(factors[-1], this_out_fac_W) + this_out_fac_b
|
| log_rates_t0.set_shape([None, None])
|
| rates[-1] = tf.exp(log_rates_t0)
|
| rates[-1].set_shape([None, hps.dataset_dims[hps.dataset_names[0]]])
|
| elif hps.output_dist == 'gaussian':
|
| mean_n_logvars = tf.matmul(factors[-1],this_out_fac_W) + this_out_fac_b
|
| mean_n_logvars.set_shape([None, None])
|
| means_t_bxd, logvars_t_bxd = tf.split(axis=1, num_or_size_splits=2,
|
| value=mean_n_logvars)
|
| rates[-1] = means_t_bxd
|
| else:
|
| assert False, "NIY"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| self.output_dist_params = dist_params = [None] * num_steps
|
| self.log_p_xgz_b = log_p_xgz_b = 0.0
|
| for t in range(num_steps):
|
|
|
| if co_dim > 0:
|
|
|
| tlag = t - hps.controller_input_lag
|
| if tlag < 0:
|
| con_in_f_t = tf.zeros_like(ci_enc_fwd[0])
|
| else:
|
| con_in_f_t = ci_enc_fwd[tlag]
|
| if hps.do_causal_controller:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| con_in_list_t = [con_in_f_t]
|
| else:
|
| tlag_rev = t + hps.controller_input_lag
|
| if tlag_rev >= num_steps:
|
|
|
| con_in_r_t = tf.zeros_like(ci_enc_rev[0])
|
| else:
|
| con_in_r_t = ci_enc_rev[tlag_rev]
|
| con_in_list_t = [con_in_f_t, con_in_r_t]
|
|
|
| if hps.do_feed_factors_to_controller:
|
| if hps.feedback_factors_or_rates == "factors":
|
| con_in_list_t.append(factors[t-1])
|
| elif hps.feedback_factors_or_rates == "rates":
|
| con_in_list_t.append(rates[t-1])
|
| else:
|
| assert False, "NIY"
|
|
|
| con_in_t = tf.concat(axis=1, values=con_in_list_t)
|
| con_in_t = tf.nn.dropout(con_in_t, keep_prob)
|
| with tf.variable_scope("con", reuse=True if t > 0 else None):
|
| con_outs[t], con_states[t] = con_cell(con_in_t, con_states[t-1])
|
| posterior_zs_co[t] = \
|
| DiagonalGaussianFromInput(con_outs[t], co_dim,
|
| name="con_to_post_co")
|
| if kind == "train":
|
| u_t[t] = posterior_zs_co[t].sample
|
| elif kind == "posterior_sample_and_average":
|
| u_t[t] = posterior_zs_co[t].sample
|
| elif kind == "posterior_push_mean":
|
| u_t[t] = posterior_zs_co[t].mean
|
| else:
|
| u_t[t] = prior_zs_ar_con.samples_t[t]
|
|
|
|
|
| if ext_input_dim > 0 and hps.inject_ext_input_to_gen:
|
| ext_input_t_bxi = ext_input_do_bxtxi[:,t,:]
|
| if co_dim > 0:
|
| gen_inputs[t] = tf.concat(axis=1, values=[u_t[t], ext_input_t_bxi])
|
| else:
|
| gen_inputs[t] = ext_input_t_bxi
|
| else:
|
| gen_inputs[t] = u_t[t]
|
|
|
|
|
| data_t_bxd = dataset_ph[:,t,:]
|
| with tf.variable_scope("gen", reuse=True if t > 0 else None):
|
| gen_outs[t], gen_states[t] = gen_cell(gen_inputs[t], gen_states[t-1])
|
| gen_outs[t] = tf.nn.dropout(gen_outs[t], keep_prob)
|
| with tf.variable_scope("gen", reuse=True):
|
| factors[t] = linear(gen_outs[t], factors_dim, do_bias=False,
|
| normalized=True, name="gen_2_fac")
|
| with tf.variable_scope("glm", reuse=True if t > 0 else None):
|
| if hps.output_dist == 'poisson':
|
| log_rates_t = tf.matmul(factors[t], this_out_fac_W) + this_out_fac_b
|
| log_rates_t.set_shape([None, None])
|
| rates[t] = dist_params[t] = tf.exp(tf.clip_by_value(log_rates_t, -hps._clip_value, hps._clip_value))
|
| rates[t].set_shape([None, hps.dataset_dims[hps.dataset_names[0]]])
|
| loglikelihood_t = Poisson(log_rates_t).logp(data_t_bxd)
|
|
|
| elif hps.output_dist == 'gaussian':
|
| mean_n_logvars = tf.matmul(factors[t],this_out_fac_W) + this_out_fac_b
|
| mean_n_logvars.set_shape([None, None])
|
| means_t_bxd, logvars_t_bxd = tf.split(axis=1, num_or_size_splits=2,
|
| value=mean_n_logvars)
|
| rates[t] = means_t_bxd
|
| dist_params[t] = tf.concat(
|
| axis=1, values=[means_t_bxd, tf.exp(tf.clip_by_value(logvars_t_bxd, -hps._clip_value, hps._clip_value))])
|
| loglikelihood_t = \
|
| diag_gaussian_log_likelihood(data_t_bxd,
|
| means_t_bxd, logvars_t_bxd)
|
| else:
|
| assert False, "NIY"
|
|
|
| log_p_xgz_b += tf.reduce_sum(loglikelihood_t, [1])
|
|
|
|
|
| self.corr_cost = tf.constant(0.0)
|
| if hps.co_mean_corr_scale > 0.0:
|
| all_sum_corr = []
|
| for i in range(hps.co_dim):
|
| for j in range(i+1, hps.co_dim):
|
| sum_corr_ij = tf.constant(0.0)
|
| for t in range(num_steps):
|
| u_mean_t = posterior_zs_co[t].mean
|
| sum_corr_ij += u_mean_t[:,i]*u_mean_t[:,j]
|
| all_sum_corr.append(0.5 * tf.square(sum_corr_ij))
|
| self.corr_cost = tf.reduce_mean(all_sum_corr)
|
|
|
|
|
|
|
|
|
| kl_cost_g0_b = tf.zeros_like(batch_size, dtype=tf.float32)
|
| kl_cost_co_b = tf.zeros_like(batch_size, dtype=tf.float32)
|
| self.kl_cost = tf.constant(0.0)
|
| self.recon_cost = tf.constant(0.0)
|
| self.nll_bound_vae = tf.constant(0.0)
|
| self.nll_bound_iwae = tf.constant(0.0)
|
| if kind in ["train", "posterior_sample_and_average", "posterior_push_mean"]:
|
| kl_cost_g0_b = 0.0
|
| kl_cost_co_b = 0.0
|
| if ic_dim > 0:
|
| g0_priors = [self.prior_zs_g0]
|
| g0_posts = [self.posterior_zs_g0]
|
| kl_cost_g0_b = KLCost_GaussianGaussian(g0_posts, g0_priors).kl_cost_b
|
| kl_cost_g0_b = hps.kl_ic_weight * kl_cost_g0_b
|
| if co_dim > 0:
|
| kl_cost_co_b = \
|
| KLCost_GaussianGaussianProcessSampled(
|
| posterior_zs_co, prior_zs_ar_con).kl_cost_b
|
| kl_cost_co_b = hps.kl_co_weight * kl_cost_co_b
|
|
|
|
|
|
|
|
|
| self.recon_cost = - tf.reduce_mean(log_p_xgz_b)
|
| self.kl_cost = tf.reduce_mean(kl_cost_g0_b + kl_cost_co_b)
|
|
|
| lb_on_ll_b = log_p_xgz_b - kl_cost_g0_b - kl_cost_co_b
|
|
|
|
|
| self.nll_bound_vae = -tf.reduce_mean(lb_on_ll_b)
|
|
|
|
|
| k = tf.cast(tf.shape(log_p_xgz_b)[0], tf.float32)
|
| iwae_lb_on_ll = -tf.log(k) + log_sum_exp(lb_on_ll_b)
|
| self.nll_bound_iwae = -iwae_lb_on_ll
|
|
|
|
|
| self.l2_cost = tf.constant(0.0)
|
| if self.hps.l2_gen_scale > 0.0 or self.hps.l2_con_scale > 0.0:
|
| l2_costs = []
|
| l2_numels = []
|
| l2_reg_var_lists = [tf.get_collection('l2_gen_reg'),
|
| tf.get_collection('l2_con_reg')]
|
| l2_reg_scales = [self.hps.l2_gen_scale, self.hps.l2_con_scale]
|
| for l2_reg_vars, l2_scale in zip(l2_reg_var_lists, l2_reg_scales):
|
| for v in l2_reg_vars:
|
| numel = tf.reduce_prod(tf.concat(axis=0, values=tf.shape(v)))
|
| numel_f = tf.cast(numel, tf.float32)
|
| l2_numels.append(numel_f)
|
| v_l2 = tf.reduce_sum(v*v)
|
| l2_costs.append(0.5 * l2_scale * v_l2)
|
| self.l2_cost = tf.add_n(l2_costs) / tf.add_n(l2_numels)
|
|
|
|
|
|
|
|
|
|
|
| self.kl_decay_step = tf.maximum(self.train_step - hps.kl_start_step, 0)
|
| self.l2_decay_step = tf.maximum(self.train_step - hps.l2_start_step, 0)
|
| kl_decay_step_f = tf.cast(self.kl_decay_step, tf.float32)
|
| l2_decay_step_f = tf.cast(self.l2_decay_step, tf.float32)
|
| kl_increase_steps_f = tf.cast(hps.kl_increase_steps, tf.float32)
|
| l2_increase_steps_f = tf.cast(hps.l2_increase_steps, tf.float32)
|
| self.kl_weight = kl_weight = \
|
| tf.minimum(kl_decay_step_f / kl_increase_steps_f, 1.0)
|
| self.l2_weight = l2_weight = \
|
| tf.minimum(l2_decay_step_f / l2_increase_steps_f, 1.0)
|
|
|
| self.timed_kl_cost = kl_weight * self.kl_cost
|
| self.timed_l2_cost = l2_weight * self.l2_cost
|
| self.weight_corr_cost = hps.co_mean_corr_scale * self.corr_cost
|
| self.cost = self.recon_cost + self.timed_kl_cost + \
|
| self.timed_l2_cost + self.weight_corr_cost
|
|
|
| if kind != "train":
|
|
|
| self.seso_saver = tf.train.Saver(tf.global_variables(),
|
| max_to_keep=hps.max_ckpt_to_keep)
|
|
|
| self.lve_saver = tf.train.Saver(tf.global_variables(),
|
| max_to_keep=hps.max_ckpt_to_keep_lve)
|
|
|
| return
|
|
|
|
|
|
|
| if self.hps.do_train_io_only:
|
| self.train_vars = tvars = \
|
| tf.get_collection('IO_transformations',
|
| scope=tf.get_variable_scope().name)
|
|
|
| elif self.hps.do_train_encoder_only:
|
| tvars1 = \
|
| tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
| scope='LFADS/ic_enc_*')
|
| tvars2 = \
|
| tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
| scope='LFADS/z/ic_enc_*')
|
|
|
| self.train_vars = tvars = tvars1 + tvars2
|
|
|
| else:
|
| self.train_vars = tvars = \
|
| tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
| scope=tf.get_variable_scope().name)
|
| print("done.")
|
| print("Model Variables (to be optimized): ")
|
| total_params = 0
|
| for i in range(len(tvars)):
|
| shape = tvars[i].get_shape().as_list()
|
| print(" ", i, tvars[i].name, shape)
|
| total_params += np.prod(shape)
|
| print("Total model parameters: ", total_params)
|
|
|
| grads = tf.gradients(self.cost, tvars)
|
| grads, grad_global_norm = tf.clip_by_global_norm(grads, hps.max_grad_norm)
|
| opt = tf.train.AdamOptimizer(self.learning_rate, beta1=0.9, beta2=0.999,
|
| epsilon=1e-01)
|
| self.grads = grads
|
| self.grad_global_norm = grad_global_norm
|
| self.train_op = opt.apply_gradients(
|
| zip(grads, tvars), global_step=self.train_step)
|
|
|
| self.seso_saver = tf.train.Saver(tf.global_variables(),
|
| max_to_keep=hps.max_ckpt_to_keep)
|
|
|
|
|
| self.lve_saver = tf.train.Saver(tf.global_variables(),
|
| max_to_keep=hps.max_ckpt_to_keep)
|
|
|
|
|
|
|
| self.example_image = tf.placeholder(tf.float32, shape=[1,None,None,3],
|
| name='image_tensor')
|
| self.example_summ = tf.summary.image("LFADS example", self.example_image,
|
| collections=["example_summaries"])
|
|
|
|
|
| self.lr_summ = tf.summary.scalar("Learning rate", self.learning_rate)
|
| self.kl_weight_summ = tf.summary.scalar("KL weight", self.kl_weight)
|
| self.l2_weight_summ = tf.summary.scalar("L2 weight", self.l2_weight)
|
| self.corr_cost_summ = tf.summary.scalar("Corr cost", self.weight_corr_cost)
|
| self.grad_global_norm_summ = tf.summary.scalar("Gradient global norm",
|
| self.grad_global_norm)
|
| if hps.co_dim > 0:
|
| self.atau_summ = [None] * hps.co_dim
|
| self.pvar_summ = [None] * hps.co_dim
|
| for c in range(hps.co_dim):
|
| self.atau_summ[c] = \
|
| tf.summary.scalar("AR Autocorrelation taus " + str(c),
|
| tf.exp(self.prior_zs_ar_con.logataus_1xu[0,c]))
|
| self.pvar_summ[c] = \
|
| tf.summary.scalar("AR Variances " + str(c),
|
| tf.exp(self.prior_zs_ar_con.logpvars_1xu[0,c]))
|
|
|
|
|
|
|
|
|
|
|
| kl_cost_ph = tf.placeholder(tf.float32, shape=[], name='kl_cost_ph')
|
| self.kl_t_cost_summ = tf.summary.scalar("KL cost (train)", kl_cost_ph,
|
| collections=["train_summaries"])
|
| self.kl_v_cost_summ = tf.summary.scalar("KL cost (valid)", kl_cost_ph,
|
| collections=["valid_summaries"])
|
| l2_cost_ph = tf.placeholder(tf.float32, shape=[], name='l2_cost_ph')
|
| self.l2_cost_summ = tf.summary.scalar("L2 cost", l2_cost_ph,
|
| collections=["train_summaries"])
|
|
|
| recon_cost_ph = tf.placeholder(tf.float32, shape=[], name='recon_cost_ph')
|
| self.recon_t_cost_summ = tf.summary.scalar("Reconstruction cost (train)",
|
| recon_cost_ph,
|
| collections=["train_summaries"])
|
| self.recon_v_cost_summ = tf.summary.scalar("Reconstruction cost (valid)",
|
| recon_cost_ph,
|
| collections=["valid_summaries"])
|
|
|
| total_cost_ph = tf.placeholder(tf.float32, shape=[], name='total_cost_ph')
|
| self.cost_t_summ = tf.summary.scalar("Total cost (train)", total_cost_ph,
|
| collections=["train_summaries"])
|
| self.cost_v_summ = tf.summary.scalar("Total cost (valid)", total_cost_ph,
|
| collections=["valid_summaries"])
|
|
|
| self.kl_cost_ph = kl_cost_ph
|
| self.l2_cost_ph = l2_cost_ph
|
| self.recon_cost_ph = recon_cost_ph
|
| self.total_cost_ph = total_cost_ph
|
|
|
|
|
| self.merged_examples = tf.summary.merge_all(key="example_summaries")
|
| self.merged_generic = tf.summary.merge_all()
|
| self.merged_train = tf.summary.merge_all(key="train_summaries")
|
| self.merged_valid = tf.summary.merge_all(key="valid_summaries")
|
|
|
| session = tf.get_default_session()
|
| self.logfile = os.path.join(hps.lfads_save_dir, "lfads_log")
|
| self.writer = tf.summary.FileWriter(self.logfile)
|
|
|
| def build_feed_dict(self, train_name, data_bxtxd, ext_input_bxtxi=None,
|
| keep_prob=None):
|
| """Build the feed dictionary, handles cases where there is no value defined.
|
|
|
| Args:
|
| train_name: The key into the datasets, to set the tf.case statement for
|
| the proper readin / readout matrices.
|
| data_bxtxd: The data tensor.
|
| ext_input_bxtxi (optional): The external input tensor.
|
| keep_prob: The drop out keep probability.
|
|
|
| Returns:
|
| The feed dictionary with TF tensors as keys and data as values, for use
|
| with tf.Session.run()
|
|
|
| """
|
| feed_dict = {}
|
| B, T, _ = data_bxtxd.shape
|
| feed_dict[self.dataName] = train_name
|
| feed_dict[self.dataset_ph] = data_bxtxd
|
|
|
| if self.ext_input is not None and ext_input_bxtxi is not None:
|
| feed_dict[self.ext_input] = ext_input_bxtxi
|
|
|
| if keep_prob is None:
|
| feed_dict[self.keep_prob] = self.hps.keep_prob
|
| else:
|
| feed_dict[self.keep_prob] = keep_prob
|
|
|
| return feed_dict
|
|
|
| @staticmethod
|
| def get_batch(data_extxd, ext_input_extxi=None, batch_size=None,
|
| example_idxs=None):
|
| """Get a batch of data, either randomly chosen, or specified directly.
|
|
|
| Args:
|
| data_extxd: The data to model, numpy tensors with shape:
|
| # examples x # time steps x # dimensions
|
| ext_input_extxi (optional): The external inputs, numpy tensor with shape:
|
| # examples x # time steps x # external input dimensions
|
| batch_size: The size of the batch to return.
|
| example_idxs (optional): The example indices used to select examples.
|
|
|
| Returns:
|
| A tuple with two parts:
|
| 1. Batched data numpy tensor with shape:
|
| batch_size x # time steps x # dimensions
|
| 2. Batched external input numpy tensor with shape:
|
| batch_size x # time steps x # external input dims
|
| """
|
| assert batch_size is not None or example_idxs is not None, "Problems"
|
| E, T, D = data_extxd.shape
|
| if example_idxs is None:
|
| example_idxs = np.random.choice(E, batch_size)
|
|
|
| ext_input_bxtxi = None
|
| if ext_input_extxi is not None:
|
| ext_input_bxtxi = ext_input_extxi[example_idxs,:,:]
|
|
|
| return data_extxd[example_idxs,:,:], ext_input_bxtxi
|
|
|
| @staticmethod
|
| def example_idxs_mod_batch_size(nexamples, batch_size):
|
| """Given a number of examples, E, and a batch_size, B, generate indices
|
| [0, 1, 2, ... B-1;
|
| [B, B+1, ... 2*B-1;
|
| ...
|
| ]
|
| returning those indices as a 2-dim tensor shaped like E/B x B. Note that
|
| shape is only correct if E % B == 0. If not, then an extra row is generated
|
| so that the remainder of examples is included. The extra examples are
|
| explicitly to to the zero index (see randomize_example_idxs_mod_batch_size)
|
| for randomized behavior.
|
|
|
| Args:
|
| nexamples: The number of examples to batch up.
|
| batch_size: The size of the batch.
|
| Returns:
|
| 2-dim tensor as described above.
|
| """
|
| bmrem = batch_size - (nexamples % batch_size)
|
| bmrem_examples = []
|
| if bmrem < batch_size:
|
|
|
| ridxs = np.random.permutation(nexamples)[0:bmrem].astype(np.int32)
|
| bmrem_examples = np.sort(ridxs)
|
| example_idxs = range(nexamples) + list(bmrem_examples)
|
| example_idxs_e_x_edivb = np.reshape(example_idxs, [-1, batch_size])
|
| return example_idxs_e_x_edivb, bmrem
|
|
|
| @staticmethod
|
| def randomize_example_idxs_mod_batch_size(nexamples, batch_size):
|
| """Indices 1:nexamples, randomized, in 2D form of
|
| shape = (nexamples / batch_size) x batch_size. The remainder
|
| is managed by drawing randomly from 1:nexamples.
|
|
|
| Args:
|
| nexamples: Number of examples to randomize.
|
| batch_size: Number of elements in batch.
|
|
|
| Returns:
|
| The randomized, properly shaped indicies.
|
| """
|
| assert nexamples > batch_size, "Problems"
|
| bmrem = batch_size - nexamples % batch_size
|
| bmrem_examples = []
|
| if bmrem < batch_size:
|
| bmrem_examples = np.random.choice(range(nexamples),
|
| size=bmrem, replace=False)
|
| example_idxs = range(nexamples) + list(bmrem_examples)
|
| mixed_example_idxs = np.random.permutation(example_idxs)
|
| example_idxs_e_x_edivb = np.reshape(mixed_example_idxs, [-1, batch_size])
|
| return example_idxs_e_x_edivb, bmrem
|
|
|
| def shuffle_spikes_in_time(self, data_bxtxd):
|
| """Shuffle the spikes in the temporal dimension. This is useful to
|
| help the LFADS system avoid overfitting to individual spikes or fast
|
| oscillations found in the data that are irrelevant to behavior. A
|
| pure 'tabula rasa' approach would avoid this, but LFADS is sensitive
|
| enough to pick up dynamics that you may not want.
|
|
|
| Args:
|
| data_bxtxd: Numpy array of spike count data to be shuffled.
|
| Returns:
|
| S_bxtxd, a numpy array with the same dimensions and contents as
|
| data_bxtxd, but shuffled appropriately.
|
|
|
| """
|
|
|
| B, T, N = data_bxtxd.shape
|
| w = self.hps.temporal_spike_jitter_width
|
|
|
| if w == 0:
|
| return data_bxtxd
|
|
|
| max_counts = np.max(data_bxtxd)
|
| S_bxtxd = np.zeros([B,T,N])
|
|
|
|
|
|
|
| for mc in range(1,max_counts+1):
|
| idxs = np.nonzero(data_bxtxd >= mc)
|
|
|
| data_ones = np.zeros_like(data_bxtxd)
|
| data_ones[data_bxtxd >= mc] = 1
|
|
|
| nfound = len(idxs[0])
|
| shuffles_incrs_in_time = np.random.randint(-w, w, size=nfound)
|
|
|
| shuffle_tidxs = idxs[1].copy()
|
| shuffle_tidxs += shuffles_incrs_in_time
|
|
|
|
|
| shuffle_tidxs[shuffle_tidxs < 0] = -shuffle_tidxs[shuffle_tidxs < 0]
|
| shuffle_tidxs[shuffle_tidxs > T-1] = \
|
| (T-1)-(shuffle_tidxs[shuffle_tidxs > T-1] -(T-1))
|
|
|
| for iii in zip(idxs[0], shuffle_tidxs, idxs[2]):
|
| S_bxtxd[iii] += 1
|
|
|
| return S_bxtxd
|
|
|
| def shuffle_and_flatten_datasets(self, datasets, kind='train'):
|
| """Since LFADS supports multiple datasets in the same dynamical model,
|
| we have to be careful to use all the data in a single training epoch. But
|
| since the datasets my have different data dimensionality, we cannot batch
|
| examples from data dictionaries together. Instead, we generate random
|
| batches within each data dictionary, and then randomize these batches
|
| while holding onto the dataname, so that when it's time to feed
|
| the graph, the correct in/out matrices can be selected, per batch.
|
|
|
| Args:
|
| datasets: A dict of data dicts. The dataset dict is simply a
|
| name(string)-> data dictionary mapping (See top of lfads.py).
|
| kind: 'train' or 'valid'
|
|
|
| Returns:
|
| A flat list, in which each element is a pair ('name', indices).
|
| """
|
| batch_size = self.hps.batch_size
|
| ndatasets = len(datasets)
|
| random_example_idxs = {}
|
| epoch_idxs = {}
|
| all_name_example_idx_pairs = []
|
| kind_data = kind + '_data'
|
| for name, data_dict in datasets.items():
|
| nexamples, ntime, data_dim = data_dict[kind_data].shape
|
| epoch_idxs[name] = 0
|
| random_example_idxs, _ = \
|
| self.randomize_example_idxs_mod_batch_size(nexamples, batch_size)
|
|
|
| epoch_size = random_example_idxs.shape[0]
|
| names = [name] * epoch_size
|
| all_name_example_idx_pairs += zip(names, random_example_idxs)
|
|
|
| np.random.shuffle(all_name_example_idx_pairs)
|
|
|
| return all_name_example_idx_pairs
|
|
|
| def train_epoch(self, datasets, batch_size=None, do_save_ckpt=True):
|
| """Train the model through the entire dataset once.
|
|
|
| Args:
|
| datasets: A dict of data dicts. The dataset dict is simply a
|
| name(string)-> data dictionary mapping (See top of lfads.py).
|
| batch_size (optional): The batch_size to use.
|
| do_save_ckpt (optional): Should the routine save a checkpoint on this
|
| training epoch?
|
|
|
| Returns:
|
| A tuple with 6 float values:
|
| (total cost of the epoch, epoch reconstruction cost,
|
| epoch kl cost, KL weight used this training epoch,
|
| total l2 cost on generator, and the corresponding weight).
|
| """
|
| ops_to_eval = [self.cost, self.recon_cost,
|
| self.kl_cost, self.kl_weight,
|
| self.l2_cost, self.l2_weight,
|
| self.train_op]
|
| collected_op_values = self.run_epoch(datasets, ops_to_eval, kind="train")
|
|
|
| total_cost = total_recon_cost = total_kl_cost = 0.0
|
|
|
| epoch_size = len(collected_op_values)
|
| for op_values in collected_op_values:
|
| total_cost += op_values[0]
|
| total_recon_cost += op_values[1]
|
| total_kl_cost += op_values[2]
|
|
|
| kl_weight = collected_op_values[-1][3]
|
| l2_cost = collected_op_values[-1][4]
|
| l2_weight = collected_op_values[-1][5]
|
|
|
| epoch_total_cost = total_cost / epoch_size
|
| epoch_recon_cost = total_recon_cost / epoch_size
|
| epoch_kl_cost = total_kl_cost / epoch_size
|
|
|
| if do_save_ckpt:
|
| session = tf.get_default_session()
|
| checkpoint_path = os.path.join(self.hps.lfads_save_dir,
|
| self.hps.checkpoint_name + '.ckpt')
|
| self.seso_saver.save(session, checkpoint_path,
|
| global_step=self.train_step)
|
|
|
| return epoch_total_cost, epoch_recon_cost, epoch_kl_cost, \
|
| kl_weight, l2_cost, l2_weight
|
|
|
|
|
| def run_epoch(self, datasets, ops_to_eval, kind="train", batch_size=None,
|
| do_collect=True, keep_prob=None):
|
| """Run the model through the entire dataset once.
|
|
|
| Args:
|
| datasets: A dict of data dicts. The dataset dict is simply a
|
| name(string)-> data dictionary mapping (See top of lfads.py).
|
| ops_to_eval: A list of tensorflow operations that will be evaluated in
|
| the tf.session.run() call.
|
| batch_size (optional): The batch_size to use.
|
| do_collect (optional): Should the routine collect all session.run
|
| output as a list, and return it?
|
| keep_prob (optional): The dropout keep probability.
|
|
|
| Returns:
|
| A list of lists, the internal list is the return for the ops for each
|
| session.run() call. The outer list collects over the epoch.
|
| """
|
| hps = self.hps
|
| all_name_example_idx_pairs = \
|
| self.shuffle_and_flatten_datasets(datasets, kind)
|
|
|
| kind_data = kind + '_data'
|
| kind_ext_input = kind + '_ext_input'
|
|
|
| total_cost = total_recon_cost = total_kl_cost = 0.0
|
| session = tf.get_default_session()
|
| epoch_size = len(all_name_example_idx_pairs)
|
| evaled_ops_list = []
|
| for name, example_idxs in all_name_example_idx_pairs:
|
| data_dict = datasets[name]
|
| data_extxd = data_dict[kind_data]
|
| if hps.output_dist == 'poisson' and hps.temporal_spike_jitter_width > 0:
|
| data_extxd = self.shuffle_spikes_in_time(data_extxd)
|
|
|
| ext_input_extxi = data_dict[kind_ext_input]
|
| data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd, ext_input_extxi,
|
| example_idxs=example_idxs)
|
|
|
| feed_dict = self.build_feed_dict(name, data_bxtxd, ext_input_bxtxi,
|
| keep_prob=keep_prob)
|
| evaled_ops_np = session.run(ops_to_eval, feed_dict=feed_dict)
|
| if do_collect:
|
| evaled_ops_list.append(evaled_ops_np)
|
|
|
| return evaled_ops_list
|
|
|
| def summarize_all(self, datasets, summary_values):
|
| """Plot and summarize stuff in tensorboard.
|
|
|
| Note that everything done in the current function is otherwise done on
|
| a single, randomly selected dataset (except for summary_values, which are
|
| passed in.)
|
|
|
| Args:
|
| datasets, the dictionary of datasets used in the study.
|
| summary_values: These summary values are created from the training loop,
|
| and so summarize the entire set of datasets.
|
| """
|
| hps = self.hps
|
| tr_kl_cost = summary_values['tr_kl_cost']
|
| tr_recon_cost = summary_values['tr_recon_cost']
|
| tr_total_cost = summary_values['tr_total_cost']
|
| kl_weight = summary_values['kl_weight']
|
| l2_weight = summary_values['l2_weight']
|
| l2_cost = summary_values['l2_cost']
|
| has_any_valid_set = summary_values['has_any_valid_set']
|
| i = summary_values['nepochs']
|
|
|
| session = tf.get_default_session()
|
| train_summ, train_step = session.run([self.merged_train,
|
| self.train_step],
|
| feed_dict={self.l2_cost_ph:l2_cost,
|
| self.kl_cost_ph:tr_kl_cost,
|
| self.recon_cost_ph:tr_recon_cost,
|
| self.total_cost_ph:tr_total_cost})
|
| self.writer.add_summary(train_summ, train_step)
|
| if has_any_valid_set:
|
| ev_kl_cost = summary_values['ev_kl_cost']
|
| ev_recon_cost = summary_values['ev_recon_cost']
|
| ev_total_cost = summary_values['ev_total_cost']
|
| eval_summ = session.run(self.merged_valid,
|
| feed_dict={self.kl_cost_ph:ev_kl_cost,
|
| self.recon_cost_ph:ev_recon_cost,
|
| self.total_cost_ph:ev_total_cost})
|
| self.writer.add_summary(eval_summ, train_step)
|
| print("Epoch:%d, step:%d (TRAIN, VALID): total: %.2f, %.2f\
|
| recon: %.2f, %.2f, kl: %.2f, %.2f, l2: %.5f,\
|
| kl weight: %.2f, l2 weight: %.2f" % \
|
| (i, train_step, tr_total_cost, ev_total_cost,
|
| tr_recon_cost, ev_recon_cost, tr_kl_cost, ev_kl_cost,
|
| l2_cost, kl_weight, l2_weight))
|
|
|
| csv_outstr = "epoch,%d, step,%d, total,%.2f,%.2f, \
|
| recon,%.2f,%.2f, kl,%.2f,%.2f, l2,%.5f, \
|
| klweight,%.2f, l2weight,%.2f\n"% \
|
| (i, train_step, tr_total_cost, ev_total_cost,
|
| tr_recon_cost, ev_recon_cost, tr_kl_cost, ev_kl_cost,
|
| l2_cost, kl_weight, l2_weight)
|
|
|
| else:
|
| print("Epoch:%d, step:%d TRAIN: total: %.2f recon: %.2f, kl: %.2f,\
|
| l2: %.5f, kl weight: %.2f, l2 weight: %.2f" % \
|
| (i, train_step, tr_total_cost, tr_recon_cost, tr_kl_cost,
|
| l2_cost, kl_weight, l2_weight))
|
| csv_outstr = "epoch,%d, step,%d, total,%.2f, recon,%.2f, kl,%.2f, \
|
| l2,%.5f, klweight,%.2f, l2weight,%.2f\n"% \
|
| (i, train_step, tr_total_cost, tr_recon_cost,
|
| tr_kl_cost, l2_cost, kl_weight, l2_weight)
|
|
|
| if self.hps.csv_log:
|
| csv_file = os.path.join(self.hps.lfads_save_dir, self.hps.csv_log+'.csv')
|
| with open(csv_file, "a") as myfile:
|
| myfile.write(csv_outstr)
|
|
|
|
|
| def plot_single_example(self, datasets):
|
| """Plot an image relating to a randomly chosen, specific example. We use
|
| posterior sample and average by taking one example, and filling a whole
|
| batch with that example, sample from the posterior, and then average the
|
| quantities.
|
|
|
| """
|
| hps = self.hps
|
| all_data_names = datasets.keys()
|
| data_name = np.random.permutation(all_data_names)[0]
|
| data_dict = datasets[data_name]
|
| has_valid_set = True if data_dict['valid_data'] is not None else False
|
| cf = 1.0
|
|
|
|
|
| E, _, _ = data_dict['train_data'].shape
|
| eidx = np.random.choice(E)
|
| example_idxs = eidx * np.ones(hps.batch_size, dtype=np.int32)
|
|
|
| train_data_bxtxd, train_ext_input_bxtxi = \
|
| self.get_batch(data_dict['train_data'], data_dict['train_ext_input'],
|
| example_idxs=example_idxs)
|
|
|
| truth_train_data_bxtxd = None
|
| if 'train_truth' in data_dict and data_dict['train_truth'] is not None:
|
| truth_train_data_bxtxd, _ = self.get_batch(data_dict['train_truth'],
|
| example_idxs=example_idxs)
|
| cf = data_dict['conversion_factor']
|
|
|
|
|
| train_model_values = self.eval_model_runs_batch(data_name,
|
| train_data_bxtxd,
|
| train_ext_input_bxtxi,
|
| do_average_batch=False)
|
|
|
| train_step = train_model_values['train_steps']
|
| feed_dict = self.build_feed_dict(data_name, train_data_bxtxd,
|
| train_ext_input_bxtxi, keep_prob=1.0)
|
|
|
| session = tf.get_default_session()
|
| generic_summ = session.run(self.merged_generic, feed_dict=feed_dict)
|
| self.writer.add_summary(generic_summ, train_step)
|
|
|
| valid_data_bxtxd = valid_model_values = valid_ext_input_bxtxi = None
|
| truth_valid_data_bxtxd = None
|
| if has_valid_set:
|
| E, _, _ = data_dict['valid_data'].shape
|
| eidx = np.random.choice(E)
|
| example_idxs = eidx * np.ones(hps.batch_size, dtype=np.int32)
|
| valid_data_bxtxd, valid_ext_input_bxtxi = \
|
| self.get_batch(data_dict['valid_data'],
|
| data_dict['valid_ext_input'],
|
| example_idxs=example_idxs)
|
| if 'valid_truth' in data_dict and data_dict['valid_truth'] is not None:
|
| truth_valid_data_bxtxd, _ = self.get_batch(data_dict['valid_truth'],
|
| example_idxs=example_idxs)
|
| else:
|
| truth_valid_data_bxtxd = None
|
|
|
|
|
| valid_model_values = self.eval_model_runs_batch(data_name,
|
| valid_data_bxtxd,
|
| valid_ext_input_bxtxi,
|
| do_average_batch=False)
|
|
|
| example_image = plot_lfads(train_bxtxd=train_data_bxtxd,
|
| train_model_vals=train_model_values,
|
| train_ext_input_bxtxi=train_ext_input_bxtxi,
|
| train_truth_bxtxd=truth_train_data_bxtxd,
|
| valid_bxtxd=valid_data_bxtxd,
|
| valid_model_vals=valid_model_values,
|
| valid_ext_input_bxtxi=valid_ext_input_bxtxi,
|
| valid_truth_bxtxd=truth_valid_data_bxtxd,
|
| bidx=None, cf=cf, output_dist=hps.output_dist)
|
| example_image = np.expand_dims(example_image, axis=0)
|
| example_summ = session.run(self.merged_examples,
|
| feed_dict={self.example_image : example_image})
|
| self.writer.add_summary(example_summ)
|
|
|
| def train_model(self, datasets):
|
| """Train the model, print per-epoch information, and save checkpoints.
|
|
|
| Loop over training epochs. The function that actually does the
|
| training is train_epoch. This function iterates over the training
|
| data, one epoch at a time. The learning rate schedule is such
|
| that it will stay the same until the cost goes up in comparison to
|
| the last few values, then it will drop.
|
|
|
| Args:
|
| datasets: A dict of data dicts. The dataset dict is simply a
|
| name(string)-> data dictionary mapping (See top of lfads.py).
|
| """
|
| hps = self.hps
|
| has_any_valid_set = False
|
| for data_dict in datasets.values():
|
| if data_dict['valid_data'] is not None:
|
| has_any_valid_set = True
|
| break
|
|
|
| session = tf.get_default_session()
|
| lr = session.run(self.learning_rate)
|
| lr_stop = hps.learning_rate_stop
|
| i = -1
|
| train_costs = []
|
| valid_costs = []
|
| ev_total_cost = ev_recon_cost = ev_kl_cost = 0.0
|
| lowest_ev_cost = np.Inf
|
| while True:
|
| i += 1
|
| do_save_ckpt = True if i % 10 ==0 else False
|
| tr_total_cost, tr_recon_cost, tr_kl_cost, kl_weight, l2_cost, l2_weight = \
|
| self.train_epoch(datasets, do_save_ckpt=do_save_ckpt)
|
|
|
|
|
|
|
|
|
| if has_any_valid_set:
|
| ev_total_cost, ev_recon_cost, ev_kl_cost = \
|
| self.eval_cost_epoch(datasets, kind='valid')
|
| valid_costs.append(ev_total_cost)
|
|
|
|
|
|
|
| n_lve = 1
|
| run_avg_lve = np.mean(valid_costs[-n_lve:])
|
|
|
|
|
|
|
|
|
|
|
|
|
| if kl_weight >= 1.0 and \
|
| (l2_weight >= 1.0 or \
|
| (self.hps.l2_gen_scale == 0.0 and self.hps.l2_con_scale == 0.0)) \
|
| and (len(valid_costs) > n_lve and run_avg_lve < lowest_ev_cost):
|
|
|
| lowest_ev_cost = run_avg_lve
|
| checkpoint_path = os.path.join(self.hps.lfads_save_dir,
|
| self.hps.checkpoint_name + '_lve.ckpt')
|
| self.lve_saver.save(session, checkpoint_path,
|
| global_step=self.train_step,
|
| latest_filename='checkpoint_lve')
|
|
|
|
|
| values = {'nepochs':i, 'has_any_valid_set': has_any_valid_set,
|
| 'tr_total_cost':tr_total_cost, 'ev_total_cost':ev_total_cost,
|
| 'tr_recon_cost':tr_recon_cost, 'ev_recon_cost':ev_recon_cost,
|
| 'tr_kl_cost':tr_kl_cost, 'ev_kl_cost':ev_kl_cost,
|
| 'l2_weight':l2_weight, 'kl_weight':kl_weight,
|
| 'l2_cost':l2_cost}
|
| self.summarize_all(datasets, values)
|
| self.plot_single_example(datasets)
|
|
|
|
|
| train_res = tr_total_cost
|
| n_lr = hps.learning_rate_n_to_compare
|
| if len(train_costs) > n_lr and train_res > np.max(train_costs[-n_lr:]):
|
| _ = session.run(self.learning_rate_decay_op)
|
| lr = session.run(self.learning_rate)
|
| print(" Decreasing learning rate to %f." % lr)
|
|
|
| train_costs.append(np.inf)
|
| else:
|
| train_costs.append(train_res)
|
|
|
| if lr < lr_stop:
|
| print("Stopping optimization based on learning rate criteria.")
|
| break
|
|
|
| def eval_cost_epoch(self, datasets, kind='train', ext_input_extxi=None,
|
| batch_size=None):
|
| """Evaluate the cost of the epoch.
|
|
|
| Args:
|
| data_dict: The dictionary of data (training and validation) used for
|
| training and evaluation of the model, respectively.
|
|
|
| Returns:
|
| a 3 tuple of costs:
|
| (epoch total cost, epoch reconstruction cost, epoch KL cost)
|
| """
|
| ops_to_eval = [self.cost, self.recon_cost, self.kl_cost]
|
| collected_op_values = self.run_epoch(datasets, ops_to_eval, kind=kind,
|
| keep_prob=1.0)
|
|
|
| total_cost = total_recon_cost = total_kl_cost = 0.0
|
|
|
| epoch_size = len(collected_op_values)
|
| for op_values in collected_op_values:
|
| total_cost += op_values[0]
|
| total_recon_cost += op_values[1]
|
| total_kl_cost += op_values[2]
|
|
|
| epoch_total_cost = total_cost / epoch_size
|
| epoch_recon_cost = total_recon_cost / epoch_size
|
| epoch_kl_cost = total_kl_cost / epoch_size
|
|
|
| return epoch_total_cost, epoch_recon_cost, epoch_kl_cost
|
|
|
| def eval_model_runs_batch(self, data_name, data_bxtxd, ext_input_bxtxi=None,
|
| do_eval_cost=False, do_average_batch=False):
|
| """Returns all the goodies for the entire model, per batch.
|
|
|
| If data_bxtxd and ext_input_bxtxi can have fewer than batch_size along dim 1
|
| in which case this handles the padding and truncating automatically
|
|
|
| Args:
|
| data_name: The name of the data dict, to select which in/out matrices
|
| to use.
|
| data_bxtxd: Numpy array training data with shape:
|
| batch_size x # time steps x # dimensions
|
| ext_input_bxtxi: Numpy array training external input with shape:
|
| batch_size x # time steps x # external input dims
|
| do_eval_cost (optional): If true, the IWAE (Importance Weighted
|
| Autoencoder) log likeihood bound, instead of the VAE version.
|
| do_average_batch (optional): average over the batch, useful for getting
|
| good IWAE costs, and model outputs for a single data point.
|
|
|
| Returns:
|
| A dictionary with the outputs of the model decoder, namely:
|
| prior g0 mean, prior g0 variance, approx. posterior mean, approx
|
| posterior mean, the generator initial conditions, the control inputs (if
|
| enabled), the state of the generator, the factors, and the rates.
|
| """
|
| session = tf.get_default_session()
|
|
|
|
|
| hps = self.hps
|
| batch_size = hps.batch_size
|
| E, _, _ = data_bxtxd.shape
|
| if E < hps.batch_size:
|
| data_bxtxd = np.pad(data_bxtxd, ((0, hps.batch_size-E), (0, 0), (0, 0)),
|
| mode='constant', constant_values=0)
|
| if ext_input_bxtxi is not None:
|
| ext_input_bxtxi = np.pad(ext_input_bxtxi,
|
| ((0, hps.batch_size-E), (0, 0), (0, 0)),
|
| mode='constant', constant_values=0)
|
|
|
| feed_dict = self.build_feed_dict(data_name, data_bxtxd,
|
| ext_input_bxtxi, keep_prob=1.0)
|
|
|
|
|
|
|
| tf_vals = [self.gen_ics, self.gen_states, self.factors,
|
| self.output_dist_params]
|
| tf_vals.append(self.cost)
|
| tf_vals.append(self.nll_bound_vae)
|
| tf_vals.append(self.nll_bound_iwae)
|
| tf_vals.append(self.train_step)
|
| if self.hps.ic_dim > 0:
|
| tf_vals += [self.prior_zs_g0.mean, self.prior_zs_g0.logvar,
|
| self.posterior_zs_g0.mean, self.posterior_zs_g0.logvar]
|
| if self.hps.co_dim > 0:
|
| tf_vals.append(self.controller_outputs)
|
| tf_vals_flat, fidxs = flatten(tf_vals)
|
|
|
| np_vals_flat = session.run(tf_vals_flat, feed_dict=feed_dict)
|
|
|
| ff = 0
|
| gen_ics = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
|
| gen_states = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
|
| factors = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
|
| out_dist_params = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
|
| costs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
|
| nll_bound_vaes = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
|
| nll_bound_iwaes = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1
|
| train_steps = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1
|
| if self.hps.ic_dim > 0:
|
| prior_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1
|
| prior_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
|
| post_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
|
| post_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
|
| if self.hps.co_dim > 0:
|
| controller_outputs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
|
|
|
|
|
| gen_ics = gen_ics[0]
|
| costs = costs[0]
|
| nll_bound_vaes = nll_bound_vaes[0]
|
| nll_bound_iwaes = nll_bound_iwaes[0]
|
| train_steps = train_steps[0]
|
|
|
|
|
| gen_states = list_t_bxn_to_tensor_bxtxn(gen_states)
|
| factors = list_t_bxn_to_tensor_bxtxn(factors)
|
| out_dist_params = list_t_bxn_to_tensor_bxtxn(out_dist_params)
|
| if self.hps.ic_dim > 0:
|
|
|
| prior_g0_mean = prior_g0_mean[0]
|
| prior_g0_logvar = prior_g0_logvar[0]
|
| post_g0_mean = post_g0_mean[0]
|
| post_g0_logvar = post_g0_logvar[0]
|
| if self.hps.co_dim > 0:
|
| controller_outputs = list_t_bxn_to_tensor_bxtxn(controller_outputs)
|
|
|
|
|
| if E < hps.batch_size:
|
| idx = np.arange(E)
|
| gen_ics = gen_ics[idx, :]
|
| gen_states = gen_states[idx, :]
|
| factors = factors[idx, :, :]
|
| out_dist_params = out_dist_params[idx, :, :]
|
| if self.hps.ic_dim > 0:
|
| prior_g0_mean = prior_g0_mean[idx, :]
|
| prior_g0_logvar = prior_g0_logvar[idx, :]
|
| post_g0_mean = post_g0_mean[idx, :]
|
| post_g0_logvar = post_g0_logvar[idx, :]
|
| if self.hps.co_dim > 0:
|
| controller_outputs = controller_outputs[idx, :, :]
|
|
|
| if do_average_batch:
|
| gen_ics = np.mean(gen_ics, axis=0)
|
| gen_states = np.mean(gen_states, axis=0)
|
| factors = np.mean(factors, axis=0)
|
| out_dist_params = np.mean(out_dist_params, axis=0)
|
| if self.hps.ic_dim > 0:
|
| prior_g0_mean = np.mean(prior_g0_mean, axis=0)
|
| prior_g0_logvar = np.mean(prior_g0_logvar, axis=0)
|
| post_g0_mean = np.mean(post_g0_mean, axis=0)
|
| post_g0_logvar = np.mean(post_g0_logvar, axis=0)
|
| if self.hps.co_dim > 0:
|
| controller_outputs = np.mean(controller_outputs, axis=0)
|
|
|
| model_vals = {}
|
| model_vals['gen_ics'] = gen_ics
|
| model_vals['gen_states'] = gen_states
|
| model_vals['factors'] = factors
|
| model_vals['output_dist_params'] = out_dist_params
|
| model_vals['costs'] = costs
|
| model_vals['nll_bound_vaes'] = nll_bound_vaes
|
| model_vals['nll_bound_iwaes'] = nll_bound_iwaes
|
| model_vals['train_steps'] = train_steps
|
| if self.hps.ic_dim > 0:
|
| model_vals['prior_g0_mean'] = prior_g0_mean
|
| model_vals['prior_g0_logvar'] = prior_g0_logvar
|
| model_vals['post_g0_mean'] = post_g0_mean
|
| model_vals['post_g0_logvar'] = post_g0_logvar
|
| if self.hps.co_dim > 0:
|
| model_vals['controller_outputs'] = controller_outputs
|
|
|
| return model_vals
|
|
|
| def eval_model_runs_avg_epoch(self, data_name, data_extxd,
|
| ext_input_extxi=None):
|
| """Returns all the expected value for goodies for the entire model.
|
|
|
| The expected value is taken over hidden (z) variables, namely the initial
|
| conditions and the control inputs. The expected value is approximate, and
|
| accomplished via sampling (batch_size) samples for every examples.
|
|
|
| Args:
|
| data_name: The name of the data dict, to select which in/out matrices
|
| to use.
|
| data_extxd: Numpy array training data with shape:
|
| # examples x # time steps x # dimensions
|
| ext_input_extxi (optional): Numpy array training external input with
|
| shape: # examples x # time steps x # external input dims
|
|
|
| Returns:
|
| A dictionary with the averaged outputs of the model decoder, namely:
|
| prior g0 mean, prior g0 variance, approx. posterior mean, approx
|
| posterior mean, the generator initial conditions, the control inputs (if
|
| enabled), the state of the generator, the factors, and the output
|
| distribution parameters, e.g. (rates or mean and variances).
|
| """
|
| hps = self.hps
|
| batch_size = hps.batch_size
|
| E, T, D = data_extxd.shape
|
| E_to_process = hps.ps_nexamples_to_process
|
| if E_to_process > E:
|
| E_to_process = E
|
|
|
| if hps.ic_dim > 0:
|
| prior_g0_mean = np.zeros([E_to_process, hps.ic_dim])
|
| prior_g0_logvar = np.zeros([E_to_process, hps.ic_dim])
|
| post_g0_mean = np.zeros([E_to_process, hps.ic_dim])
|
| post_g0_logvar = np.zeros([E_to_process, hps.ic_dim])
|
|
|
| if hps.co_dim > 0:
|
| controller_outputs = np.zeros([E_to_process, T, hps.co_dim])
|
| gen_ics = np.zeros([E_to_process, hps.gen_dim])
|
| gen_states = np.zeros([E_to_process, T, hps.gen_dim])
|
| factors = np.zeros([E_to_process, T, hps.factors_dim])
|
|
|
| if hps.output_dist == 'poisson':
|
| out_dist_params = np.zeros([E_to_process, T, D])
|
| elif hps.output_dist == 'gaussian':
|
| out_dist_params = np.zeros([E_to_process, T, D+D])
|
| else:
|
| assert False, "NIY"
|
|
|
| costs = np.zeros(E_to_process)
|
| nll_bound_vaes = np.zeros(E_to_process)
|
| nll_bound_iwaes = np.zeros(E_to_process)
|
| train_steps = np.zeros(E_to_process)
|
| for es_idx in range(E_to_process):
|
| print("Running %d of %d." % (es_idx+1, E_to_process))
|
| example_idxs = es_idx * np.ones(batch_size, dtype=np.int32)
|
| data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd,
|
| ext_input_extxi,
|
| batch_size=batch_size,
|
| example_idxs=example_idxs)
|
| model_values = self.eval_model_runs_batch(data_name, data_bxtxd,
|
| ext_input_bxtxi,
|
| do_eval_cost=True,
|
| do_average_batch=True)
|
|
|
| if self.hps.ic_dim > 0:
|
| prior_g0_mean[es_idx,:] = model_values['prior_g0_mean']
|
| prior_g0_logvar[es_idx,:] = model_values['prior_g0_logvar']
|
| post_g0_mean[es_idx,:] = model_values['post_g0_mean']
|
| post_g0_logvar[es_idx,:] = model_values['post_g0_logvar']
|
| gen_ics[es_idx,:] = model_values['gen_ics']
|
|
|
| if self.hps.co_dim > 0:
|
| controller_outputs[es_idx,:,:] = model_values['controller_outputs']
|
| gen_states[es_idx,:,:] = model_values['gen_states']
|
| factors[es_idx,:,:] = model_values['factors']
|
| out_dist_params[es_idx,:,:] = model_values['output_dist_params']
|
| costs[es_idx] = model_values['costs']
|
| nll_bound_vaes[es_idx] = model_values['nll_bound_vaes']
|
| nll_bound_iwaes[es_idx] = model_values['nll_bound_iwaes']
|
| train_steps[es_idx] = model_values['train_steps']
|
| print('bound nll(vae): %.3f, bound nll(iwae): %.3f' \
|
| % (nll_bound_vaes[es_idx], nll_bound_iwaes[es_idx]))
|
|
|
| model_runs = {}
|
| if self.hps.ic_dim > 0:
|
| model_runs['prior_g0_mean'] = prior_g0_mean
|
| model_runs['prior_g0_logvar'] = prior_g0_logvar
|
| model_runs['post_g0_mean'] = post_g0_mean
|
| model_runs['post_g0_logvar'] = post_g0_logvar
|
| model_runs['gen_ics'] = gen_ics
|
|
|
| if self.hps.co_dim > 0:
|
| model_runs['controller_outputs'] = controller_outputs
|
| model_runs['gen_states'] = gen_states
|
| model_runs['factors'] = factors
|
| model_runs['output_dist_params'] = out_dist_params
|
| model_runs['costs'] = costs
|
| model_runs['nll_bound_vaes'] = nll_bound_vaes
|
| model_runs['nll_bound_iwaes'] = nll_bound_iwaes
|
| model_runs['train_steps'] = train_steps
|
| return model_runs
|
|
|
| def eval_model_runs_push_mean(self, data_name, data_extxd,
|
| ext_input_extxi=None):
|
| """Returns values of interest for the model by pushing the means through
|
|
|
| The mean values for both initial conditions and the control inputs are
|
| pushed through the model instead of sampling (as is done in
|
| eval_model_runs_avg_epoch).
|
| This is a quick and approximate version of estimating these values instead
|
| of sampling from the posterior many times and then averaging those values of
|
| interest.
|
|
|
| Internally, a total of batch_size trials are run through the model at once.
|
|
|
| Args:
|
| data_name: The name of the data dict, to select which in/out matrices
|
| to use.
|
| data_extxd: Numpy array training data with shape:
|
| # examples x # time steps x # dimensions
|
| ext_input_extxi (optional): Numpy array training external input with
|
| shape: # examples x # time steps x # external input dims
|
|
|
| Returns:
|
| A dictionary with the estimated outputs of the model decoder, namely:
|
| prior g0 mean, prior g0 variance, approx. posterior mean, approx
|
| posterior mean, the generator initial conditions, the control inputs (if
|
| enabled), the state of the generator, the factors, and the output
|
| distribution parameters, e.g. (rates or mean and variances).
|
| """
|
| hps = self.hps
|
| batch_size = hps.batch_size
|
| E, T, D = data_extxd.shape
|
| E_to_process = hps.ps_nexamples_to_process
|
| if E_to_process > E:
|
| print("Setting number of posterior samples to process to : ", E)
|
| E_to_process = E
|
|
|
| if hps.ic_dim > 0:
|
| prior_g0_mean = np.zeros([E_to_process, hps.ic_dim])
|
| prior_g0_logvar = np.zeros([E_to_process, hps.ic_dim])
|
| post_g0_mean = np.zeros([E_to_process, hps.ic_dim])
|
| post_g0_logvar = np.zeros([E_to_process, hps.ic_dim])
|
|
|
| if hps.co_dim > 0:
|
| controller_outputs = np.zeros([E_to_process, T, hps.co_dim])
|
| gen_ics = np.zeros([E_to_process, hps.gen_dim])
|
| gen_states = np.zeros([E_to_process, T, hps.gen_dim])
|
| factors = np.zeros([E_to_process, T, hps.factors_dim])
|
|
|
| if hps.output_dist == 'poisson':
|
| out_dist_params = np.zeros([E_to_process, T, D])
|
| elif hps.output_dist == 'gaussian':
|
| out_dist_params = np.zeros([E_to_process, T, D+D])
|
| else:
|
| assert False, "NIY"
|
|
|
| costs = np.zeros(E_to_process)
|
| nll_bound_vaes = np.zeros(E_to_process)
|
| nll_bound_iwaes = np.zeros(E_to_process)
|
| train_steps = np.zeros(E_to_process)
|
|
|
|
|
|
|
|
|
| def trial_batches(N, per):
|
| for i in range(0, N, per):
|
| yield np.arange(i, min(i+per, N), dtype=np.int32)
|
|
|
| for batch_idx, es_idx in enumerate(trial_batches(E_to_process,
|
| hps.batch_size)):
|
| print("Running trial batch %d with %d trials" % (batch_idx+1,
|
| len(es_idx)))
|
| data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd,
|
| ext_input_extxi,
|
| batch_size=batch_size,
|
| example_idxs=es_idx)
|
| model_values = self.eval_model_runs_batch(data_name, data_bxtxd,
|
| ext_input_bxtxi,
|
| do_eval_cost=True,
|
| do_average_batch=False)
|
|
|
| if self.hps.ic_dim > 0:
|
| prior_g0_mean[es_idx,:] = model_values['prior_g0_mean']
|
| prior_g0_logvar[es_idx,:] = model_values['prior_g0_logvar']
|
| post_g0_mean[es_idx,:] = model_values['post_g0_mean']
|
| post_g0_logvar[es_idx,:] = model_values['post_g0_logvar']
|
| gen_ics[es_idx,:] = model_values['gen_ics']
|
|
|
| if self.hps.co_dim > 0:
|
| controller_outputs[es_idx,:,:] = model_values['controller_outputs']
|
| gen_states[es_idx,:,:] = model_values['gen_states']
|
| factors[es_idx,:,:] = model_values['factors']
|
| out_dist_params[es_idx,:,:] = model_values['output_dist_params']
|
|
|
|
|
|
|
|
|
| costs[es_idx] = model_values['costs']
|
| nll_bound_vaes[es_idx] = model_values['nll_bound_vaes']
|
| nll_bound_iwaes[es_idx] = model_values['nll_bound_iwaes']
|
|
|
| train_steps[es_idx] = model_values['train_steps']
|
|
|
| model_runs = {}
|
| if self.hps.ic_dim > 0:
|
| model_runs['prior_g0_mean'] = prior_g0_mean
|
| model_runs['prior_g0_logvar'] = prior_g0_logvar
|
| model_runs['post_g0_mean'] = post_g0_mean
|
| model_runs['post_g0_logvar'] = post_g0_logvar
|
| model_runs['gen_ics'] = gen_ics
|
|
|
| if self.hps.co_dim > 0:
|
| model_runs['controller_outputs'] = controller_outputs
|
| model_runs['gen_states'] = gen_states
|
| model_runs['factors'] = factors
|
| model_runs['output_dist_params'] = out_dist_params
|
|
|
|
|
|
|
| model_runs['costs'] = costs
|
| model_runs['nll_bound_vaes'] = nll_bound_vaes
|
| model_runs['nll_bound_iwaes'] = nll_bound_iwaes
|
| model_runs['train_steps'] = train_steps
|
| return model_runs
|
|
|
| def write_model_runs(self, datasets, output_fname=None, push_mean=False):
|
| """Run the model on the data in data_dict, and save the computed values.
|
|
|
| LFADS generates a number of outputs for each examples, and these are all
|
| saved. They are:
|
| The mean and variance of the prior of g0.
|
| The mean and variance of approximate posterior of g0.
|
| The control inputs (if enabled).
|
| The initial conditions, g0, for all examples.
|
| The generator states for all time.
|
| The factors for all time.
|
| The output distribution parameters (e.g. rates) for all time.
|
|
|
| Args:
|
| datasets: A dictionary of named data_dictionaries, see top of lfads.py
|
| output_fname: a file name stem for the output files.
|
| push_mean: If False (default), generates batch_size samples for each trial
|
| and averages the results. if True, runs each trial once without noise,
|
| pushing the posterior mean initial conditions and control inputs through
|
| the trained model. False is used for posterior_sample_and_average, True
|
| is used for posterior_push_mean.
|
| """
|
| hps = self.hps
|
| kind = hps.kind
|
|
|
| for data_name, data_dict in datasets.items():
|
| data_tuple = [('train', data_dict['train_data'],
|
| data_dict['train_ext_input']),
|
| ('valid', data_dict['valid_data'],
|
| data_dict['valid_ext_input'])]
|
| for data_kind, data_extxd, ext_input_extxi in data_tuple:
|
| if not output_fname:
|
| fname = "model_runs_" + data_name + '_' + data_kind + '_' + kind
|
| else:
|
| fname = output_fname + data_name + '_' + data_kind + '_' + kind
|
|
|
| print("Writing data for %s data and kind %s." % (data_name, data_kind))
|
| if push_mean:
|
| model_runs = self.eval_model_runs_push_mean(data_name, data_extxd,
|
| ext_input_extxi)
|
| else:
|
| model_runs = self.eval_model_runs_avg_epoch(data_name, data_extxd,
|
| ext_input_extxi)
|
| full_fname = os.path.join(hps.lfads_save_dir, fname)
|
| write_data(full_fname, model_runs, compression='gzip')
|
| print("Done.")
|
|
|
| def write_model_samples(self, dataset_name, output_fname=None):
|
| """Use the prior distribution to generate batch_size number of samples
|
| from the model.
|
|
|
| LFADS generates a number of outputs for each sample, and these are all
|
| saved. They are:
|
| The mean and variance of the prior of g0.
|
| The control inputs (if enabled).
|
| The initial conditions, g0, for all examples.
|
| The generator states for all time.
|
| The factors for all time.
|
| The output distribution parameters (e.g. rates) for all time.
|
|
|
| Args:
|
| dataset_name: The name of the dataset to grab the factors -> rates
|
| alignment matrices from.
|
| output_fname: The name of the file in which to save the generated
|
| samples.
|
| """
|
| hps = self.hps
|
| batch_size = hps.batch_size
|
|
|
| print("Generating %d samples" % (batch_size))
|
| tf_vals = [self.factors, self.gen_states, self.gen_ics,
|
| self.cost, self.output_dist_params]
|
| if hps.ic_dim > 0:
|
| tf_vals += [self.prior_zs_g0.mean, self.prior_zs_g0.logvar]
|
| if hps.co_dim > 0:
|
| tf_vals += [self.prior_zs_ar_con.samples_t]
|
| tf_vals_flat, fidxs = flatten(tf_vals)
|
|
|
| session = tf.get_default_session()
|
| feed_dict = {}
|
| feed_dict[self.dataName] = dataset_name
|
| feed_dict[self.keep_prob] = 1.0
|
|
|
| np_vals_flat = session.run(tf_vals_flat, feed_dict=feed_dict)
|
|
|
| ff = 0
|
| factors = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
|
| gen_states = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
|
| gen_ics = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
|
| costs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
|
| output_dist_params = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
|
| if hps.ic_dim > 0:
|
| prior_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
|
| prior_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
|
| if hps.co_dim > 0:
|
| prior_zs_ar_con = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
|
|
|
|
|
| gen_ics = gen_ics[0]
|
| costs = costs[0]
|
|
|
|
|
| gen_states = list_t_bxn_to_tensor_bxtxn(gen_states)
|
| factors = list_t_bxn_to_tensor_bxtxn(factors)
|
| output_dist_params = list_t_bxn_to_tensor_bxtxn(output_dist_params)
|
| if hps.ic_dim > 0:
|
| prior_g0_mean = prior_g0_mean[0]
|
| prior_g0_logvar = prior_g0_logvar[0]
|
| if hps.co_dim > 0:
|
| prior_zs_ar_con = list_t_bxn_to_tensor_bxtxn(prior_zs_ar_con)
|
|
|
| model_vals = {}
|
| model_vals['gen_ics'] = gen_ics
|
| model_vals['gen_states'] = gen_states
|
| model_vals['factors'] = factors
|
| model_vals['output_dist_params'] = output_dist_params
|
| model_vals['costs'] = costs.reshape(1)
|
| if hps.ic_dim > 0:
|
| model_vals['prior_g0_mean'] = prior_g0_mean
|
| model_vals['prior_g0_logvar'] = prior_g0_logvar
|
| if hps.co_dim > 0:
|
| model_vals['prior_zs_ar_con'] = prior_zs_ar_con
|
|
|
| full_fname = os.path.join(hps.lfads_save_dir, output_fname)
|
| write_data(full_fname, model_vals, compression='gzip')
|
| print("Done.")
|
|
|
| @staticmethod
|
| def eval_model_parameters(use_nested=True, include_strs=None):
|
| """Evaluate and return all of the TF variables in the model.
|
|
|
| Args:
|
| use_nested (optional): For returning values, use a nested dictoinary, based
|
| on variable scoping, or return all variables in a flat dictionary.
|
| include_strs (optional): A list of strings to use as a filter, to reduce the
|
| number of variables returned. A variable name must contain at least one
|
| string in include_strs as a sub-string in order to be returned.
|
|
|
| Returns:
|
| The parameters of the model. This can be in a flat
|
| dictionary, or a nested dictionary, where the nesting is by variable
|
| scope.
|
| """
|
| all_tf_vars = tf.global_variables()
|
| session = tf.get_default_session()
|
| all_tf_vars_eval = session.run(all_tf_vars)
|
| vars_dict = {}
|
| strs = ["LFADS"]
|
| if include_strs:
|
| strs += include_strs
|
|
|
| for i, (var, var_eval) in enumerate(zip(all_tf_vars, all_tf_vars_eval)):
|
| if any(s in include_strs for s in var.name):
|
| if not isinstance(var_eval, np.ndarray):
|
| print(var.name, """ is not numpy array, saving as numpy array
|
| with value: """, var_eval, type(var_eval))
|
| e = np.array(var_eval)
|
| print(e, type(e))
|
| else:
|
| e = var_eval
|
| vars_dict[var.name] = e
|
|
|
| if not use_nested:
|
| return vars_dict
|
|
|
| var_names = vars_dict.keys()
|
| nested_vars_dict = {}
|
| current_dict = nested_vars_dict
|
| for v, var_name in enumerate(var_names):
|
| var_split_name_list = var_name.split('/')
|
| split_name_list_len = len(var_split_name_list)
|
| current_dict = nested_vars_dict
|
| for p, part in enumerate(var_split_name_list):
|
| if p < split_name_list_len - 1:
|
| if part in current_dict:
|
| current_dict = current_dict[part]
|
| else:
|
| current_dict[part] = {}
|
| current_dict = current_dict[part]
|
| else:
|
| current_dict[part] = vars_dict[var_name]
|
|
|
| return nested_vars_dict
|
|
|
| @staticmethod
|
| def spikify_rates(rates_bxtxd):
|
| """Randomly spikify underlying rates according a Poisson distribution
|
|
|
| Args:
|
| rates_bxtxd: A numpy tensor with shape:
|
|
|
| Returns:
|
| A numpy array with the same shape as rates_bxtxd, but with the event
|
| counts.
|
| """
|
|
|
| B,T,N = rates_bxtxd.shape
|
| assert all([B > 0, N > 0]), "problems"
|
|
|
|
|
| spikes_bxtxd = np.zeros([B,T,N], dtype=np.int32)
|
| for b in range(B):
|
| for t in range(T):
|
| for n in range(N):
|
| rate = rates_bxtxd[b,t,n]
|
| count = np.random.poisson(rate)
|
| spikes_bxtxd[b,t,n] = count
|
|
|
| return spikes_bxtxd
|
|
|