Spaces:
Runtime error
Runtime error
Upload mac_cell.py
Browse files- mac_cell.py +592 -0
mac_cell.py
ADDED
|
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import numpy as np
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
|
| 5 |
+
import ops
|
| 6 |
+
from config import config
|
| 7 |
+
|
| 8 |
+
MACCellTuple = collections.namedtuple("MACCellTuple", ("control", "memory"))
|
| 9 |
+
|
| 10 |
+
'''
|
| 11 |
+
The MAC cell.
|
| 12 |
+
|
| 13 |
+
Recurrent cell for multi-step reasoning. Presented in https://arxiv.org/abs/1803.03067.
|
| 14 |
+
The cell has recurrent control and memory states that interact with the question
|
| 15 |
+
and knowledge base (image) respectively.
|
| 16 |
+
|
| 17 |
+
The hidden state structure is MACCellTuple(control, memory)
|
| 18 |
+
|
| 19 |
+
At each step the cell performs by calling to three subunits: control, read and write.
|
| 20 |
+
|
| 21 |
+
1. The Control Unit computes the control state by computing attention over the question words.
|
| 22 |
+
The control state represents the current reasoning operation the cell performs.
|
| 23 |
+
|
| 24 |
+
2. The Read Unit retrieves information from the knowledge base, given the control and previous
|
| 25 |
+
memory values, by computing 2-stages attention over the knowledge base.
|
| 26 |
+
|
| 27 |
+
3. The Write Unit integrates the retrieved information to the previous hidden memory state,
|
| 28 |
+
given the value of the control state, to perform the current reasoning operation.
|
| 29 |
+
'''
|
| 30 |
+
class MACCell(tf.nn.rnn_cell.RNNCell):
|
| 31 |
+
|
| 32 |
+
'''Initialize the MAC cell.
|
| 33 |
+
(Note that in the current version the cell is stateful --
|
| 34 |
+
updating its own internals when being called)
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
vecQuestions: the vector representation of the questions.
|
| 38 |
+
[batchSize, ctrlDim]
|
| 39 |
+
|
| 40 |
+
questionWords: the question words embeddings.
|
| 41 |
+
[batchSize, questionLength, ctrlDim]
|
| 42 |
+
|
| 43 |
+
questionCntxWords: the encoder outputs -- the "contextual" question words.
|
| 44 |
+
[batchSize, questionLength, ctrlDim]
|
| 45 |
+
|
| 46 |
+
questionLengths: the length of each question.
|
| 47 |
+
[batchSize]
|
| 48 |
+
|
| 49 |
+
memoryDropout: dropout on the memory state (Tensor scalar).
|
| 50 |
+
readDropout: dropout inside the read unit (Tensor scalar).
|
| 51 |
+
writeDropout: dropout on the new information that gets into the write unit (Tensor scalar).
|
| 52 |
+
|
| 53 |
+
batchSize: batch size (Tensor scalar).
|
| 54 |
+
train: train or test mod (Tensor boolean).
|
| 55 |
+
reuse: reuse cell
|
| 56 |
+
|
| 57 |
+
knowledgeBase:
|
| 58 |
+
'''
|
| 59 |
+
def __init__(self, vecQuestions, questionWords, questionCntxWords, questionLengths,
|
| 60 |
+
knowledgeBase, memoryDropout, readDropout, writeDropout,
|
| 61 |
+
batchSize, train, reuse = None):
|
| 62 |
+
|
| 63 |
+
self.vecQuestions = vecQuestions
|
| 64 |
+
self.questionWords = questionWords
|
| 65 |
+
self.questionCntxWords = questionCntxWords
|
| 66 |
+
self.questionLengths = questionLengths
|
| 67 |
+
|
| 68 |
+
self.knowledgeBase = knowledgeBase
|
| 69 |
+
|
| 70 |
+
self.dropouts = {}
|
| 71 |
+
self.dropouts["memory"] = memoryDropout
|
| 72 |
+
self.dropouts["read"] = readDropout
|
| 73 |
+
self.dropouts["write"] = writeDropout
|
| 74 |
+
|
| 75 |
+
self.none = tf.zeros((batchSize, 1), dtype = tf.float32)
|
| 76 |
+
|
| 77 |
+
self.batchSize = batchSize
|
| 78 |
+
self.train = train
|
| 79 |
+
self.reuse = reuse
|
| 80 |
+
|
| 81 |
+
'''
|
| 82 |
+
Cell state size.
|
| 83 |
+
'''
|
| 84 |
+
@property
|
| 85 |
+
def state_size(self):
|
| 86 |
+
return MACCellTuple(config.ctrlDim, config.memDim)
|
| 87 |
+
|
| 88 |
+
'''
|
| 89 |
+
Cell output size. Currently it doesn't have any outputs.
|
| 90 |
+
'''
|
| 91 |
+
@property
|
| 92 |
+
def output_size(self):
|
| 93 |
+
return 1
|
| 94 |
+
|
| 95 |
+
# pass encoder hidden states to control?
|
| 96 |
+
'''
|
| 97 |
+
The Control Unit: computes the new control state -- the reasoning operation,
|
| 98 |
+
by summing up the word embeddings according to a computed attention distribution.
|
| 99 |
+
|
| 100 |
+
The unit is recurrent: it receives the whole question and the previous control state,
|
| 101 |
+
merge them together (resulting in the "continuous control"), and then uses that
|
| 102 |
+
to compute attentions over the question words. Finally, it combines the words
|
| 103 |
+
together according to the attention distribution to get the new control state.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
controlInput: external inputs to control unit (the question vector).
|
| 107 |
+
[batchSize, ctrlDim]
|
| 108 |
+
|
| 109 |
+
inWords: the representation of the words used to compute the attention.
|
| 110 |
+
[batchSize, questionLength, ctrlDim]
|
| 111 |
+
|
| 112 |
+
outWords: the representation of the words that are summed up.
|
| 113 |
+
(by default inWords == outWords)
|
| 114 |
+
[batchSize, questionLength, ctrlDim]
|
| 115 |
+
|
| 116 |
+
questionLengths: the length of each question.
|
| 117 |
+
[batchSize]
|
| 118 |
+
|
| 119 |
+
control: the previous control hidden state value.
|
| 120 |
+
[batchSize, ctrlDim]
|
| 121 |
+
|
| 122 |
+
contControl: optional corresponding continuous control state
|
| 123 |
+
(before casting the attention over the words).
|
| 124 |
+
[batchSize, ctrlDim]
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
the new control state
|
| 128 |
+
[batchSize, ctrlDim]
|
| 129 |
+
|
| 130 |
+
the continuous (pre-attention) control
|
| 131 |
+
[batchSize, ctrlDim]
|
| 132 |
+
'''
|
| 133 |
+
def control(self, controlInput, inWords, outWords, questionLengths,
|
| 134 |
+
control, contControl = None, name = "", reuse = None):
|
| 135 |
+
|
| 136 |
+
with tf.variable_scope("control" + name, reuse = reuse):
|
| 137 |
+
dim = config.ctrlDim
|
| 138 |
+
|
| 139 |
+
## Step 1: compute "continuous" control state given previous control and question.
|
| 140 |
+
# control inputs: question and previous control
|
| 141 |
+
newContControl = controlInput
|
| 142 |
+
if config.controlFeedPrev:
|
| 143 |
+
newContControl = control if config.controlFeedPrevAtt else contControl
|
| 144 |
+
if config.controlFeedInputs:
|
| 145 |
+
newContControl = tf.concat([newContControl, controlInput], axis = -1)
|
| 146 |
+
dim += config.ctrlDim
|
| 147 |
+
|
| 148 |
+
# merge inputs together
|
| 149 |
+
newContControl = ops.linear(newContControl, dim, config.ctrlDim,
|
| 150 |
+
act = config.controlContAct, name = "contControl")
|
| 151 |
+
dim = config.ctrlDim
|
| 152 |
+
|
| 153 |
+
## Step 2: compute attention distribution over words and sum them up accordingly.
|
| 154 |
+
# compute interactions with question words
|
| 155 |
+
interactions = tf.expand_dims(newContControl, axis = 1) * inWords
|
| 156 |
+
|
| 157 |
+
# optionally concatenate words
|
| 158 |
+
if config.controlConcatWords:
|
| 159 |
+
interactions = tf.concat([interactions, inWords], axis = -1)
|
| 160 |
+
dim += config.ctrlDim
|
| 161 |
+
|
| 162 |
+
# optional projection
|
| 163 |
+
if config.controlProj:
|
| 164 |
+
interactions = ops.linear(interactions, dim, config.ctrlDim,
|
| 165 |
+
act = config.controlProjAct)
|
| 166 |
+
dim = config.ctrlDim
|
| 167 |
+
|
| 168 |
+
# compute attention distribution over words and summarize them accordingly
|
| 169 |
+
logits = ops.inter2logits(interactions, dim)
|
| 170 |
+
# self.interL = (interW, interb)
|
| 171 |
+
|
| 172 |
+
# if config.controlCoverage:
|
| 173 |
+
# logits += coverageBias * coverage
|
| 174 |
+
|
| 175 |
+
attention = tf.nn.softmax(ops.expMask(logits, questionLengths))
|
| 176 |
+
self.attentions["question"].append(attention)
|
| 177 |
+
|
| 178 |
+
# if config.controlCoverage:
|
| 179 |
+
# coverage += attention # Add logits instead?
|
| 180 |
+
|
| 181 |
+
newControl = ops.att2Smry(attention, outWords)
|
| 182 |
+
|
| 183 |
+
# ablation: use continuous control (pre-attention) instead
|
| 184 |
+
if config.controlContinuous:
|
| 185 |
+
newControl = newContControl
|
| 186 |
+
|
| 187 |
+
return newControl, newContControl
|
| 188 |
+
|
| 189 |
+
'''
|
| 190 |
+
The read unit extracts relevant information from the knowledge base given the
|
| 191 |
+
cell's memory and control states. It computes attention distribution over
|
| 192 |
+
the knowledge base by comparing it first to the memory and then to the control.
|
| 193 |
+
Finally, it uses the attention distribution to sum up the knowledge base accordingly,
|
| 194 |
+
resulting in an extraction of relevant information.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
knowledge base: representation of the knowledge base (image).
|
| 198 |
+
[batchSize, kbSize (Height * Width), memDim]
|
| 199 |
+
|
| 200 |
+
memory: the cell's memory state
|
| 201 |
+
[batchSize, memDim]
|
| 202 |
+
|
| 203 |
+
control: the cell's control state
|
| 204 |
+
[batchSize, ctrlDim]
|
| 205 |
+
|
| 206 |
+
Returns the information extracted.
|
| 207 |
+
[batchSize, memDim]
|
| 208 |
+
'''
|
| 209 |
+
def read(self, knowledgeBase, memory, control, name = "", reuse = None):
|
| 210 |
+
with tf.variable_scope("read" + name, reuse = reuse):
|
| 211 |
+
dim = config.memDim
|
| 212 |
+
|
| 213 |
+
## memory dropout
|
| 214 |
+
if config.memoryVariationalDropout:
|
| 215 |
+
memory = ops.applyVarDpMask(memory, self.memDpMask, self.dropouts["memory"])
|
| 216 |
+
else:
|
| 217 |
+
memory = tf.nn.dropout(memory, self.dropouts["memory"])
|
| 218 |
+
|
| 219 |
+
## Step 1: knowledge base / memory interactions
|
| 220 |
+
# parameters for knowledge base and memory projection
|
| 221 |
+
proj = None
|
| 222 |
+
if config.readProjInputs:
|
| 223 |
+
proj = {"dim": config.attDim, "shared": config.readProjShared, "dropout": self.dropouts["read"] }
|
| 224 |
+
dim = config.attDim
|
| 225 |
+
|
| 226 |
+
# parameters for concatenating knowledge base elements
|
| 227 |
+
concat = {"x": config.readMemConcatKB, "proj": config.readMemConcatProj}
|
| 228 |
+
|
| 229 |
+
# compute interactions between knowledge base and memory
|
| 230 |
+
interactions, interDim = ops.mul(x = knowledgeBase, y = memory, dim = config.memDim,
|
| 231 |
+
proj = proj, concat = concat, interMod = config.readMemAttType, name = "memInter")
|
| 232 |
+
|
| 233 |
+
projectedKB = proj.get("x") if proj else None
|
| 234 |
+
|
| 235 |
+
# project memory interactions back to hidden dimension
|
| 236 |
+
if config.readMemProj:
|
| 237 |
+
interactions = ops.linear(interactions, interDim, dim, act = config.readMemAct,
|
| 238 |
+
name = "memKbProj")
|
| 239 |
+
else:
|
| 240 |
+
dim = interDim
|
| 241 |
+
|
| 242 |
+
## Step 2: compute interactions with control
|
| 243 |
+
if config.readCtrl:
|
| 244 |
+
# compute interactions with control
|
| 245 |
+
if config.ctrlDim != dim:
|
| 246 |
+
control = ops.linear(control, ctrlDim, dim, name = "ctrlProj")
|
| 247 |
+
|
| 248 |
+
interactions, interDim = ops.mul(interactions, control, dim,
|
| 249 |
+
interMod = config.readCtrlAttType, concat = {"x": config.readCtrlConcatInter},
|
| 250 |
+
name = "ctrlInter")
|
| 251 |
+
|
| 252 |
+
# optionally concatenate knowledge base elements
|
| 253 |
+
if config.readCtrlConcatKB:
|
| 254 |
+
if config.readCtrlConcatProj:
|
| 255 |
+
addedInp, addedDim = projectedKB, config.attDim
|
| 256 |
+
else:
|
| 257 |
+
addedInp, addedDim = knowledgeBase, config.memDim
|
| 258 |
+
interactions = tf.concat([interactions, addedInp], axis = -1)
|
| 259 |
+
dim += addedDim
|
| 260 |
+
|
| 261 |
+
# optional nonlinearity
|
| 262 |
+
interactions = ops.activations[config.readCtrlAct](interactions)
|
| 263 |
+
|
| 264 |
+
## Step 3: sum attentions up over the knowledge base
|
| 265 |
+
# transform vectors to attention distribution
|
| 266 |
+
attention = ops.inter2att(interactions, dim, dropout = self.dropouts["read"])
|
| 267 |
+
|
| 268 |
+
self.attentions["kb"].append(attention)
|
| 269 |
+
|
| 270 |
+
# optionally use projected knowledge base instead of original
|
| 271 |
+
if config.readSmryKBProj:
|
| 272 |
+
knowledgeBase = projectedKB
|
| 273 |
+
|
| 274 |
+
# sum up the knowledge base according to the distribution
|
| 275 |
+
information = ops.att2Smry(attention, knowledgeBase)
|
| 276 |
+
|
| 277 |
+
return information
|
| 278 |
+
|
| 279 |
+
'''
|
| 280 |
+
The write unit integrates newly retrieved information (from the read unit),
|
| 281 |
+
with the cell's previous memory hidden state, resulting in a new memory value.
|
| 282 |
+
The unit optionally supports:
|
| 283 |
+
1. Self-attention to previous control / memory states, in order to consider previous steps
|
| 284 |
+
in the reasoning process.
|
| 285 |
+
2. Gating between the new memory and previous memory states, to allow dynamic adjustment
|
| 286 |
+
of the reasoning process length.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
memory: the cell's memory state
|
| 290 |
+
[batchSize, memDim]
|
| 291 |
+
|
| 292 |
+
info: the information to integrate with the memory
|
| 293 |
+
[batchSize, memDim]
|
| 294 |
+
|
| 295 |
+
control: the cell's control state
|
| 296 |
+
[batchSize, ctrlDim]
|
| 297 |
+
|
| 298 |
+
contControl: optional corresponding continuous control state
|
| 299 |
+
(before casting the attention over the words).
|
| 300 |
+
[batchSize, ctrlDim]
|
| 301 |
+
|
| 302 |
+
Return the new memory
|
| 303 |
+
[batchSize, memDim]
|
| 304 |
+
'''
|
| 305 |
+
def write(self, memory, info, control, contControl = None, name = "", reuse = None):
|
| 306 |
+
with tf.variable_scope("write" + name, reuse = reuse):
|
| 307 |
+
|
| 308 |
+
# optionally project info
|
| 309 |
+
if config.writeInfoProj:
|
| 310 |
+
info = ops.linear(info, config.memDim, config.memDim, name = "info")
|
| 311 |
+
|
| 312 |
+
# optional info nonlinearity
|
| 313 |
+
info = ops.activations[config.writeInfoAct](info)
|
| 314 |
+
|
| 315 |
+
# compute self-attention vector based on previous controls and memories
|
| 316 |
+
if config.writeSelfAtt:
|
| 317 |
+
selfControl = control
|
| 318 |
+
if config.writeSelfAttMod == "CONT":
|
| 319 |
+
selfControl = contControl
|
| 320 |
+
# elif config.writeSelfAttMod == "POST":
|
| 321 |
+
# selfControl = postControl
|
| 322 |
+
selfControl = ops.linear(selfControl, config.ctrlDim, config.ctrlDim, name = "ctrlProj")
|
| 323 |
+
|
| 324 |
+
interactions = self.controls * tf.expand_dims(selfControl, axis = 1)
|
| 325 |
+
|
| 326 |
+
# if config.selfAttShareInter:
|
| 327 |
+
# selfAttlogits = self.linearP(selfAttInter, config.encDim, 1, self.interL[0], self.interL[1], name = "modSelfAttInter")
|
| 328 |
+
attention = ops.inter2att(interactions, config.ctrlDim, name = "selfAttention")
|
| 329 |
+
self.attentions["self"].append(attention)
|
| 330 |
+
selfSmry = ops.att2Smry(attention, self.memories)
|
| 331 |
+
|
| 332 |
+
# get write unit inputs: previous memory, the new info, optionally self-attention / control
|
| 333 |
+
newMemory, dim = memory, config.memDim
|
| 334 |
+
if config.writeInputs == "INFO":
|
| 335 |
+
newMemory = info
|
| 336 |
+
elif config.writeInputs == "SUM":
|
| 337 |
+
newMemory += info
|
| 338 |
+
elif config.writeInputs == "BOTH":
|
| 339 |
+
newMemory, dim = ops.concat(newMemory, info, dim, mul = config.writeConcatMul)
|
| 340 |
+
# else: MEM
|
| 341 |
+
|
| 342 |
+
if config.writeSelfAtt:
|
| 343 |
+
newMemory = tf.concat([newMemory, selfSmry], axis = -1)
|
| 344 |
+
dim += config.memDim
|
| 345 |
+
|
| 346 |
+
if config.writeMergeCtrl:
|
| 347 |
+
newMemory = tf.concat([newMemory, control], axis = -1)
|
| 348 |
+
dim += config.memDim
|
| 349 |
+
|
| 350 |
+
# project memory back to memory dimension
|
| 351 |
+
if config.writeMemProj or (dim != config.memDim):
|
| 352 |
+
newMemory = ops.linear(newMemory, dim, config.memDim, name = "newMemory")
|
| 353 |
+
|
| 354 |
+
# optional memory nonlinearity
|
| 355 |
+
newMemory = ops.activations[config.writeMemAct](newMemory)
|
| 356 |
+
|
| 357 |
+
# write unit gate
|
| 358 |
+
if config.writeGate:
|
| 359 |
+
gateDim = config.memDim
|
| 360 |
+
if config.writeGateShared:
|
| 361 |
+
gateDim = 1
|
| 362 |
+
|
| 363 |
+
z = tf.sigmoid(ops.linear(control, config.ctrlDim, gateDim, name = "gate", bias = config.writeGateBias))
|
| 364 |
+
|
| 365 |
+
self.attentions["gate"].append(z)
|
| 366 |
+
|
| 367 |
+
newMemory = newMemory * z + memory * (1 - z)
|
| 368 |
+
|
| 369 |
+
# optional batch normalization
|
| 370 |
+
if config.memoryBN:
|
| 371 |
+
newMemory = tf.contrib.layers.batch_norm(newMemory, decay = config.bnDecay,
|
| 372 |
+
center = config.bnCenter, scale = config.bnScale,
|
| 373 |
+
is_training = self.train, updates_collections = None)
|
| 374 |
+
|
| 375 |
+
return newMemory
|
| 376 |
+
|
| 377 |
+
def memAutoEnc(newMemory, info, control, name = "", reuse = None):
|
| 378 |
+
with tf.variable_scope("memAutoEnc" + name, reuse = reuse):
|
| 379 |
+
# inputs to auto encoder
|
| 380 |
+
features = info if config.autoEncMemInputs == "INFO" else newMemory
|
| 381 |
+
features = ops.linear(features, config.memDim, config.ctrlDim,
|
| 382 |
+
act = config.autoEncMemAct, name = "aeMem")
|
| 383 |
+
|
| 384 |
+
# reconstruct control
|
| 385 |
+
if config.autoEncMemLoss == "CONT":
|
| 386 |
+
loss = tf.reduce_mean(tf.squared_difference(control, features))
|
| 387 |
+
else:
|
| 388 |
+
interactions, dim = ops.mul(self.questionCntxWords, features, config.ctrlDim,
|
| 389 |
+
concat = {"x": config.autoEncMemCnct}, mulBias = config.mulBias, name = "aeMem")
|
| 390 |
+
|
| 391 |
+
logits = ops.inter2logits(interactions, dim)
|
| 392 |
+
logits = self.expMask(logits, self.questionLengths)
|
| 393 |
+
|
| 394 |
+
# reconstruct word attentions
|
| 395 |
+
if config.autoEncMemLoss == "PROB":
|
| 396 |
+
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
|
| 397 |
+
labels = self.attentions["question"][-1], logits = logits))
|
| 398 |
+
|
| 399 |
+
# reconstruct control through words attentions
|
| 400 |
+
else:
|
| 401 |
+
attention = tf.nn.softmax(logits)
|
| 402 |
+
summary = ops.att2Smry(attention, self.questionCntxWords)
|
| 403 |
+
loss = tf.reduce_mean(tf.squared_difference(control, summary))
|
| 404 |
+
|
| 405 |
+
return loss
|
| 406 |
+
|
| 407 |
+
'''
|
| 408 |
+
Call the cell to get new control and memory states.
|
| 409 |
+
|
| 410 |
+
Args:
|
| 411 |
+
inputs: in the current implementation the cell don't get recurrent inputs
|
| 412 |
+
every iteration (argument for comparability with rnn interface).
|
| 413 |
+
|
| 414 |
+
state: the cell current state (control, memory)
|
| 415 |
+
MACCellTuple([batchSize, ctrlDim],[batchSize, memDim])
|
| 416 |
+
|
| 417 |
+
Returns the new state -- the new memory and control values.
|
| 418 |
+
MACCellTuple([batchSize, ctrlDim],[batchSize, memDim])
|
| 419 |
+
'''
|
| 420 |
+
def __call__(self, inputs, state, scope = None):
|
| 421 |
+
scope = scope or type(self).__name__
|
| 422 |
+
with tf.variable_scope(scope, reuse = self.reuse): # as tfscope
|
| 423 |
+
control = state.control
|
| 424 |
+
memory = state.memory
|
| 425 |
+
|
| 426 |
+
# cell sharing
|
| 427 |
+
inputName = "qInput"
|
| 428 |
+
inputNameU = "qInputU"
|
| 429 |
+
inputReuseU = inputReuse = (self.iteration > 0)
|
| 430 |
+
if config.controlInputUnshared:
|
| 431 |
+
inputNameU = "qInput%d" % self.iteration
|
| 432 |
+
inputReuseU = None
|
| 433 |
+
|
| 434 |
+
cellName = ""
|
| 435 |
+
cellReuse = (self.iteration > 0)
|
| 436 |
+
if config.unsharedCells:
|
| 437 |
+
cellName = str(self.iteration)
|
| 438 |
+
cellReuse = None
|
| 439 |
+
|
| 440 |
+
## control unit
|
| 441 |
+
# prepare question input to control
|
| 442 |
+
controlInput = ops.linear(self.vecQuestions, config.ctrlDim, config.ctrlDim,
|
| 443 |
+
name = inputName, reuse = inputReuse)
|
| 444 |
+
|
| 445 |
+
controlInput = ops.activations[config.controlInputAct](controlInput)
|
| 446 |
+
|
| 447 |
+
controlInput = ops.linear(controlInput, config.ctrlDim, config.ctrlDim,
|
| 448 |
+
name = inputNameU, reuse = inputReuseU)
|
| 449 |
+
|
| 450 |
+
newControl, self.contControl = self.control(controlInput, self.inWords, self.outWords,
|
| 451 |
+
self.questionLengths, control, self.contControl, name = cellName, reuse = cellReuse)
|
| 452 |
+
|
| 453 |
+
# read unit
|
| 454 |
+
# ablation: use whole question as control
|
| 455 |
+
if config.controlWholeQ:
|
| 456 |
+
newControl = self.vecQuestions
|
| 457 |
+
# ops.linear(self.vecQuestions, config.ctrlDim, projDim, name = "qMod")
|
| 458 |
+
|
| 459 |
+
info = self.read(self.knowledgeBase, memory, newControl, name = cellName, reuse = cellReuse)
|
| 460 |
+
|
| 461 |
+
if config.writeDropout < 1.0:
|
| 462 |
+
# write unit
|
| 463 |
+
info = tf.nn.dropout(info, self.dropouts["write"])
|
| 464 |
+
|
| 465 |
+
newMemory = self.write(memory, info, newControl, self.contControl, name = cellName, reuse = cellReuse)
|
| 466 |
+
|
| 467 |
+
# add auto encoder loss for memory
|
| 468 |
+
# if config.autoEncMem:
|
| 469 |
+
# self.autoEncLosses["memory"] += memAutoEnc(newMemory, info, newControl)
|
| 470 |
+
|
| 471 |
+
# append as standard list?
|
| 472 |
+
self.controls = tf.concat([self.controls, tf.expand_dims(newControl, axis = 1)], axis = 1)
|
| 473 |
+
self.memories = tf.concat([self.memories, tf.expand_dims(newMemory, axis = 1)], axis = 1)
|
| 474 |
+
self.infos = tf.concat([self.infos, tf.expand_dims(info, axis = 1)], axis = 1)
|
| 475 |
+
|
| 476 |
+
# self.contControls = tf.concat([self.contControls, tf.expand_dims(contControl, axis = 1)], axis = 1)
|
| 477 |
+
# self.postControls = tf.concat([self.controls, tf.expand_dims(postControls, axis = 1)], axis = 1)
|
| 478 |
+
|
| 479 |
+
newState = MACCellTuple(newControl, newMemory)
|
| 480 |
+
return self.none, newState
|
| 481 |
+
|
| 482 |
+
'''
|
| 483 |
+
Initializes the a hidden state to based on the value of the initType:
|
| 484 |
+
"PRM" for parametric initialization
|
| 485 |
+
"ZERO" for zero initialization
|
| 486 |
+
"Q" to initialize to question vectors.
|
| 487 |
+
|
| 488 |
+
Args:
|
| 489 |
+
name: the state variable name.
|
| 490 |
+
dim: the dimension of the state.
|
| 491 |
+
initType: the type of the initialization
|
| 492 |
+
batchSize: the batch size
|
| 493 |
+
|
| 494 |
+
Returns the initialized hidden state.
|
| 495 |
+
'''
|
| 496 |
+
def initState(self, name, dim, initType, batchSize):
|
| 497 |
+
if initType == "PRM":
|
| 498 |
+
prm = tf.get_variable(name, shape = (dim, ),
|
| 499 |
+
initializer = tf.random_normal_initializer())
|
| 500 |
+
initState = tf.tile(tf.expand_dims(prm, axis = 0), [batchSize, 1])
|
| 501 |
+
elif initType == "ZERO":
|
| 502 |
+
initState = tf.zeros((batchSize, dim), dtype = tf.float32)
|
| 503 |
+
else: # "Q"
|
| 504 |
+
initState = self.vecQuestions
|
| 505 |
+
return initState
|
| 506 |
+
|
| 507 |
+
'''
|
| 508 |
+
Add a parametric null word to the questions.
|
| 509 |
+
|
| 510 |
+
Args:
|
| 511 |
+
words: the words to add a null word to.
|
| 512 |
+
[batchSize, questionLentgth]
|
| 513 |
+
|
| 514 |
+
lengths: question lengths.
|
| 515 |
+
[batchSize]
|
| 516 |
+
|
| 517 |
+
Returns the updated word sequence and lengths.
|
| 518 |
+
'''
|
| 519 |
+
def addNullWord(words, lengths):
|
| 520 |
+
nullWord = tf.get_variable("zeroWord", shape = (1 , config.ctrlDim), initializer = tf.random_normal_initializer())
|
| 521 |
+
nullWord = tf.tile(tf.expand_dims(nullWord, axis = 0), [self.batchSize, 1, 1])
|
| 522 |
+
words = tf.concat([nullWord, words], axis = 1)
|
| 523 |
+
lengths += 1
|
| 524 |
+
return words, lengths
|
| 525 |
+
|
| 526 |
+
'''
|
| 527 |
+
Initializes the cell internal state (currently it's stateful). In particular,
|
| 528 |
+
1. Data-structures (lists of attention maps and accumulated losses).
|
| 529 |
+
2. The memory and control states.
|
| 530 |
+
3. The knowledge base (optionally merging it with the question vectors)
|
| 531 |
+
4. The question words used by the cell (either the original word embeddings, or the
|
| 532 |
+
encoder outputs, with optional projection).
|
| 533 |
+
|
| 534 |
+
Args:
|
| 535 |
+
batchSize: the batch size
|
| 536 |
+
|
| 537 |
+
Returns the initial cell state.
|
| 538 |
+
'''
|
| 539 |
+
def zero_state(self, batchSize, dtype = tf.float32):
|
| 540 |
+
## initialize data-structures
|
| 541 |
+
self.attentions = {"kb": [], "question": [], "self": [], "gate": []}
|
| 542 |
+
self.autoEncLosses = {"control": tf.constant(0.0), "memory": tf.constant(0.0)}
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
## initialize state
|
| 546 |
+
initialControl = self.initState("initCtrl", config.ctrlDim, config.initCtrl, batchSize)
|
| 547 |
+
initialMemory = self.initState("initMem", config.memDim, config.initMem, batchSize)
|
| 548 |
+
|
| 549 |
+
self.controls = tf.expand_dims(initialControl, axis = 1)
|
| 550 |
+
self.memories = tf.expand_dims(initialMemory, axis = 1)
|
| 551 |
+
self.infos = tf.expand_dims(initialMemory, axis = 1)
|
| 552 |
+
|
| 553 |
+
self.contControl = initialControl
|
| 554 |
+
# self.contControls = tf.expand_dims(initialControl, axis = 1)
|
| 555 |
+
# self.postControls = tf.expand_dims(initialControl, axis = 1)
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
## initialize knowledge base
|
| 559 |
+
# optionally merge question into knowledge base representation
|
| 560 |
+
if config.initKBwithQ != "NON":
|
| 561 |
+
iVecQuestions = ops.linear(self.vecQuestions, config.ctrlDim, config.memDim, name = "questions")
|
| 562 |
+
|
| 563 |
+
concatMul = (config.initKBwithQ == "MUL")
|
| 564 |
+
cnct, dim = ops.concat(self.knowledgeBase, iVecQuestions, config.memDim, mul = concatMul, expandY = True)
|
| 565 |
+
self.knowledgeBase = ops.linear(cnct, dim, config.memDim, name = "initKB")
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
## initialize question words
|
| 569 |
+
# choose question words to work with (original embeddings or encoder outputs)
|
| 570 |
+
words = self.questionCntxWords if config.controlContextual else self.questionWords
|
| 571 |
+
|
| 572 |
+
# optionally add parametric "null" word in the to all questions
|
| 573 |
+
if config.addNullWord:
|
| 574 |
+
words, questionLengths = self.addNullWord(words, questionLengths)
|
| 575 |
+
|
| 576 |
+
# project words
|
| 577 |
+
self.inWords = self.outWords = words
|
| 578 |
+
if config.controlInWordsProj or config.controlOutWordsProj:
|
| 579 |
+
pWords = ops.linear(words, config.ctrlDim, config.ctrlDim, name = "wordsProj")
|
| 580 |
+
self.inWords = pWords if config.controlInWordsProj else words
|
| 581 |
+
self.outWords = pWords if config.controlOutWordsProj else words
|
| 582 |
+
|
| 583 |
+
# if config.controlCoverage:
|
| 584 |
+
# self.coverage = tf.zeros((batchSize, tf.shape(words)[1]), dtype = tf.float32)
|
| 585 |
+
# self.coverageBias = tf.get_variable("coverageBias", shape = (),
|
| 586 |
+
# initializer = config.controlCoverageBias)
|
| 587 |
+
|
| 588 |
+
## initialize memory variational dropout mask
|
| 589 |
+
if config.memoryVariationalDropout:
|
| 590 |
+
self.memDpMask = ops.generateVarDpMask((batchSize, config.memDim), self.dropouts["memory"])
|
| 591 |
+
|
| 592 |
+
return MACCellTuple(initialControl, initialMemory)
|