|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| from __future__ import absolute_import
|
| from __future__ import division
|
| from __future__ import print_function
|
|
|
| import matplotlib
|
| matplotlib.use('Agg')
|
| from matplotlib import pyplot as plt
|
| import numpy as np
|
| import tensorflow as tf
|
|
|
| def _plot_item(W, name, full_name, nspaces):
|
| plt.figure()
|
| if W.shape == ():
|
| print(name, ": ", W)
|
| elif W.shape[0] == 1:
|
| plt.stem(W.T)
|
| plt.title(full_name)
|
| elif W.shape[1] == 1:
|
| plt.stem(W)
|
| plt.title(full_name)
|
| else:
|
| plt.imshow(np.abs(W), interpolation='nearest', cmap='jet');
|
| plt.colorbar()
|
| plt.title(full_name)
|
|
|
|
|
| def all_plot(d, full_name="", exclude="", nspaces=0):
|
| """Recursively plot all the LFADS model parameters in the nested
|
| dictionary."""
|
| for k, v in d.iteritems():
|
| this_name = full_name+"/"+k
|
| if isinstance(v, dict):
|
| all_plot(v, full_name=this_name, exclude=exclude, nspaces=nspaces+4)
|
| else:
|
| if exclude == "" or exclude not in this_name:
|
| _plot_item(v, name=k, full_name=full_name+"/"+k, nspaces=nspaces+4)
|
|
|
|
|
|
|
| def plot_time_series(vals_bxtxn, bidx=None, n_to_plot=np.inf, scale=1.0,
|
| color='r', title=None):
|
|
|
| if bidx is None:
|
| vals_txn = np.mean(vals_bxtxn, axis=0)
|
| else:
|
| vals_txn = vals_bxtxn[bidx,:,:]
|
|
|
| T, N = vals_txn.shape
|
| if n_to_plot > N:
|
| n_to_plot = N
|
|
|
| plt.plot(vals_txn[:,0:n_to_plot] + scale*np.array(range(n_to_plot)),
|
| color=color, lw=1.0)
|
| plt.axis('tight')
|
| if title:
|
| plt.title(title)
|
|
|
|
|
| def plot_lfads_timeseries(data_bxtxn, model_vals, ext_input_bxtxi=None,
|
| truth_bxtxn=None, bidx=None, output_dist="poisson",
|
| conversion_factor=1.0, subplot_cidx=0,
|
| col_title=None):
|
|
|
| n_to_plot = 10
|
| scale = 1.0
|
| nrows = 7
|
| plt.subplot(nrows,2,1+subplot_cidx)
|
|
|
| if output_dist == 'poisson':
|
| rates = means = conversion_factor * model_vals['output_dist_params']
|
| plot_time_series(rates, bidx, n_to_plot=n_to_plot, scale=scale,
|
| title=col_title + " rates (LFADS - red, Truth - black)")
|
| elif output_dist == 'gaussian':
|
| means_vars = model_vals['output_dist_params']
|
| means, vars = np.split(means_vars,2, axis=2)
|
| stds = np.sqrt(vars)
|
| plot_time_series(means, bidx, n_to_plot=n_to_plot, scale=scale,
|
| title=col_title + " means (LFADS - red, Truth - black)")
|
| plot_time_series(means+stds, bidx, n_to_plot=n_to_plot, scale=scale,
|
| color='c')
|
| plot_time_series(means-stds, bidx, n_to_plot=n_to_plot, scale=scale,
|
| color='c')
|
| else:
|
| assert 'NIY'
|
|
|
|
|
| if truth_bxtxn is not None:
|
| plot_time_series(truth_bxtxn, bidx, n_to_plot=n_to_plot, color='k',
|
| scale=scale)
|
|
|
| input_title = ""
|
| if "controller_outputs" in model_vals.keys():
|
| input_title += " Controller Output"
|
| plt.subplot(nrows,2,3+subplot_cidx)
|
| u_t = model_vals['controller_outputs'][0:-1]
|
| plot_time_series(u_t, bidx, n_to_plot=n_to_plot, color='c', scale=1.0,
|
| title=col_title + input_title)
|
|
|
| if ext_input_bxtxi is not None:
|
| input_title += " External Input"
|
| plot_time_series(ext_input_bxtxi, n_to_plot=n_to_plot, color='b',
|
| scale=scale, title=col_title + input_title)
|
|
|
| plt.subplot(nrows,2,5+subplot_cidx)
|
| plot_time_series(means, bidx,
|
| n_to_plot=n_to_plot, scale=1.0,
|
| title=col_title + " Spikes (LFADS - red, Spikes - black)")
|
| plot_time_series(data_bxtxn, bidx, n_to_plot=n_to_plot, color='k', scale=1.0)
|
|
|
| plt.subplot(nrows,2,7+subplot_cidx)
|
| plot_time_series(model_vals['factors'], bidx, n_to_plot=n_to_plot, color='b',
|
| scale=2.0, title=col_title + " Factors")
|
|
|
| plt.subplot(nrows,2,9+subplot_cidx)
|
| plot_time_series(model_vals['gen_states'], bidx, n_to_plot=n_to_plot,
|
| color='g', scale=1.0, title=col_title + " Generator State")
|
|
|
| if bidx is not None:
|
| data_nxt = data_bxtxn[bidx,:,:].T
|
| params_nxt = model_vals['output_dist_params'][bidx,:,:].T
|
| else:
|
| data_nxt = np.mean(data_bxtxn, axis=0).T
|
| params_nxt = np.mean(model_vals['output_dist_params'], axis=0).T
|
| if output_dist == 'poisson':
|
| means_nxt = params_nxt
|
| elif output_dist == 'gaussian':
|
| means_nxt = np.vsplit(params_nxt,2)[0]
|
| else:
|
| assert "NIY"
|
|
|
| plt.subplot(nrows,2,11+subplot_cidx)
|
| plt.imshow(data_nxt, aspect='auto', interpolation='nearest')
|
| plt.title(col_title + ' Data')
|
|
|
| plt.subplot(nrows,2,13+subplot_cidx)
|
| plt.imshow(means_nxt, aspect='auto', interpolation='nearest')
|
| plt.title(col_title + ' Means')
|
|
|
|
|
| def plot_lfads(train_bxtxd, train_model_vals,
|
| train_ext_input_bxtxi=None, train_truth_bxtxd=None,
|
| valid_bxtxd=None, valid_model_vals=None,
|
| valid_ext_input_bxtxi=None, valid_truth_bxtxd=None,
|
| bidx=None, cf=1.0, output_dist='poisson'):
|
|
|
|
|
| f = plt.figure(figsize=(18,20), tight_layout=True)
|
| plot_lfads_timeseries(train_bxtxd, train_model_vals,
|
| train_ext_input_bxtxi,
|
| truth_bxtxn=train_truth_bxtxd,
|
| conversion_factor=cf, bidx=bidx,
|
| output_dist=output_dist, col_title='Train')
|
| plot_lfads_timeseries(valid_bxtxd, valid_model_vals,
|
| valid_ext_input_bxtxi,
|
| truth_bxtxn=valid_truth_bxtxd,
|
| conversion_factor=cf, bidx=bidx,
|
| output_dist=output_dist,
|
| subplot_cidx=1, col_title='Valid')
|
|
|
|
|
| f.canvas.draw()
|
| data = np.fromstring(f.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
| data_wxhx3 = data.reshape(f.canvas.get_width_height()[::-1] + (3,))
|
| plt.close()
|
|
|
| return data_wxhx3
|
|
|