|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| from __future__ import absolute_import
|
| from __future__ import division
|
| from __future__ import print_function
|
|
|
| from lfads import LFADS
|
| import numpy as np
|
| import os
|
| import tensorflow as tf
|
| import re
|
| import utils
|
| import sys
|
| MAX_INT = sys.maxsize
|
|
|
|
|
|
|
|
|
|
|
| CHECKPOINT_PB_LOAD_NAME = "checkpoint"
|
| CHECKPOINT_NAME = "lfads_vae"
|
| CSV_LOG = "fitlog"
|
| OUTPUT_FILENAME_STEM = ""
|
| DEVICE = "gpu:0"
|
| MAX_CKPT_TO_KEEP = 5
|
| MAX_CKPT_TO_KEEP_LVE = 5
|
| PS_NEXAMPLES_TO_PROCESS = MAX_INT
|
| EXT_INPUT_DIM = 0
|
| IC_DIM = 64
|
| FACTORS_DIM = 50
|
| IC_ENC_DIM = 128
|
| GEN_DIM = 200
|
| GEN_CELL_INPUT_WEIGHT_SCALE = 1.0
|
| GEN_CELL_REC_WEIGHT_SCALE = 1.0
|
| CELL_WEIGHT_SCALE = 1.0
|
| BATCH_SIZE = 128
|
| LEARNING_RATE_INIT = 0.01
|
| LEARNING_RATE_DECAY_FACTOR = 0.95
|
| LEARNING_RATE_STOP = 0.00001
|
| LEARNING_RATE_N_TO_COMPARE = 6
|
| INJECT_EXT_INPUT_TO_GEN = False
|
| DO_TRAIN_IO_ONLY = False
|
| DO_TRAIN_ENCODER_ONLY = False
|
| DO_RESET_LEARNING_RATE = False
|
| FEEDBACK_FACTORS_OR_RATES = "factors"
|
| DO_TRAIN_READIN = True
|
|
|
|
|
| MAX_GRAD_NORM = 200.0
|
| CELL_CLIP_VALUE = 5.0
|
| KEEP_PROB = 0.95
|
| TEMPORAL_SPIKE_JITTER_WIDTH = 0
|
| OUTPUT_DISTRIBUTION = 'poisson'
|
| NUM_STEPS_FOR_GEN_IC = MAX_INT
|
|
|
| DATA_DIR = "/tmp/rnn_synth_data_v1.0/"
|
| DATA_FILENAME_STEM = "chaotic_rnn_inputs_g1p5"
|
| LFADS_SAVE_DIR = "/tmp/lfads_chaotic_rnn_inputs_g1p5/"
|
| CO_DIM = 1
|
| DO_CAUSAL_CONTROLLER = False
|
| DO_FEED_FACTORS_TO_CONTROLLER = True
|
| CONTROLLER_INPUT_LAG = 1
|
| PRIOR_AR_AUTOCORRELATION = 10.0
|
| PRIOR_AR_PROCESS_VAR = 0.1
|
| DO_TRAIN_PRIOR_AR_ATAU = True
|
| DO_TRAIN_PRIOR_AR_NVAR = True
|
| CI_ENC_DIM = 128
|
| CON_DIM = 128
|
| CO_PRIOR_VAR_SCALE = 0.1
|
| KL_INCREASE_STEPS = 2000
|
| L2_INCREASE_STEPS = 2000
|
| L2_GEN_SCALE = 2000.0
|
| L2_CON_SCALE = 0.0
|
|
|
| CO_MEAN_CORR_SCALE = 0.0
|
| KL_IC_WEIGHT = 1.0
|
| KL_CO_WEIGHT = 1.0
|
| KL_START_STEP = 0
|
| L2_START_STEP = 0
|
| IC_PRIOR_VAR_MIN = 0.1
|
| IC_PRIOR_VAR_SCALE = 0.1
|
| IC_PRIOR_VAR_MAX = 0.1
|
| IC_POST_VAR_MIN = 0.0001
|
|
|
| flags = tf.app.flags
|
| flags.DEFINE_string("kind", "train",
|
| "Type of model to build {train, \
|
| posterior_sample_and_average, \
|
| posterior_push_mean, \
|
| prior_sample, write_model_params")
|
| flags.DEFINE_string("output_dist", OUTPUT_DISTRIBUTION,
|
| "Type of output distribution, 'poisson' or 'gaussian'")
|
| flags.DEFINE_boolean("allow_gpu_growth", False,
|
| "If true, only allocate amount of memory needed for \
|
| Session. Otherwise, use full GPU memory.")
|
|
|
|
|
| flags.DEFINE_string("data_dir", DATA_DIR, "Data for training")
|
| flags.DEFINE_string("data_filename_stem", DATA_FILENAME_STEM,
|
| "Filename stem for data dictionaries.")
|
| flags.DEFINE_string("lfads_save_dir", LFADS_SAVE_DIR, "model save dir")
|
| flags.DEFINE_string("checkpoint_pb_load_name", CHECKPOINT_PB_LOAD_NAME,
|
| "Name of checkpoint files, use 'checkpoint_lve' for best \
|
| error")
|
| flags.DEFINE_string("checkpoint_name", CHECKPOINT_NAME,
|
| "Name of checkpoint files (.ckpt appended)")
|
| flags.DEFINE_string("output_filename_stem", OUTPUT_FILENAME_STEM,
|
| "Name of output file (postfix will be added)")
|
| flags.DEFINE_string("device", DEVICE,
|
| "Which device to use (default: \"gpu:0\", can also be \
|
| \"cpu:0\", \"gpu:1\", etc)")
|
| flags.DEFINE_string("csv_log", CSV_LOG,
|
| "Name of file to keep running log of fit likelihoods, \
|
| etc (.csv appended)")
|
| flags.DEFINE_integer("max_ckpt_to_keep", MAX_CKPT_TO_KEEP,
|
| "Max # of checkpoints to keep (rolling)")
|
| flags.DEFINE_integer("ps_nexamples_to_process", PS_NEXAMPLES_TO_PROCESS,
|
| "Number of examples to process for posterior sample and \
|
| average (not number of samples to average over).")
|
| flags.DEFINE_integer("max_ckpt_to_keep_lve", MAX_CKPT_TO_KEEP_LVE,
|
| "Max # of checkpoints to keep for lowest validation error \
|
| models (rolling)")
|
| flags.DEFINE_integer("ext_input_dim", EXT_INPUT_DIM, "Dimension of external \
|
| inputs")
|
| flags.DEFINE_integer("num_steps_for_gen_ic", NUM_STEPS_FOR_GEN_IC,
|
| "Number of steps to train the generator initial conditon.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| flags.DEFINE_boolean("inject_ext_input_to_gen",
|
| INJECT_EXT_INPUT_TO_GEN,
|
| "Should observed inputs be input to model via encoders, \
|
| or injected directly into generator?")
|
|
|
|
|
|
|
|
|
|
|
|
|
| flags.DEFINE_float("cell_weight_scale", CELL_WEIGHT_SCALE,
|
| "Input scaling for input weights in generator.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| flags.DEFINE_integer("ic_dim", IC_DIM, "Dimension of h0")
|
|
|
|
|
|
|
| flags.DEFINE_integer("factors_dim", FACTORS_DIM,
|
| "Number of factors from generator")
|
| flags.DEFINE_integer("ic_enc_dim", IC_ENC_DIM,
|
| "Cell hidden size, encoder of h0")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| flags.DEFINE_integer("gen_dim", GEN_DIM,
|
| "Cell hidden size, generator.")
|
|
|
|
|
|
|
| flags.DEFINE_float("gen_cell_input_weight_scale", GEN_CELL_INPUT_WEIGHT_SCALE,
|
| "Input scaling for input weights in generator.")
|
| flags.DEFINE_float("gen_cell_rec_weight_scale", GEN_CELL_REC_WEIGHT_SCALE,
|
| "Input scaling for rec weights in generator.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| flags.DEFINE_float("ic_prior_var_min", IC_PRIOR_VAR_MIN,
|
| "Minimum variance in posterior h0 codes.")
|
| flags.DEFINE_float("ic_prior_var_scale", IC_PRIOR_VAR_SCALE,
|
| "Variance of ic prior distribution")
|
| flags.DEFINE_float("ic_prior_var_max", IC_PRIOR_VAR_MAX,
|
| "Maximum variance of IC prior distribution.")
|
|
|
|
|
| flags.DEFINE_float("ic_post_var_min", IC_POST_VAR_MIN,
|
| "Minimum variance of IC posterior distribution.")
|
| flags.DEFINE_float("co_prior_var_scale", CO_PRIOR_VAR_SCALE,
|
| "Variance of control input prior distribution.")
|
|
|
|
|
| flags.DEFINE_float("prior_ar_atau", PRIOR_AR_AUTOCORRELATION,
|
| "Initial autocorrelation of AR(1) priors.")
|
| flags.DEFINE_float("prior_ar_nvar", PRIOR_AR_PROCESS_VAR,
|
| "Initial noise variance for AR(1) priors.")
|
| flags.DEFINE_boolean("do_train_prior_ar_atau", DO_TRAIN_PRIOR_AR_ATAU,
|
| "Is the value for atau an init, or the constant value?")
|
| flags.DEFINE_boolean("do_train_prior_ar_nvar", DO_TRAIN_PRIOR_AR_NVAR,
|
| "Is the value for noise variance an init, or the constant \
|
| value?")
|
|
|
|
|
|
|
|
|
|
|
|
|
| flags.DEFINE_integer("co_dim", CO_DIM,
|
| "Number of control net outputs (>0 builds that graph).")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| flags.DEFINE_boolean("do_causal_controller",
|
| DO_CAUSAL_CONTROLLER,
|
| "Restrict the controller create only causal inferred \
|
| inputs?")
|
|
|
|
|
|
|
| flags.DEFINE_boolean("do_feed_factors_to_controller",
|
| DO_FEED_FACTORS_TO_CONTROLLER,
|
| "Should factors[t-1] be input to controller at time t?")
|
| flags.DEFINE_string("feedback_factors_or_rates", FEEDBACK_FACTORS_OR_RATES,
|
| "Feedback the factors or the rates to the controller? \
|
| Acceptable values: 'factors' or 'rates'.")
|
| flags.DEFINE_integer("controller_input_lag", CONTROLLER_INPUT_LAG,
|
| "Time lag on the encoding to controller t-lag for \
|
| forward, t+lag for reverse.")
|
|
|
| flags.DEFINE_integer("ci_enc_dim", CI_ENC_DIM,
|
| "Cell hidden size, encoder of control inputs")
|
| flags.DEFINE_integer("con_dim", CON_DIM,
|
| "Cell hidden size, controller")
|
|
|
|
|
|
|
| flags.DEFINE_integer("batch_size", BATCH_SIZE,
|
| "Batch size to use during training.")
|
| flags.DEFINE_float("learning_rate_init", LEARNING_RATE_INIT,
|
| "Learning rate initial value")
|
| flags.DEFINE_float("learning_rate_decay_factor", LEARNING_RATE_DECAY_FACTOR,
|
| "Learning rate decay, decay by this fraction every so \
|
| often.")
|
| flags.DEFINE_float("learning_rate_stop", LEARNING_RATE_STOP,
|
| "The lr is adaptively reduced, stop training at this value.")
|
|
|
|
|
|
|
|
|
| flags.DEFINE_integer("learning_rate_n_to_compare", LEARNING_RATE_N_TO_COMPARE,
|
| "Number of previous costs current cost has to be worse \
|
| than, to lower learning rate.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| flags.DEFINE_float("max_grad_norm", MAX_GRAD_NORM,
|
| "Max norm of gradient before clipping.")
|
|
|
|
|
|
|
|
|
|
|
| flags.DEFINE_float("cell_clip_value", CELL_CLIP_VALUE,
|
| "Max value recurrent cell can take before being clipped.")
|
|
|
|
|
|
|
|
|
|
|
| flags.DEFINE_boolean("do_train_io_only", DO_TRAIN_IO_ONLY,
|
| "Train only the input (readin) and output (readout) \
|
| affine functions.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| flags.DEFINE_boolean("do_train_encoder_only", DO_TRAIN_ENCODER_ONLY,
|
| "Train only the encoder weights.")
|
|
|
| flags.DEFINE_boolean("do_reset_learning_rate", DO_RESET_LEARNING_RATE,
|
| "Reset the learning rate to initial value.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| flags.DEFINE_boolean("do_train_readin", DO_TRAIN_READIN, "Whether to train the \
|
| readin matrices and bias vectors. False leaves them fixed \
|
| at their initial values specified by the alignment \
|
| matrices and vectors.")
|
|
|
|
|
|
|
|
|
|
|
| flags.DEFINE_float("keep_prob", KEEP_PROB, "Dropout keep probability.")
|
|
|
|
|
|
|
| flags.DEFINE_integer("temporal_spike_jitter_width",
|
| TEMPORAL_SPIKE_JITTER_WIDTH,
|
| "Shuffle spikes around this window.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| flags.DEFINE_float("l2_gen_scale", L2_GEN_SCALE,
|
| "L2 regularization cost for the generator only.")
|
| flags.DEFINE_float("l2_con_scale", L2_CON_SCALE,
|
| "L2 regularization cost for the controller only.")
|
| flags.DEFINE_float("co_mean_corr_scale", CO_MEAN_CORR_SCALE,
|
| "Cost of correlation (thru time)in the means of \
|
| controller output.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| flags.DEFINE_float("kl_ic_weight", KL_IC_WEIGHT,
|
| "Strength of KL weight on initial conditions KL penatly.")
|
| flags.DEFINE_float("kl_co_weight", KL_CO_WEIGHT,
|
| "Strength of KL weight on controller output KL penalty.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| flags.DEFINE_integer("kl_start_step", KL_START_STEP,
|
| "Start increasing weight after this many steps.")
|
|
|
| flags.DEFINE_integer("kl_increase_steps", KL_INCREASE_STEPS,
|
| "Increase weight of kl cost to avoid local minimum.")
|
|
|
|
|
| flags.DEFINE_integer("l2_start_step", L2_START_STEP,
|
| "Start increasing l2 weight after this many steps.")
|
| flags.DEFINE_integer("l2_increase_steps", L2_INCREASE_STEPS,
|
| "Increase weight of l2 cost to avoid local minimum.")
|
|
|
| FLAGS = flags.FLAGS
|
|
|
|
|
| def build_model(hps, kind="train", datasets=None):
|
| """Builds a model from either random initialization, or saved parameters.
|
|
|
| Args:
|
| hps: The hyper parameters for the model.
|
| kind: (optional) The kind of model to build. Training vs inference require
|
| different graphs.
|
| datasets: The datasets structure (see top of lfads.py).
|
|
|
| Returns:
|
| an LFADS model.
|
| """
|
|
|
| build_kind = kind
|
| if build_kind == "write_model_params":
|
| build_kind = "train"
|
| with tf.variable_scope("LFADS", reuse=None):
|
| model = LFADS(hps, kind=build_kind, datasets=datasets)
|
|
|
| if not os.path.exists(hps.lfads_save_dir):
|
| print("Save directory %s does not exist, creating it." % hps.lfads_save_dir)
|
| os.makedirs(hps.lfads_save_dir)
|
|
|
| cp_pb_ln = hps.checkpoint_pb_load_name
|
| cp_pb_ln = 'checkpoint' if cp_pb_ln == "" else cp_pb_ln
|
| if cp_pb_ln == 'checkpoint':
|
| print("Loading latest training checkpoint in: ", hps.lfads_save_dir)
|
| saver = model.seso_saver
|
| elif cp_pb_ln == 'checkpoint_lve':
|
| print("Loading lowest validation checkpoint in: ", hps.lfads_save_dir)
|
| saver = model.lve_saver
|
| else:
|
| print("Loading checkpoint: ", cp_pb_ln, ", in: ", hps.lfads_save_dir)
|
| saver = model.seso_saver
|
|
|
| ckpt = tf.train.get_checkpoint_state(hps.lfads_save_dir,
|
| latest_filename=cp_pb_ln)
|
|
|
| session = tf.get_default_session()
|
| print("ckpt: ", ckpt)
|
| if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
|
| print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
|
| saver.restore(session, ckpt.model_checkpoint_path)
|
| else:
|
| print("Created model with fresh parameters.")
|
| if kind in ["posterior_sample_and_average", "posterior_push_mean",
|
| "prior_sample", "write_model_params"]:
|
| print("Possible error!!! You are running ", kind, " on a newly \
|
| initialized model!")
|
|
|
| print("Are you sure you sure a checkpoint in ", hps.lfads_save_dir,
|
| " exists?")
|
|
|
| tf.global_variables_initializer().run()
|
|
|
| if ckpt:
|
| train_step_str = re.search('-[0-9]+$', ckpt.model_checkpoint_path).group()
|
| else:
|
| train_step_str = '-0'
|
|
|
| fname = 'hyperparameters' + train_step_str + '.txt'
|
| hp_fname = os.path.join(hps.lfads_save_dir, fname)
|
| hps_for_saving = jsonify_dict(hps)
|
| utils.write_data(hp_fname, hps_for_saving, use_json=True)
|
|
|
| return model
|
|
|
|
|
| def jsonify_dict(d):
|
| """Turns python booleans into strings so hps dict can be written in json.
|
| Creates a shallow-copied dictionary first, then accomplishes string
|
| conversion.
|
|
|
| Args:
|
| d: hyperparameter dictionary
|
|
|
| Returns: hyperparameter dictionary with bool's as strings
|
| """
|
|
|
| d2 = d.copy()
|
| def jsonify_bool(boolean_value):
|
| if boolean_value:
|
| return "true"
|
| else:
|
| return "false"
|
|
|
| for key in d2.keys():
|
| if isinstance(d2[key], bool):
|
| d2[key] = jsonify_bool(d2[key])
|
| return d2
|
|
|
|
|
| def build_hyperparameter_dict(flags):
|
| """Simple script for saving hyper parameters. Under the hood the
|
| flags structure isn't a dictionary, so it has to be simplified since we
|
| want to be able to view file as text.
|
|
|
| Args:
|
| flags: From tf.app.flags
|
|
|
| Returns:
|
| dictionary of hyper parameters (ignoring other flag types).
|
| """
|
| d = {}
|
|
|
| d['output_dist'] = flags.output_dist
|
| d['data_dir'] = flags.data_dir
|
| d['lfads_save_dir'] = flags.lfads_save_dir
|
| d['checkpoint_pb_load_name'] = flags.checkpoint_pb_load_name
|
| d['checkpoint_name'] = flags.checkpoint_name
|
| d['output_filename_stem'] = flags.output_filename_stem
|
| d['max_ckpt_to_keep'] = flags.max_ckpt_to_keep
|
| d['max_ckpt_to_keep_lve'] = flags.max_ckpt_to_keep_lve
|
| d['ps_nexamples_to_process'] = flags.ps_nexamples_to_process
|
| d['ext_input_dim'] = flags.ext_input_dim
|
| d['data_filename_stem'] = flags.data_filename_stem
|
| d['device'] = flags.device
|
| d['csv_log'] = flags.csv_log
|
| d['num_steps_for_gen_ic'] = flags.num_steps_for_gen_ic
|
| d['inject_ext_input_to_gen'] = flags.inject_ext_input_to_gen
|
|
|
| d['cell_weight_scale'] = flags.cell_weight_scale
|
|
|
| d['ic_dim'] = flags.ic_dim
|
| d['factors_dim'] = flags.factors_dim
|
| d['ic_enc_dim'] = flags.ic_enc_dim
|
| d['gen_dim'] = flags.gen_dim
|
| d['gen_cell_input_weight_scale'] = flags.gen_cell_input_weight_scale
|
| d['gen_cell_rec_weight_scale'] = flags.gen_cell_rec_weight_scale
|
|
|
| d['ic_prior_var_min'] = flags.ic_prior_var_min
|
| d['ic_prior_var_scale'] = flags.ic_prior_var_scale
|
| d['ic_prior_var_max'] = flags.ic_prior_var_max
|
| d['ic_post_var_min'] = flags.ic_post_var_min
|
| d['co_prior_var_scale'] = flags.co_prior_var_scale
|
| d['prior_ar_atau'] = flags.prior_ar_atau
|
| d['prior_ar_nvar'] = flags.prior_ar_nvar
|
| d['do_train_prior_ar_atau'] = flags.do_train_prior_ar_atau
|
| d['do_train_prior_ar_nvar'] = flags.do_train_prior_ar_nvar
|
|
|
| d['do_causal_controller'] = flags.do_causal_controller
|
| d['controller_input_lag'] = flags.controller_input_lag
|
| d['do_feed_factors_to_controller'] = flags.do_feed_factors_to_controller
|
| d['feedback_factors_or_rates'] = flags.feedback_factors_or_rates
|
| d['co_dim'] = flags.co_dim
|
| d['ci_enc_dim'] = flags.ci_enc_dim
|
| d['con_dim'] = flags.con_dim
|
| d['co_mean_corr_scale'] = flags.co_mean_corr_scale
|
|
|
| d['batch_size'] = flags.batch_size
|
| d['learning_rate_init'] = flags.learning_rate_init
|
| d['learning_rate_decay_factor'] = flags.learning_rate_decay_factor
|
| d['learning_rate_stop'] = flags.learning_rate_stop
|
| d['learning_rate_n_to_compare'] = flags.learning_rate_n_to_compare
|
| d['max_grad_norm'] = flags.max_grad_norm
|
| d['cell_clip_value'] = flags.cell_clip_value
|
| d['do_train_io_only'] = flags.do_train_io_only
|
| d['do_train_encoder_only'] = flags.do_train_encoder_only
|
| d['do_reset_learning_rate'] = flags.do_reset_learning_rate
|
| d['do_train_readin'] = flags.do_train_readin
|
|
|
|
|
| d['keep_prob'] = flags.keep_prob
|
| d['temporal_spike_jitter_width'] = flags.temporal_spike_jitter_width
|
| d['l2_gen_scale'] = flags.l2_gen_scale
|
| d['l2_con_scale'] = flags.l2_con_scale
|
|
|
| d['kl_ic_weight'] = flags.kl_ic_weight
|
| d['kl_co_weight'] = flags.kl_co_weight
|
| d['kl_start_step'] = flags.kl_start_step
|
| d['kl_increase_steps'] = flags.kl_increase_steps
|
| d['l2_start_step'] = flags.l2_start_step
|
| d['l2_increase_steps'] = flags.l2_increase_steps
|
| d['_clip_value'] = 80
|
|
|
| return d
|
|
|
|
|
| class hps_dict_to_obj(dict):
|
| """Helper class allowing us to access hps dictionary more easily."""
|
|
|
| def __getattr__(self, key):
|
| if key in self:
|
| return self[key]
|
| else:
|
| assert False, ("%s does not exist." % key)
|
| def __setattr__(self, key, value):
|
| self[key] = value
|
|
|
|
|
| def train(hps, datasets):
|
| """Train the LFADS model.
|
|
|
| Args:
|
| hps: The dictionary of hyperparameters.
|
| datasets: A dictionary of data dictionaries. The dataset dict is simply a
|
| name(string)-> data dictionary mapping (See top of lfads.py).
|
| """
|
| model = build_model(hps, kind="train", datasets=datasets)
|
| if hps.do_reset_learning_rate:
|
| sess = tf.get_default_session()
|
| sess.run(model.learning_rate.initializer)
|
|
|
| model.train_model(datasets)
|
|
|
|
|
| def write_model_runs(hps, datasets, output_fname=None, push_mean=False):
|
| """Run the model on the data in data_dict, and save the computed values.
|
|
|
| LFADS generates a number of outputs for each examples, and these are all
|
| saved. They are:
|
| The mean and variance of the prior of g0.
|
| The mean and variance of approximate posterior of g0.
|
| The control inputs (if enabled)
|
| The initial conditions, g0, for all examples.
|
| The generator states for all time.
|
| The factors for all time.
|
| The rates for all time.
|
|
|
| Args:
|
| hps: The dictionary of hyperparameters.
|
| datasets: A dictionary of data dictionaries. The dataset dict is simply a
|
| name(string)-> data dictionary mapping (See top of lfads.py).
|
| output_fname (optional): output filename stem to write the model runs.
|
| push_mean: if False (default), generates batch_size samples for each trial
|
| and averages the results. if True, runs each trial once without noise,
|
| pushing the posterior mean initial conditions and control inputs through
|
| the trained model. False is used for posterior_sample_and_average, True
|
| is used for posterior_push_mean.
|
| """
|
| model = build_model(hps, kind=hps.kind, datasets=datasets)
|
| model.write_model_runs(datasets, output_fname, push_mean)
|
|
|
|
|
| def write_model_samples(hps, datasets, dataset_name=None, output_fname=None):
|
| """Use the prior distribution to generate samples from the model.
|
| Generates batch_size number of samples (set through FLAGS).
|
|
|
| LFADS generates a number of outputs for each examples, and these are all
|
| saved. They are:
|
| The mean and variance of the prior of g0.
|
| The control inputs (if enabled)
|
| The initial conditions, g0, for all examples.
|
| The generator states for all time.
|
| The factors for all time.
|
| The output distribution parameters (e.g. rates) for all time.
|
|
|
| Args:
|
| hps: The dictionary of hyperparameters.
|
| datasets: A dictionary of data dictionaries. The dataset dict is simply a
|
| name(string)-> data dictionary mapping (See top of lfads.py).
|
| dataset_name: The name of the dataset to grab the factors -> rates
|
| alignment matrices from. Only a concern with models trained on
|
| multi-session data. By default, uses the first dataset in the data dict.
|
| output_fname: The name prefix of the file in which to save the generated
|
| samples.
|
| """
|
| if not output_fname:
|
| output_fname = "model_runs_" + hps.kind
|
| else:
|
| output_fname = output_fname + "model_runs_" + hps.kind
|
| if not dataset_name:
|
| dataset_name = datasets.keys()[0]
|
| else:
|
| if dataset_name not in datasets.keys():
|
| raise ValueError("Invalid dataset name '%s'."%(dataset_name))
|
| model = build_model(hps, kind=hps.kind, datasets=datasets)
|
| model.write_model_samples(dataset_name, output_fname)
|
|
|
|
|
| def write_model_parameters(hps, output_fname=None, datasets=None):
|
| """Save all the model parameters
|
|
|
| Save all the parameters to hps.lfads_save_dir.
|
|
|
| Args:
|
| hps: The dictionary of hyperparameters.
|
| output_fname: The prefix of the file in which to save the generated
|
| samples.
|
| datasets: A dictionary of data dictionaries. The dataset dict is simply a
|
| name(string)-> data dictionary mapping (See top of lfads.py).
|
| """
|
| if not output_fname:
|
| output_fname = "model_params"
|
| else:
|
| output_fname = output_fname + "_model_params"
|
| fname = os.path.join(hps.lfads_save_dir, output_fname)
|
| print("Writing model parameters to: ", fname)
|
|
|
| model = build_model(hps, kind="write_model_params", datasets=datasets)
|
| model_params = model.eval_model_parameters(use_nested=False,
|
| include_strs="LFADS")
|
| utils.write_data(fname, model_params, compression=None)
|
| print("Done.")
|
|
|
|
|
| def clean_data_dict(data_dict):
|
| """Add some key/value pairs to the data dict, if they are missing.
|
| Args:
|
| data_dict - dictionary containing data for LFADS
|
| Returns:
|
| data_dict with some keys filled in, if they are absent.
|
| """
|
|
|
| keys = ['train_truth', 'train_ext_input', 'valid_data',
|
| 'valid_truth', 'valid_ext_input', 'valid_train']
|
| for k in keys:
|
| if k not in data_dict:
|
| data_dict[k] = None
|
|
|
| return data_dict
|
|
|
|
|
| def load_datasets(data_dir, data_filename_stem):
|
| """Load the datasets from a specified directory.
|
|
|
| Example files look like
|
| >data_dir/my_dataset_first_day
|
| >data_dir/my_dataset_second_day
|
|
|
| If my_dataset (filename) stem is in the directory, the read routine will try
|
| and load it. The datasets dictionary will then look like
|
| dataset['first_day'] -> (first day data dictionary)
|
| dataset['second_day'] -> (first day data dictionary)
|
|
|
| Args:
|
| data_dir: The directory from which to load the datasets.
|
| data_filename_stem: The stem of the filename for the datasets.
|
|
|
| Returns:
|
| datasets: a dataset dictionary, with one name->data dictionary pair for
|
| each dataset file.
|
| """
|
| print("Reading data from ", data_dir)
|
| datasets = utils.read_datasets(data_dir, data_filename_stem)
|
| for k, data_dict in datasets.items():
|
| datasets[k] = clean_data_dict(data_dict)
|
|
|
| train_total_size = len(data_dict['train_data'])
|
| if train_total_size == 0:
|
| print("Did not load training set.")
|
| else:
|
| print("Found training set with number examples: ", train_total_size)
|
|
|
| valid_total_size = len(data_dict['valid_data'])
|
| if valid_total_size == 0:
|
| print("Did not load validation set.")
|
| else:
|
| print("Found validation set with number examples: ", valid_total_size)
|
|
|
| return datasets
|
|
|
|
|
| def main(_):
|
| """Get this whole shindig off the ground."""
|
| d = build_hyperparameter_dict(FLAGS)
|
| hps = hps_dict_to_obj(d)
|
| kind = FLAGS.kind
|
|
|
|
|
| train_set = valid_set = None
|
| if kind in ["train", "posterior_sample_and_average", "posterior_push_mean",
|
| "prior_sample", "write_model_params"]:
|
| datasets = load_datasets(hps.data_dir, hps.data_filename_stem)
|
| else:
|
| raise ValueError('Kind {} is not supported.'.format(kind))
|
|
|
|
|
| hps.kind = kind
|
| hps.dataset_names = []
|
| hps.dataset_dims = {}
|
| for key in datasets:
|
| hps.dataset_names.append(key)
|
| hps.dataset_dims[key] = datasets[key]['data_dim']
|
|
|
|
|
|
|
| hps.num_steps = datasets.values()[0]['num_steps']
|
| hps.ndatasets = len(hps.dataset_names)
|
|
|
| if hps.num_steps_for_gen_ic > hps.num_steps:
|
| hps.num_steps_for_gen_ic = hps.num_steps
|
|
|
|
|
| config = tf.ConfigProto(allow_soft_placement=True,
|
| log_device_placement=False)
|
| if FLAGS.allow_gpu_growth:
|
| config.gpu_options.allow_growth = True
|
| sess = tf.Session(config=config)
|
| with sess.as_default():
|
| with tf.device(hps.device):
|
| if kind == "train":
|
| train(hps, datasets)
|
| elif kind == "posterior_sample_and_average":
|
| write_model_runs(hps, datasets, hps.output_filename_stem,
|
| push_mean=False)
|
| elif kind == "posterior_push_mean":
|
| write_model_runs(hps, datasets, hps.output_filename_stem,
|
| push_mean=True)
|
| elif kind == "prior_sample":
|
| write_model_samples(hps, datasets, hps.output_filename_stem)
|
| elif kind == "write_model_params":
|
| write_model_parameters(hps, hps.output_filename_stem, datasets)
|
| else:
|
| assert False, ("Kind %s is not implemented. " % kind)
|
|
|
|
|
| if __name__ == "__main__":
|
| tf.app.run()
|
|
|