import os import numpy as np import tensorflow as tf from collections import deque def sample(logits): noise = tf.random.uniform(tf.shape(input=logits)) return tf.argmax(input=logits - tf.math.log(-tf.math.log(noise)), axis=1) def cat_entropy(logits): a0 = logits - tf.reduce_max(input_tensor=logits, axis=1, keepdims=True) ea0 = tf.exp(a0) z0 = tf.reduce_sum(input_tensor=ea0, axis=1, keepdims=True) p0 = ea0 / z0 return tf.reduce_sum(input_tensor=p0 * (tf.math.log(z0) - a0), axis=1) def cat_entropy_softmax(p0): return - tf.reduce_sum(input_tensor=p0 * tf.math.log(p0 + 1e-6), axis = 1) def ortho_init(scale=1.0): def _ortho_init(shape, dtype, partition_info=None): #lasagne ortho init for tf shape = tuple(shape) if len(shape) == 2: flat_shape = shape elif len(shape) == 4: # assumes NHWC flat_shape = (np.prod(shape[:-1]), shape[-1]) else: raise NotImplementedError a = np.random.normal(0.0, 1.0, flat_shape) u, _, v = np.linalg.svd(a, full_matrices=False) q = u if u.shape == flat_shape else v # pick the one with the correct shape q = q.reshape(shape) return (scale * q[:shape[0], :shape[1]]).astype(np.float32) return _ortho_init def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0, data_format='NHWC', one_dim_bias=False): if data_format == 'NHWC': channel_ax = 3 strides = [1, stride, stride, 1] bshape = [1, 1, 1, nf] elif data_format == 'NCHW': channel_ax = 1 strides = [1, 1, stride, stride] bshape = [1, nf, 1, 1] else: raise NotImplementedError bias_var_shape = [nf] if one_dim_bias else [1, nf, 1, 1] try: nin = x.get_shape()[channel_ax].value except: nin = x.get_shape()[channel_ax] wshape = [rf, rf, nin, nf] with tf.compat.v1.variable_scope(scope): w = tf.compat.v1.get_variable("w", wshape, initializer=ortho_init(init_scale)) b = tf.compat.v1.get_variable("b", bias_var_shape, initializer=tf.compat.v1.constant_initializer(0.0)) if not one_dim_bias and data_format == 'NHWC': b = tf.reshape(b, bshape) return tf.nn.conv2d(input=x, filters=w, strides=strides, padding=pad, data_format=data_format) + b def fc(x, scope, nh, *, init_scale=1.0, init_bias=0.0): with tf.compat.v1.variable_scope(scope): try: nin = x.get_shape()[1].value except: nin = x.get_shape()[1] w = tf.compat.v1.get_variable("w", [nin, nh], initializer=ortho_init(init_scale)) b = tf.compat.v1.get_variable("b", [nh], initializer=tf.compat.v1.constant_initializer(init_bias)) return tf.matmul(x, w)+b def batch_to_seq(h, nbatch, nsteps, flat=False): if flat: h = tf.reshape(h, [nbatch, nsteps]) else: h = tf.reshape(h, [nbatch, nsteps, -1]) return [tf.squeeze(v, [1]) for v in tf.split(axis=1, num_or_size_splits=nsteps, value=h)] def seq_to_batch(h, flat = False): shape = h[0].get_shape().as_list() if not flat: assert(len(shape) > 1) nh = h[0].get_shape()[-1].value return tf.reshape(tf.concat(axis=1, values=h), [-1, nh]) else: return tf.reshape(tf.stack(values=h, axis=1), [-1]) def lstm(xs, ms, s, scope, nh, init_scale=1.0): nbatch, nin = [v.value for v in xs[0].get_shape()] with tf.compat.v1.variable_scope(scope): wx = tf.compat.v1.get_variable("wx", [nin, nh*4], initializer=ortho_init(init_scale)) wh = tf.compat.v1.get_variable("wh", [nh, nh*4], initializer=ortho_init(init_scale)) b = tf.compat.v1.get_variable("b", [nh*4], initializer=tf.compat.v1.constant_initializer(0.0)) c, h = tf.split(axis=1, num_or_size_splits=2, value=s) for idx, (x, m) in enumerate(zip(xs, ms)): c = c*(1-m) h = h*(1-m) z = tf.matmul(x, wx) + tf.matmul(h, wh) + b i, f, o, u = tf.split(axis=1, num_or_size_splits=4, value=z) i = tf.nn.sigmoid(i) f = tf.nn.sigmoid(f) o = tf.nn.sigmoid(o) u = tf.tanh(u) c = f*c + i*u h = o*tf.tanh(c) xs[idx] = h s = tf.concat(axis=1, values=[c, h]) return xs, s def _ln(x, g, b, e=1e-5, axes=[1]): u, s = tf.nn.moments(x=x, axes=axes, keepdims=True) x = (x-u)/tf.sqrt(s+e) x = x*g+b return x def lnlstm(xs, ms, s, scope, nh, init_scale=1.0): nbatch, nin = [v.value for v in xs[0].get_shape()] with tf.compat.v1.variable_scope(scope): wx = tf.compat.v1.get_variable("wx", [nin, nh*4], initializer=ortho_init(init_scale)) gx = tf.compat.v1.get_variable("gx", [nh*4], initializer=tf.compat.v1.constant_initializer(1.0)) bx = tf.compat.v1.get_variable("bx", [nh*4], initializer=tf.compat.v1.constant_initializer(0.0)) wh = tf.compat.v1.get_variable("wh", [nh, nh*4], initializer=ortho_init(init_scale)) gh = tf.compat.v1.get_variable("gh", [nh*4], initializer=tf.compat.v1.constant_initializer(1.0)) bh = tf.compat.v1.get_variable("bh", [nh*4], initializer=tf.compat.v1.constant_initializer(0.0)) b = tf.compat.v1.get_variable("b", [nh*4], initializer=tf.compat.v1.constant_initializer(0.0)) gc = tf.compat.v1.get_variable("gc", [nh], initializer=tf.compat.v1.constant_initializer(1.0)) bc = tf.compat.v1.get_variable("bc", [nh], initializer=tf.compat.v1.constant_initializer(0.0)) c, h = tf.split(axis=1, num_or_size_splits=2, value=s) for idx, (x, m) in enumerate(zip(xs, ms)): c = c*(1-m) h = h*(1-m) z = _ln(tf.matmul(x, wx), gx, bx) + _ln(tf.matmul(h, wh), gh, bh) + b i, f, o, u = tf.split(axis=1, num_or_size_splits=4, value=z) i = tf.nn.sigmoid(i) f = tf.nn.sigmoid(f) o = tf.nn.sigmoid(o) u = tf.tanh(u) c = f*c + i*u h = o*tf.tanh(_ln(c, gc, bc)) xs[idx] = h s = tf.concat(axis=1, values=[c, h]) return xs, s def conv_to_fc(x): try: nh = np.prod([v.value for v in x.get_shape()[1:]]) except: nh = np.prod([v for v in x.get_shape()[1:]]) x = tf.reshape(x, [-1, nh]) return x def discount_with_dones(rewards, dones, gamma): discounted = [] r = 0 for reward, done in zip(rewards[::-1], dones[::-1]): r = reward + gamma*r*(1.-done) # fixed off by one bug discounted.append(r) return discounted[::-1] def find_trainable_variables(key): return tf.compat.v1.trainable_variables(key) def make_path(f): return os.makedirs(f, exist_ok=True) def constant(p): return 1 def linear(p): return 1-p def middle_drop(p): eps = 0.75 if 1-p