Spaces:
Runtime error
Runtime error
Upload mi_lstm_cell.py
Browse files- mi_lstm_cell.py +77 -0
mi_lstm_cell.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
class MiLSTMCell(tf.nn.rnn_cell.RNNCell):
|
| 5 |
+
def __init__(self, num_units, forget_bias = 1.0, input_size = None,
|
| 6 |
+
state_is_tuple = True, activation = tf.tanh, reuse = None):
|
| 7 |
+
self.numUnits = num_units
|
| 8 |
+
self.forgetBias = forget_bias
|
| 9 |
+
self.activation = activation
|
| 10 |
+
self.reuse = reuse
|
| 11 |
+
|
| 12 |
+
@property
|
| 13 |
+
def state_size(self):
|
| 14 |
+
return tf.nn.rnn_cell.LSTMStateTuple(self.numUnits, self.numUnits)
|
| 15 |
+
|
| 16 |
+
@property
|
| 17 |
+
def output_size(self):
|
| 18 |
+
return self.numUnits
|
| 19 |
+
|
| 20 |
+
def mulWeights(self, inp, inDim, outDim, name = ""):
|
| 21 |
+
with tf.variable_scope("weights" + name):
|
| 22 |
+
W = tf.get_variable("weights", shape = (inDim, outDim),
|
| 23 |
+
initializer = tf.contrib.layers.xavier_initializer())
|
| 24 |
+
output = tf.matmul(inp, W)
|
| 25 |
+
return output
|
| 26 |
+
|
| 27 |
+
def addBiases(self, inp1, inp2, dim, name = ""):
|
| 28 |
+
with tf.variable_scope("additiveBiases" + name):
|
| 29 |
+
b = tf.get_variable("biases", shape = (dim,),
|
| 30 |
+
initializer = tf.zeros_initializer())
|
| 31 |
+
with tf.variable_scope("multiplicativeBias" + name):
|
| 32 |
+
beta = tf.get_variable("biases", shape = (3 * dim,),
|
| 33 |
+
initializer = tf.ones_initializer())
|
| 34 |
+
|
| 35 |
+
Wx, Uh, inter = tf.split(beta * tf.concat([inp1, inp2, inp1 * inp2], axis = 1),
|
| 36 |
+
num_or_size_splits = 3, axis = 1)
|
| 37 |
+
output = Wx + Uh + inter + b
|
| 38 |
+
return output
|
| 39 |
+
|
| 40 |
+
def __call__(self, inputs, state, scope = None):
|
| 41 |
+
scope = scope or type(self).__name__
|
| 42 |
+
with tf.variable_scope(scope, reuse = self.reuse):
|
| 43 |
+
c, h = state
|
| 44 |
+
inputSize = int(inputs.shape[1])
|
| 45 |
+
|
| 46 |
+
Wx = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxi")
|
| 47 |
+
Uh = self.mulWeights(h, self.numUnits, self.numUnits, name = "Uhi")
|
| 48 |
+
|
| 49 |
+
i = self.addBiases(Wx, Uh, self.numUnits, name = "i")
|
| 50 |
+
|
| 51 |
+
Wx = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxj")
|
| 52 |
+
Uh = self.mulWeights(h, self.numUnits, self.numUnits, name = "Uhj")
|
| 53 |
+
|
| 54 |
+
j = self.addBiases(Wx, Uh, self.numUnits, name = "l")
|
| 55 |
+
|
| 56 |
+
Wx = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxf")
|
| 57 |
+
Uh = self.mulWeights(h, self.numUnits, self.numUnits, name = "Uhf")
|
| 58 |
+
|
| 59 |
+
f = self.addBiases(Wx, Uh, self.numUnits, name = "f")
|
| 60 |
+
|
| 61 |
+
Wx = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxo")
|
| 62 |
+
Uh = self.mulWeights(h, self.numUnits, self.numUnits, name = "Uho")
|
| 63 |
+
|
| 64 |
+
o = self.addBiases(Wx, Uh, self.numUnits, name = "o")
|
| 65 |
+
# i, j, f, o = tf.split(value = concat, num_or_size_splits = 4, axis = 1)
|
| 66 |
+
|
| 67 |
+
newC = (c * tf.nn.sigmoid(f + self.forgetBias) + tf.nn.sigmoid(i) *
|
| 68 |
+
self.activation(j))
|
| 69 |
+
newH = self.activation(newC) * tf.nn.sigmoid(o)
|
| 70 |
+
|
| 71 |
+
newState = tf.nn.rnn_cell.LSTMStateTuple(newC, newH)
|
| 72 |
+
return newH, newState
|
| 73 |
+
|
| 74 |
+
def zero_state(self, batchSize, dtype = tf.float32):
|
| 75 |
+
return tf.nn.rnn_cell.LSTMStateTuple(tf.zeros((batchSize, self.numUnits), dtype = dtype),
|
| 76 |
+
tf.zeros((batchSize, self.numUnits), dtype = dtype))
|
| 77 |
+
|