Spaces:
Running
Running
| # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """JAX implementation of baseline processor networks.""" | |
| import abc | |
| from typing import Any, Callable, List, Optional, Tuple | |
| import chex | |
| import haiku as hk | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| _Array = chex.Array | |
| _Fn = Callable[..., Any] | |
| BIG_NUMBER = 1e6 | |
| PROCESSOR_TAG = 'clrs_processor' | |
| class Processor(hk.Module): | |
| """Processor abstract base class.""" | |
| def __init__(self, name: str): | |
| if not name.endswith(PROCESSOR_TAG): | |
| name = name + '_' + PROCESSOR_TAG | |
| super().__init__(name=name) | |
| def __call__( | |
| self, | |
| node_fts: _Array, | |
| edge_fts: _Array, | |
| graph_fts: _Array, | |
| adj_mat: _Array, | |
| hidden: _Array, | |
| **kwargs, | |
| ) -> Tuple[_Array, Optional[_Array]]: | |
| """Processor inference step. | |
| Args: | |
| node_fts: Node features. | |
| edge_fts: Edge features. | |
| graph_fts: Graph features. | |
| adj_mat: Graph adjacency matrix. | |
| hidden: Hidden features. | |
| **kwargs: Extra kwargs. | |
| Returns: | |
| Output of processor inference step as a 2-tuple of (node, edge) | |
| embeddings. The edge embeddings can be None. | |
| """ | |
| pass | |
| def inf_bias(self): | |
| return False | |
| def inf_bias_edge(self): | |
| return False | |
| class GAT(Processor): | |
| """Graph Attention Network (Velickovic et al., ICLR 2018).""" | |
| def __init__( | |
| self, | |
| out_size: int, | |
| nb_heads: int, | |
| activation: Optional[_Fn] = jax.nn.relu, | |
| residual: bool = True, | |
| use_ln: bool = False, | |
| name: str = 'gat_aggr', | |
| ): | |
| super().__init__(name=name) | |
| self.out_size = out_size | |
| self.nb_heads = nb_heads | |
| if out_size % nb_heads != 0: | |
| raise ValueError('The number of attention heads must divide the width!') | |
| self.head_size = out_size // nb_heads | |
| self.activation = activation | |
| self.residual = residual | |
| self.use_ln = use_ln | |
| def __call__( # pytype: disable=signature-mismatch # numpy-scalars | |
| self, | |
| node_fts: _Array, | |
| edge_fts: _Array, | |
| graph_fts: _Array, | |
| adj_mat: _Array, | |
| hidden: _Array, | |
| **unused_kwargs, | |
| ) -> _Array: | |
| """GAT inference step.""" | |
| b, n, _ = node_fts.shape | |
| assert edge_fts.shape[:-1] == (b, n, n) | |
| assert graph_fts.shape[:-1] == (b,) | |
| assert adj_mat.shape == (b, n, n) | |
| z = jnp.concatenate([node_fts, hidden], axis=-1) | |
| m = hk.Linear(self.out_size) | |
| skip = hk.Linear(self.out_size) | |
| bias_mat = (adj_mat - 1.0) * 1e9 | |
| bias_mat = jnp.tile(bias_mat[..., None], | |
| (1, 1, 1, self.nb_heads)) # [B, N, N, H] | |
| bias_mat = jnp.transpose(bias_mat, (0, 3, 1, 2)) # [B, H, N, N] | |
| a_1 = hk.Linear(self.nb_heads) | |
| a_2 = hk.Linear(self.nb_heads) | |
| a_e = hk.Linear(self.nb_heads) | |
| a_g = hk.Linear(self.nb_heads) | |
| values = m(z) # [B, N, H*F] | |
| values = jnp.reshape( | |
| values, | |
| values.shape[:-1] + (self.nb_heads, self.head_size)) # [B, N, H, F] | |
| values = jnp.transpose(values, (0, 2, 1, 3)) # [B, H, N, F] | |
| att_1 = jnp.expand_dims(a_1(z), axis=-1) | |
| att_2 = jnp.expand_dims(a_2(z), axis=-1) | |
| att_e = a_e(edge_fts) | |
| att_g = jnp.expand_dims(a_g(graph_fts), axis=-1) | |
| logits = ( | |
| jnp.transpose(att_1, (0, 2, 1, 3)) + # + [B, H, N, 1] | |
| jnp.transpose(att_2, (0, 2, 3, 1)) + # + [B, H, 1, N] | |
| jnp.transpose(att_e, (0, 3, 1, 2)) + # + [B, H, N, N] | |
| jnp.expand_dims(att_g, axis=-1) # + [B, H, 1, 1] | |
| ) # = [B, H, N, N] | |
| coefs = jax.nn.softmax(jax.nn.leaky_relu(logits) + bias_mat, axis=-1) | |
| ret = jnp.matmul(coefs, values) # [B, H, N, F] | |
| ret = jnp.transpose(ret, (0, 2, 1, 3)) # [B, N, H, F] | |
| ret = jnp.reshape(ret, ret.shape[:-2] + (self.out_size,)) # [B, N, H*F] | |
| if self.residual: | |
| ret += skip(z) | |
| if self.activation is not None: | |
| ret = self.activation(ret) | |
| if self.use_ln: | |
| ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True) | |
| ret = ln(ret) | |
| return ret, None # pytype: disable=bad-return-type # numpy-scalars | |
| class GATFull(GAT): | |
| """Graph Attention Network with full adjacency matrix.""" | |
| def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array, | |
| adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array: | |
| adj_mat = jnp.ones_like(adj_mat) | |
| return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden) | |
| class GATv2(Processor): | |
| """Graph Attention Network v2 (Brody et al., ICLR 2022).""" | |
| def __init__( | |
| self, | |
| out_size: int, | |
| nb_heads: int, | |
| mid_size: Optional[int] = None, | |
| activation: Optional[_Fn] = jax.nn.relu, | |
| residual: bool = True, | |
| use_ln: bool = False, | |
| name: str = 'gatv2_aggr', | |
| ): | |
| super().__init__(name=name) | |
| if mid_size is None: | |
| self.mid_size = out_size | |
| else: | |
| self.mid_size = mid_size | |
| self.out_size = out_size | |
| self.nb_heads = nb_heads | |
| if out_size % nb_heads != 0: | |
| raise ValueError('The number of attention heads must divide the width!') | |
| self.head_size = out_size // nb_heads | |
| if self.mid_size % nb_heads != 0: | |
| raise ValueError('The number of attention heads must divide the message!') | |
| self.mid_head_size = self.mid_size // nb_heads | |
| self.activation = activation | |
| self.residual = residual | |
| self.use_ln = use_ln | |
| def __call__( # pytype: disable=signature-mismatch # numpy-scalars | |
| self, | |
| node_fts: _Array, | |
| edge_fts: _Array, | |
| graph_fts: _Array, | |
| adj_mat: _Array, | |
| hidden: _Array, | |
| **unused_kwargs, | |
| ) -> _Array: | |
| """GATv2 inference step.""" | |
| b, n, _ = node_fts.shape | |
| assert edge_fts.shape[:-1] == (b, n, n) | |
| assert graph_fts.shape[:-1] == (b,) | |
| assert adj_mat.shape == (b, n, n) | |
| z = jnp.concatenate([node_fts, hidden], axis=-1) | |
| m = hk.Linear(self.out_size) | |
| skip = hk.Linear(self.out_size) | |
| bias_mat = (adj_mat - 1.0) * 1e9 | |
| bias_mat = jnp.tile(bias_mat[..., None], | |
| (1, 1, 1, self.nb_heads)) # [B, N, N, H] | |
| bias_mat = jnp.transpose(bias_mat, (0, 3, 1, 2)) # [B, H, N, N] | |
| w_1 = hk.Linear(self.mid_size) | |
| w_2 = hk.Linear(self.mid_size) | |
| w_e = hk.Linear(self.mid_size) | |
| w_g = hk.Linear(self.mid_size) | |
| a_heads = [] | |
| for _ in range(self.nb_heads): | |
| a_heads.append(hk.Linear(1)) | |
| values = m(z) # [B, N, H*F] | |
| values = jnp.reshape( | |
| values, | |
| values.shape[:-1] + (self.nb_heads, self.head_size)) # [B, N, H, F] | |
| values = jnp.transpose(values, (0, 2, 1, 3)) # [B, H, N, F] | |
| pre_att_1 = w_1(z) | |
| pre_att_2 = w_2(z) | |
| pre_att_e = w_e(edge_fts) | |
| pre_att_g = w_g(graph_fts) | |
| pre_att = ( | |
| jnp.expand_dims(pre_att_1, axis=1) + # + [B, 1, N, H*F] | |
| jnp.expand_dims(pre_att_2, axis=2) + # + [B, N, 1, H*F] | |
| pre_att_e + # + [B, N, N, H*F] | |
| jnp.expand_dims(pre_att_g, axis=(1, 2)) # + [B, 1, 1, H*F] | |
| ) # = [B, N, N, H*F] | |
| pre_att = jnp.reshape( | |
| pre_att, | |
| pre_att.shape[:-1] + (self.nb_heads, self.mid_head_size) | |
| ) # [B, N, N, H, F] | |
| pre_att = jnp.transpose(pre_att, (0, 3, 1, 2, 4)) # [B, H, N, N, F] | |
| # This part is not very efficient, but we agree to keep it this way to | |
| # enhance readability, assuming `nb_heads` will not be large. | |
| logit_heads = [] | |
| for head in range(self.nb_heads): | |
| logit_heads.append( | |
| jnp.squeeze( | |
| a_heads[head](jax.nn.leaky_relu(pre_att[:, head])), | |
| axis=-1) | |
| ) # [B, N, N] | |
| logits = jnp.stack(logit_heads, axis=1) # [B, H, N, N] | |
| coefs = jax.nn.softmax(logits + bias_mat, axis=-1) | |
| ret = jnp.matmul(coefs, values) # [B, H, N, F] | |
| ret = jnp.transpose(ret, (0, 2, 1, 3)) # [B, N, H, F] | |
| ret = jnp.reshape(ret, ret.shape[:-2] + (self.out_size,)) # [B, N, H*F] | |
| if self.residual: | |
| ret += skip(z) | |
| if self.activation is not None: | |
| ret = self.activation(ret) | |
| if self.use_ln: | |
| ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True) | |
| ret = ln(ret) | |
| return ret, None # pytype: disable=bad-return-type # numpy-scalars | |
| class GATv2Full(GATv2): | |
| """Graph Attention Network v2 with full adjacency matrix.""" | |
| def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array, | |
| adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array: | |
| adj_mat = jnp.ones_like(adj_mat) | |
| return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden) | |
| def get_triplet_msgs(z, edge_fts, graph_fts, nb_triplet_fts): | |
| """Triplet messages, as done by Dudzik and Velickovic (2022).""" | |
| t_1 = hk.Linear(nb_triplet_fts) | |
| t_2 = hk.Linear(nb_triplet_fts) | |
| t_3 = hk.Linear(nb_triplet_fts) | |
| t_e_1 = hk.Linear(nb_triplet_fts) | |
| t_e_2 = hk.Linear(nb_triplet_fts) | |
| t_e_3 = hk.Linear(nb_triplet_fts) | |
| t_g = hk.Linear(nb_triplet_fts) | |
| tri_1 = t_1(z) | |
| tri_2 = t_2(z) | |
| tri_3 = t_3(z) | |
| tri_e_1 = t_e_1(edge_fts) | |
| tri_e_2 = t_e_2(edge_fts) | |
| tri_e_3 = t_e_3(edge_fts) | |
| tri_g = t_g(graph_fts) | |
| return ( | |
| jnp.expand_dims(tri_1, axis=(2, 3)) + # (B, N, 1, 1, H) | |
| jnp.expand_dims(tri_2, axis=(1, 3)) + # + (B, 1, N, 1, H) | |
| jnp.expand_dims(tri_3, axis=(1, 2)) + # + (B, 1, 1, N, H) | |
| jnp.expand_dims(tri_e_1, axis=3) + # + (B, N, N, 1, H) | |
| jnp.expand_dims(tri_e_2, axis=2) + # + (B, N, 1, N, H) | |
| jnp.expand_dims(tri_e_3, axis=1) + # + (B, 1, N, N, H) | |
| jnp.expand_dims(tri_g, axis=(1, 2, 3)) # + (B, 1, 1, 1, H) | |
| ) # = (B, N, N, N, H) | |
| class PGN(Processor): | |
| """Pointer Graph Networks (Veličković et al., NeurIPS 2020).""" | |
| def __init__( | |
| self, | |
| out_size: int, | |
| mid_size: Optional[int] = None, | |
| mid_act: Optional[_Fn] = None, | |
| activation: Optional[_Fn] = jax.nn.relu, | |
| reduction: _Fn = jnp.max, | |
| msgs_mlp_sizes: Optional[List[int]] = None, | |
| use_ln: bool = False, | |
| use_triplets: bool = False, | |
| nb_triplet_fts: int = 8, | |
| gated: bool = False, | |
| name: str = 'mpnn_aggr', | |
| ): | |
| super().__init__(name=name) | |
| if mid_size is None: | |
| self.mid_size = out_size | |
| else: | |
| self.mid_size = mid_size | |
| self.out_size = out_size | |
| self.mid_act = mid_act | |
| self.activation = activation | |
| self.reduction = reduction | |
| self._msgs_mlp_sizes = msgs_mlp_sizes | |
| self.use_ln = use_ln | |
| self.use_triplets = use_triplets | |
| self.nb_triplet_fts = nb_triplet_fts | |
| self.gated = gated | |
| def __call__( # pytype: disable=signature-mismatch # numpy-scalars | |
| self, | |
| node_fts: _Array, | |
| edge_fts: _Array, | |
| graph_fts: _Array, | |
| adj_mat: _Array, | |
| hidden: _Array, | |
| **unused_kwargs, | |
| ) -> _Array: | |
| """MPNN inference step.""" | |
| b, n, _ = node_fts.shape | |
| assert edge_fts.shape[:-1] == (b, n, n) | |
| assert graph_fts.shape[:-1] == (b,) | |
| assert adj_mat.shape == (b, n, n) | |
| z = jnp.concatenate([node_fts, hidden], axis=-1) | |
| m_1 = hk.Linear(self.mid_size) | |
| m_2 = hk.Linear(self.mid_size) | |
| m_e = hk.Linear(self.mid_size) | |
| m_g = hk.Linear(self.mid_size) | |
| o1 = hk.Linear(self.out_size) | |
| o2 = hk.Linear(self.out_size) | |
| msg_1 = m_1(z) | |
| msg_2 = m_2(z) | |
| msg_e = m_e(edge_fts) | |
| msg_g = m_g(graph_fts) | |
| tri_msgs = None | |
| if self.use_triplets: | |
| # Triplet messages, as done by Dudzik and Velickovic (2022) | |
| triplets = get_triplet_msgs(z, edge_fts, graph_fts, self.nb_triplet_fts) | |
| o3 = hk.Linear(self.out_size) | |
| tri_msgs = o3(jnp.max(triplets, axis=1)) # (B, N, N, H) | |
| if self.activation is not None: | |
| tri_msgs = self.activation(tri_msgs) | |
| msgs = ( | |
| jnp.expand_dims(msg_1, axis=1) + jnp.expand_dims(msg_2, axis=2) + | |
| msg_e + jnp.expand_dims(msg_g, axis=(1, 2))) | |
| if self._msgs_mlp_sizes is not None: | |
| msgs = hk.nets.MLP(self._msgs_mlp_sizes)(jax.nn.relu(msgs)) | |
| if self.mid_act is not None: | |
| msgs = self.mid_act(msgs) | |
| if self.reduction == jnp.mean: | |
| msgs = jnp.sum(msgs * jnp.expand_dims(adj_mat, -1), axis=1) | |
| msgs = msgs / jnp.sum(adj_mat, axis=-1, keepdims=True) | |
| elif self.reduction == jnp.max: | |
| maxarg = jnp.where(jnp.expand_dims(adj_mat, -1), | |
| msgs, | |
| -BIG_NUMBER) | |
| msgs = jnp.max(maxarg, axis=1) | |
| else: | |
| msgs = self.reduction(msgs * jnp.expand_dims(adj_mat, -1), axis=1) | |
| h_1 = o1(z) | |
| h_2 = o2(msgs) | |
| ret = h_1 + h_2 | |
| if self.activation is not None: | |
| ret = self.activation(ret) | |
| if self.use_ln: | |
| ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True) | |
| ret = ln(ret) | |
| if self.gated: | |
| gate1 = hk.Linear(self.out_size) | |
| gate2 = hk.Linear(self.out_size) | |
| gate3 = hk.Linear(self.out_size, b_init=hk.initializers.Constant(-3)) | |
| gate = jax.nn.sigmoid(gate3(jax.nn.relu(gate1(z) + gate2(msgs)))) | |
| ret = ret * gate + hidden * (1-gate) | |
| return ret, tri_msgs # pytype: disable=bad-return-type # numpy-scalars | |
| class DeepSets(PGN): | |
| """Deep Sets (Zaheer et al., NeurIPS 2017).""" | |
| def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array, | |
| adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array: | |
| assert adj_mat.ndim == 3 | |
| adj_mat = jnp.ones_like(adj_mat) * jnp.eye(adj_mat.shape[-1]) | |
| return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden) | |
| class MPNN(PGN): | |
| """Message-Passing Neural Network (Gilmer et al., ICML 2017).""" | |
| def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array, | |
| adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array: | |
| adj_mat = jnp.ones_like(adj_mat) | |
| return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden) | |
| class PGNMask(PGN): | |
| """Masked Pointer Graph Networks (Veličković et al., NeurIPS 2020).""" | |
| def inf_bias(self): | |
| return True | |
| def inf_bias_edge(self): | |
| return True | |
| class MemNetMasked(Processor): | |
| """Implementation of End-to-End Memory Networks. | |
| Inspired by the description in https://arxiv.org/abs/1503.08895. | |
| """ | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| sentence_size: int, | |
| linear_output_size: int, | |
| embedding_size: int = 16, | |
| memory_size: Optional[int] = 128, | |
| num_hops: int = 1, | |
| nonlin: Callable[[Any], Any] = jax.nn.relu, | |
| apply_embeddings: bool = True, | |
| init_func: hk.initializers.Initializer = jnp.zeros, | |
| use_ln: bool = False, | |
| name: str = 'memnet') -> None: | |
| """Constructor. | |
| Args: | |
| vocab_size: the number of words in the dictionary (each story, query and | |
| answer come contain symbols coming from this dictionary). | |
| sentence_size: the dimensionality of each memory. | |
| linear_output_size: the dimensionality of the output of the last layer | |
| of the model. | |
| embedding_size: the dimensionality of the latent space to where all | |
| memories are projected. | |
| memory_size: the number of memories provided. | |
| num_hops: the number of layers in the model. | |
| nonlin: non-linear transformation applied at the end of each layer. | |
| apply_embeddings: flag whether to aply embeddings. | |
| init_func: initialization function for the biases. | |
| use_ln: whether to use layer normalisation in the model. | |
| name: the name of the model. | |
| """ | |
| super().__init__(name=name) | |
| self._vocab_size = vocab_size | |
| self._embedding_size = embedding_size | |
| self._sentence_size = sentence_size | |
| self._memory_size = memory_size | |
| self._linear_output_size = linear_output_size | |
| self._num_hops = num_hops | |
| self._nonlin = nonlin | |
| self._apply_embeddings = apply_embeddings | |
| self._init_func = init_func | |
| self._use_ln = use_ln | |
| # Encoding part: i.e. "I" of the paper. | |
| self._encodings = _position_encoding(sentence_size, embedding_size) | |
| def __call__( # pytype: disable=signature-mismatch # numpy-scalars | |
| self, | |
| node_fts: _Array, | |
| edge_fts: _Array, | |
| graph_fts: _Array, | |
| adj_mat: _Array, | |
| hidden: _Array, | |
| **unused_kwargs, | |
| ) -> _Array: | |
| """MemNet inference step.""" | |
| del hidden | |
| node_and_graph_fts = jnp.concatenate([node_fts, graph_fts[:, None]], | |
| axis=1) | |
| edge_fts_padded = jnp.pad(edge_fts * adj_mat[..., None], | |
| ((0, 0), (0, 1), (0, 1), (0, 0))) | |
| nxt_hidden = jax.vmap(self._apply, (1), 1)(node_and_graph_fts, | |
| edge_fts_padded) | |
| # Broadcast hidden state corresponding to graph features across the nodes. | |
| nxt_hidden = nxt_hidden[:, :-1] + nxt_hidden[:, -1:] | |
| return nxt_hidden, None # pytype: disable=bad-return-type # numpy-scalars | |
| def _apply(self, queries: _Array, stories: _Array) -> _Array: | |
| """Apply Memory Network to the queries and stories. | |
| Args: | |
| queries: Tensor of shape [batch_size, sentence_size]. | |
| stories: Tensor of shape [batch_size, memory_size, sentence_size]. | |
| Returns: | |
| Tensor of shape [batch_size, vocab_size]. | |
| """ | |
| if self._apply_embeddings: | |
| query_biases = hk.get_parameter( | |
| 'query_biases', | |
| shape=[self._vocab_size - 1, self._embedding_size], | |
| init=self._init_func) | |
| stories_biases = hk.get_parameter( | |
| 'stories_biases', | |
| shape=[self._vocab_size - 1, self._embedding_size], | |
| init=self._init_func) | |
| memory_biases = hk.get_parameter( | |
| 'memory_contents', | |
| shape=[self._memory_size, self._embedding_size], | |
| init=self._init_func) | |
| output_biases = hk.get_parameter( | |
| 'output_biases', | |
| shape=[self._vocab_size - 1, self._embedding_size], | |
| init=self._init_func) | |
| nil_word_slot = jnp.zeros([1, self._embedding_size]) | |
| # This is "A" in the paper. | |
| if self._apply_embeddings: | |
| stories_biases = jnp.concatenate([stories_biases, nil_word_slot], axis=0) | |
| memory_embeddings = jnp.take( | |
| stories_biases, stories.reshape([-1]).astype(jnp.int32), | |
| axis=0).reshape(list(stories.shape) + [self._embedding_size]) | |
| memory_embeddings = jnp.pad( | |
| memory_embeddings, | |
| ((0, 0), (0, self._memory_size - jnp.shape(memory_embeddings)[1]), | |
| (0, 0), (0, 0))) | |
| memory = jnp.sum(memory_embeddings * self._encodings, 2) + memory_biases | |
| else: | |
| memory = stories | |
| # This is "B" in the paper. Also, when there are no queries (only | |
| # sentences), then there these lines are substituted by | |
| # query_embeddings = 0.1. | |
| if self._apply_embeddings: | |
| query_biases = jnp.concatenate([query_biases, nil_word_slot], axis=0) | |
| query_embeddings = jnp.take( | |
| query_biases, queries.reshape([-1]).astype(jnp.int32), | |
| axis=0).reshape(list(queries.shape) + [self._embedding_size]) | |
| # This is "u" in the paper. | |
| query_input_embedding = jnp.sum(query_embeddings * self._encodings, 1) | |
| else: | |
| query_input_embedding = queries | |
| # This is "C" in the paper. | |
| if self._apply_embeddings: | |
| output_biases = jnp.concatenate([output_biases, nil_word_slot], axis=0) | |
| output_embeddings = jnp.take( | |
| output_biases, stories.reshape([-1]).astype(jnp.int32), | |
| axis=0).reshape(list(stories.shape) + [self._embedding_size]) | |
| output_embeddings = jnp.pad( | |
| output_embeddings, | |
| ((0, 0), (0, self._memory_size - jnp.shape(output_embeddings)[1]), | |
| (0, 0), (0, 0))) | |
| output = jnp.sum(output_embeddings * self._encodings, 2) | |
| else: | |
| output = stories | |
| intermediate_linear = hk.Linear(self._embedding_size, with_bias=False) | |
| # Output_linear is "H". | |
| output_linear = hk.Linear(self._linear_output_size, with_bias=False) | |
| for hop_number in range(self._num_hops): | |
| query_input_embedding_transposed = jnp.transpose( | |
| jnp.expand_dims(query_input_embedding, -1), [0, 2, 1]) | |
| # Calculate probabilities. | |
| probs = jax.nn.softmax( | |
| jnp.sum(memory * query_input_embedding_transposed, 2)) | |
| # Calculate output of the layer by multiplying by C. | |
| transposed_probs = jnp.transpose(jnp.expand_dims(probs, -1), [0, 2, 1]) | |
| transposed_output_embeddings = jnp.transpose(output, [0, 2, 1]) | |
| # This is "o" in the paper. | |
| layer_output = jnp.sum(transposed_output_embeddings * transposed_probs, 2) | |
| # Finally the answer | |
| if hop_number == self._num_hops - 1: | |
| # Please note that in the TF version we apply the final linear layer | |
| # in all hops and this results in shape mismatches. | |
| output_layer = output_linear(query_input_embedding + layer_output) | |
| else: | |
| output_layer = intermediate_linear(query_input_embedding + layer_output) | |
| query_input_embedding = output_layer | |
| if self._nonlin: | |
| output_layer = self._nonlin(output_layer) | |
| # This linear here is "W". | |
| ret = hk.Linear(self._vocab_size, with_bias=False)(output_layer) | |
| if self._use_ln: | |
| ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True) | |
| ret = ln(ret) | |
| return ret | |
| class MemNetFull(MemNetMasked): | |
| """Memory Networks with full adjacency matrix.""" | |
| def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array, | |
| adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array: | |
| adj_mat = jnp.ones_like(adj_mat) | |
| return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden) | |
| ProcessorFactory = Callable[[int], Processor] | |
| def get_processor_factory(kind: str, | |
| use_ln: bool, | |
| nb_triplet_fts: int, | |
| nb_heads: Optional[int] = None) -> ProcessorFactory: | |
| """Returns a processor factory. | |
| Args: | |
| kind: One of the available types of processor. | |
| use_ln: Whether the processor passes the output through a layernorm layer. | |
| nb_triplet_fts: How many triplet features to compute. | |
| nb_heads: Number of attention heads for GAT processors. | |
| Returns: | |
| A callable that takes an `out_size` parameter (equal to the hidden | |
| dimension of the network) and returns a processor instance. | |
| """ | |
| def _factory(out_size: int): | |
| if kind == 'deepsets': | |
| processor = DeepSets( | |
| out_size=out_size, | |
| msgs_mlp_sizes=[out_size, out_size], | |
| use_ln=use_ln, | |
| use_triplets=False, | |
| nb_triplet_fts=0 | |
| ) | |
| elif kind == 'gat': | |
| processor = GAT( | |
| out_size=out_size, | |
| nb_heads=nb_heads, | |
| use_ln=use_ln, | |
| ) | |
| elif kind == 'gat_full': | |
| processor = GATFull( | |
| out_size=out_size, | |
| nb_heads=nb_heads, | |
| use_ln=use_ln | |
| ) | |
| elif kind == 'gatv2': | |
| processor = GATv2( | |
| out_size=out_size, | |
| nb_heads=nb_heads, | |
| use_ln=use_ln | |
| ) | |
| elif kind == 'gatv2_full': | |
| processor = GATv2Full( | |
| out_size=out_size, | |
| nb_heads=nb_heads, | |
| use_ln=use_ln | |
| ) | |
| elif kind == 'memnet_full': | |
| processor = MemNetFull( | |
| vocab_size=out_size, | |
| sentence_size=out_size, | |
| linear_output_size=out_size, | |
| ) | |
| elif kind == 'memnet_masked': | |
| processor = MemNetMasked( | |
| vocab_size=out_size, | |
| sentence_size=out_size, | |
| linear_output_size=out_size, | |
| ) | |
| elif kind == 'mpnn': | |
| processor = MPNN( | |
| out_size=out_size, | |
| msgs_mlp_sizes=[out_size, out_size], | |
| use_ln=use_ln, | |
| use_triplets=False, | |
| nb_triplet_fts=0, | |
| ) | |
| elif kind == 'pgn': | |
| processor = PGN( | |
| out_size=out_size, | |
| msgs_mlp_sizes=[out_size, out_size], | |
| use_ln=use_ln, | |
| use_triplets=False, | |
| nb_triplet_fts=0, | |
| ) | |
| elif kind == 'pgn_mask': | |
| processor = PGNMask( | |
| out_size=out_size, | |
| msgs_mlp_sizes=[out_size, out_size], | |
| use_ln=use_ln, | |
| use_triplets=False, | |
| nb_triplet_fts=0, | |
| ) | |
| elif kind == 'triplet_mpnn': | |
| processor = MPNN( | |
| out_size=out_size, | |
| msgs_mlp_sizes=[out_size, out_size], | |
| use_ln=use_ln, | |
| use_triplets=True, | |
| nb_triplet_fts=nb_triplet_fts, | |
| ) | |
| elif kind == 'triplet_pgn': | |
| processor = PGN( | |
| out_size=out_size, | |
| msgs_mlp_sizes=[out_size, out_size], | |
| use_ln=use_ln, | |
| use_triplets=True, | |
| nb_triplet_fts=nb_triplet_fts, | |
| ) | |
| elif kind == 'triplet_pgn_mask': | |
| processor = PGNMask( | |
| out_size=out_size, | |
| msgs_mlp_sizes=[out_size, out_size], | |
| use_ln=use_ln, | |
| use_triplets=True, | |
| nb_triplet_fts=nb_triplet_fts, | |
| ) | |
| elif kind == 'gpgn': | |
| processor = PGN( | |
| out_size=out_size, | |
| msgs_mlp_sizes=[out_size, out_size], | |
| use_ln=use_ln, | |
| use_triplets=False, | |
| nb_triplet_fts=nb_triplet_fts, | |
| gated=True, | |
| ) | |
| elif kind == 'gpgn_mask': | |
| processor = PGNMask( | |
| out_size=out_size, | |
| msgs_mlp_sizes=[out_size, out_size], | |
| use_ln=use_ln, | |
| use_triplets=False, | |
| nb_triplet_fts=nb_triplet_fts, | |
| gated=True, | |
| ) | |
| elif kind == 'gmpnn': | |
| processor = MPNN( | |
| out_size=out_size, | |
| msgs_mlp_sizes=[out_size, out_size], | |
| use_ln=use_ln, | |
| use_triplets=False, | |
| nb_triplet_fts=nb_triplet_fts, | |
| gated=True, | |
| ) | |
| elif kind == 'triplet_gpgn': | |
| processor = PGN( | |
| out_size=out_size, | |
| msgs_mlp_sizes=[out_size, out_size], | |
| use_ln=use_ln, | |
| use_triplets=True, | |
| nb_triplet_fts=nb_triplet_fts, | |
| gated=True, | |
| ) | |
| elif kind == 'triplet_gpgn_mask': | |
| processor = PGNMask( | |
| out_size=out_size, | |
| msgs_mlp_sizes=[out_size, out_size], | |
| use_ln=use_ln, | |
| use_triplets=True, | |
| nb_triplet_fts=nb_triplet_fts, | |
| gated=True, | |
| ) | |
| elif kind == 'triplet_gmpnn': | |
| processor = MPNN( | |
| out_size=out_size, | |
| msgs_mlp_sizes=[out_size, out_size], | |
| use_ln=use_ln, | |
| use_triplets=True, | |
| nb_triplet_fts=nb_triplet_fts, | |
| gated=True, | |
| ) | |
| else: | |
| raise ValueError('Unexpected processor kind ' + kind) | |
| return processor | |
| return _factory | |
| def _position_encoding(sentence_size: int, embedding_size: int) -> np.ndarray: | |
| """Position Encoding described in section 4.1 [1].""" | |
| encoding = np.ones((embedding_size, sentence_size), dtype=np.float32) | |
| ls = sentence_size + 1 | |
| le = embedding_size + 1 | |
| for i in range(1, le): | |
| for j in range(1, ls): | |
| encoding[i - 1, j - 1] = (i - (le - 1) / 2) * (j - (ls - 1) / 2) | |
| encoding = 1 + 4 * encoding / embedding_size / sentence_size | |
| return np.transpose(encoding) | |