Spaces:
Running
Running
| # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """JAX implementation of CLRS basic network.""" | |
| import functools | |
| from typing import Dict, List, Optional, Tuple | |
| import chex | |
| from clrs._src import decoders | |
| from clrs._src import encoders | |
| from clrs._src import probing | |
| from clrs._src import processors | |
| from clrs._src import samplers | |
| from clrs._src import specs | |
| import haiku as hk | |
| import jax | |
| import jax.numpy as jnp | |
| _Array = chex.Array | |
| _DataPoint = probing.DataPoint | |
| _Features = samplers.Features | |
| _FeaturesChunked = samplers.FeaturesChunked | |
| _Location = specs.Location | |
| _Spec = specs.Spec | |
| _Stage = specs.Stage | |
| _Trajectory = samplers.Trajectory | |
| _Type = specs.Type | |
| class _MessagePassingScanState: | |
| hint_preds: chex.Array | |
| output_preds: chex.Array | |
| hiddens: chex.Array | |
| lstm_state: Optional[hk.LSTMState] | |
| class _MessagePassingOutputChunked: | |
| hint_preds: chex.Array | |
| output_preds: chex.Array | |
| class MessagePassingStateChunked: | |
| inputs: chex.Array | |
| hints: chex.Array | |
| is_first: chex.Array | |
| hint_preds: chex.Array | |
| hiddens: chex.Array | |
| lstm_state: Optional[hk.LSTMState] | |
| class Net(hk.Module): | |
| """Building blocks (networks) used to encode and decode messages.""" | |
| def __init__( | |
| self, | |
| spec: List[_Spec], | |
| hidden_dim: int, | |
| encode_hints: bool, | |
| decode_hints: bool, | |
| processor_factory: processors.ProcessorFactory, | |
| use_lstm: bool, | |
| encoder_init: str, | |
| dropout_prob: float, | |
| hint_teacher_forcing: float, | |
| hint_repred_mode='soft', | |
| nb_dims=None, | |
| nb_msg_passing_steps=1, | |
| name: str = 'net', | |
| ): | |
| """Constructs a `Net`.""" | |
| super().__init__(name=name) | |
| self._dropout_prob = dropout_prob | |
| self._hint_teacher_forcing = hint_teacher_forcing | |
| self._hint_repred_mode = hint_repred_mode | |
| self.spec = spec | |
| self.hidden_dim = hidden_dim | |
| self.encode_hints = encode_hints | |
| self.decode_hints = decode_hints | |
| self.processor_factory = processor_factory | |
| self.nb_dims = nb_dims | |
| self.use_lstm = use_lstm | |
| self.encoder_init = encoder_init | |
| self.nb_msg_passing_steps = nb_msg_passing_steps | |
| def _msg_passing_step(self, | |
| mp_state: _MessagePassingScanState, | |
| i: int, | |
| hints: List[_DataPoint], | |
| repred: bool, | |
| lengths: chex.Array, | |
| batch_size: int, | |
| nb_nodes: int, | |
| inputs: _Trajectory, | |
| first_step: bool, | |
| spec: _Spec, | |
| encs: Dict[str, List[hk.Module]], | |
| decs: Dict[str, Tuple[hk.Module]], | |
| return_hints: bool, | |
| return_all_outputs: bool | |
| ): | |
| if self.decode_hints and not first_step: | |
| assert self._hint_repred_mode in ['soft', 'hard', 'hard_on_eval'] | |
| hard_postprocess = (self._hint_repred_mode == 'hard' or | |
| (self._hint_repred_mode == 'hard_on_eval' and repred)) | |
| decoded_hint = decoders.postprocess(spec, | |
| mp_state.hint_preds, | |
| sinkhorn_temperature=0.1, | |
| sinkhorn_steps=25, | |
| hard=hard_postprocess) | |
| if repred and self.decode_hints and not first_step: | |
| cur_hint = [] | |
| for hint in decoded_hint: | |
| cur_hint.append(decoded_hint[hint]) | |
| else: | |
| cur_hint = [] | |
| needs_noise = (self.decode_hints and not first_step and | |
| self._hint_teacher_forcing < 1.0) | |
| if needs_noise: | |
| # For noisy teacher forcing, choose which examples in the batch to force | |
| force_mask = jax.random.bernoulli( | |
| hk.next_rng_key(), self._hint_teacher_forcing, | |
| (batch_size,)) | |
| else: | |
| force_mask = None | |
| for hint in hints: | |
| hint_data = jnp.asarray(hint.data)[i] | |
| _, loc, typ = spec[hint.name] | |
| if needs_noise: | |
| if (typ == _Type.POINTER and | |
| decoded_hint[hint.name].type_ == _Type.SOFT_POINTER): | |
| # When using soft pointers, the decoded hints cannot be summarised | |
| # as indices (as would happen in hard postprocessing), so we need | |
| # to raise the ground-truth hint (potentially used for teacher | |
| # forcing) to its one-hot version. | |
| hint_data = hk.one_hot(hint_data, nb_nodes) | |
| typ = _Type.SOFT_POINTER | |
| hint_data = jnp.where(_expand_to(force_mask, hint_data), | |
| hint_data, | |
| decoded_hint[hint.name].data) | |
| cur_hint.append( | |
| probing.DataPoint( | |
| name=hint.name, location=loc, type_=typ, data=hint_data)) | |
| hiddens, output_preds_cand, hint_preds, lstm_state = self._one_step_pred( | |
| inputs, cur_hint, mp_state.hiddens, | |
| batch_size, nb_nodes, mp_state.lstm_state, | |
| spec, encs, decs, repred) | |
| if first_step: | |
| output_preds = output_preds_cand | |
| else: | |
| output_preds = {} | |
| for outp in mp_state.output_preds: | |
| is_not_done = _is_not_done_broadcast(lengths, i, | |
| output_preds_cand[outp]) | |
| output_preds[outp] = is_not_done * output_preds_cand[outp] + ( | |
| 1.0 - is_not_done) * mp_state.output_preds[outp] | |
| new_mp_state = _MessagePassingScanState( # pytype: disable=wrong-arg-types # numpy-scalars | |
| hint_preds=hint_preds, | |
| output_preds=output_preds, | |
| hiddens=hiddens, | |
| lstm_state=lstm_state) | |
| # Save memory by not stacking unnecessary fields | |
| accum_mp_state = _MessagePassingScanState( # pytype: disable=wrong-arg-types # numpy-scalars | |
| hint_preds=hint_preds if return_hints else None, | |
| output_preds=output_preds if return_all_outputs else None, | |
| hiddens=None, lstm_state=None) | |
| # Complying to jax.scan, the first returned value is the state we carry over | |
| # the second value is the output that will be stacked over steps. | |
| return new_mp_state, accum_mp_state | |
| def __call__(self, features_list: List[_Features], repred: bool, | |
| algorithm_index: int, | |
| return_hints: bool, | |
| return_all_outputs: bool): | |
| """Process one batch of data. | |
| Args: | |
| features_list: A list of _Features objects, each with the inputs, hints | |
| and lengths for a batch o data corresponding to one algorithm. | |
| The list should have either length 1, at train/evaluation time, | |
| or length equal to the number of algorithms this Net is meant to | |
| process, at initialization. | |
| repred: False during training, when we have access to ground-truth hints. | |
| True in validation/test mode, when we have to use our own | |
| hint predictions. | |
| algorithm_index: Which algorithm is being processed. It can be -1 at | |
| initialisation (either because we are initialising the parameters of | |
| the module or because we are intialising the message-passing state), | |
| meaning that all algorithms should be processed, in which case | |
| `features_list` should have length equal to the number of specs of | |
| the Net. Otherwise, `algorithm_index` should be | |
| between 0 and `length(self.spec) - 1`, meaning only one of the | |
| algorithms will be processed, and `features_list` should have length 1. | |
| return_hints: Whether to accumulate and return the predicted hints, | |
| when they are decoded. | |
| return_all_outputs: Whether to return the full sequence of outputs, or | |
| just the last step's output. | |
| Returns: | |
| A 2-tuple with (output predictions, hint predictions) | |
| for the selected algorithm. | |
| """ | |
| if algorithm_index == -1: | |
| algorithm_indices = range(len(features_list)) | |
| else: | |
| algorithm_indices = [algorithm_index] | |
| assert len(algorithm_indices) == len(features_list) | |
| self.encoders, self.decoders = self._construct_encoders_decoders() | |
| self.processor = self.processor_factory(self.hidden_dim) | |
| # Optionally construct LSTM. | |
| if self.use_lstm: | |
| self.lstm = hk.LSTM( | |
| hidden_size=self.hidden_dim, | |
| name='processor_lstm') | |
| lstm_init = self.lstm.initial_state | |
| else: | |
| self.lstm = None | |
| lstm_init = lambda x: 0 | |
| for algorithm_index, features in zip(algorithm_indices, features_list): | |
| inputs = features.inputs | |
| hints = features.hints | |
| lengths = features.lengths | |
| batch_size, nb_nodes = _data_dimensions(features) | |
| nb_mp_steps = max(1, hints[0].data.shape[0] - 1) | |
| hiddens = jnp.zeros((batch_size, nb_nodes, self.hidden_dim)) | |
| if self.use_lstm: | |
| lstm_state = lstm_init(batch_size * nb_nodes) | |
| lstm_state = jax.tree_util.tree_map( | |
| lambda x, b=batch_size, n=nb_nodes: jnp.reshape(x, [b, n, -1]), | |
| lstm_state) | |
| else: | |
| lstm_state = None | |
| mp_state = _MessagePassingScanState( # pytype: disable=wrong-arg-types # numpy-scalars | |
| hint_preds=None, output_preds=None, | |
| hiddens=hiddens, lstm_state=lstm_state) | |
| # Do the first step outside of the scan because it has a different | |
| # computation graph. | |
| common_args = dict( | |
| hints=hints, | |
| repred=repred, | |
| inputs=inputs, | |
| batch_size=batch_size, | |
| nb_nodes=nb_nodes, | |
| lengths=lengths, | |
| spec=self.spec[algorithm_index], | |
| encs=self.encoders[algorithm_index], | |
| decs=self.decoders[algorithm_index], | |
| return_hints=return_hints, | |
| return_all_outputs=return_all_outputs, | |
| ) | |
| mp_state, lean_mp_state = self._msg_passing_step( | |
| mp_state, | |
| i=0, | |
| first_step=True, | |
| **common_args) | |
| # Then scan through the rest. | |
| scan_fn = functools.partial( | |
| self._msg_passing_step, | |
| first_step=False, | |
| **common_args) | |
| output_mp_state, accum_mp_state = hk.scan( | |
| scan_fn, | |
| mp_state, | |
| jnp.arange(nb_mp_steps - 1) + 1, | |
| length=nb_mp_steps - 1) | |
| # We only return the last algorithm's output. That's because | |
| # the output only matters when a single algorithm is processed; the case | |
| # `algorithm_index==-1` (meaning all algorithms should be processed) | |
| # is used only to init parameters. | |
| accum_mp_state = jax.tree_util.tree_map( | |
| lambda init, tail: jnp.concatenate([init[None], tail], axis=0), | |
| lean_mp_state, accum_mp_state) | |
| def invert(d): | |
| """Dict of lists -> list of dicts.""" | |
| if d: | |
| return [dict(zip(d, i)) for i in zip(*d.values())] | |
| if return_all_outputs: | |
| output_preds = {k: jnp.stack(v) | |
| for k, v in accum_mp_state.output_preds.items()} | |
| else: | |
| output_preds = output_mp_state.output_preds | |
| hint_preds = invert(accum_mp_state.hint_preds) | |
| return output_preds, hint_preds | |
| def _construct_encoders_decoders(self): | |
| """Constructs encoders and decoders, separate for each algorithm.""" | |
| encoders_ = [] | |
| decoders_ = [] | |
| enc_algo_idx = None | |
| for (algo_idx, spec) in enumerate(self.spec): | |
| enc = {} | |
| dec = {} | |
| for name, (stage, loc, t) in spec.items(): | |
| if stage == _Stage.INPUT or ( | |
| stage == _Stage.HINT and self.encode_hints): | |
| # Build input encoders. | |
| if name == specs.ALGO_IDX_INPUT_NAME: | |
| if enc_algo_idx is None: | |
| enc_algo_idx = [hk.Linear(self.hidden_dim, | |
| name=f'{name}_enc_linear')] | |
| enc[name] = enc_algo_idx | |
| else: | |
| enc[name] = encoders.construct_encoders( | |
| stage, loc, t, hidden_dim=self.hidden_dim, | |
| init=self.encoder_init, | |
| name=f'algo_{algo_idx}_{name}') | |
| if stage == _Stage.OUTPUT or ( | |
| stage == _Stage.HINT and self.decode_hints): | |
| # Build output decoders. | |
| dec[name] = decoders.construct_decoders( | |
| loc, t, hidden_dim=self.hidden_dim, | |
| nb_dims=self.nb_dims[algo_idx][name], | |
| name=f'algo_{algo_idx}_{name}') | |
| encoders_.append(enc) | |
| decoders_.append(dec) | |
| return encoders_, decoders_ | |
| def _one_step_pred( | |
| self, | |
| inputs: _Trajectory, | |
| hints: _Trajectory, | |
| hidden: _Array, | |
| batch_size: int, | |
| nb_nodes: int, | |
| lstm_state: Optional[hk.LSTMState], | |
| spec: _Spec, | |
| encs: Dict[str, List[hk.Module]], | |
| decs: Dict[str, Tuple[hk.Module]], | |
| repred: bool, | |
| ): | |
| """Generates one-step predictions.""" | |
| # Initialise empty node/edge/graph features and adjacency matrix. | |
| node_fts = jnp.zeros((batch_size, nb_nodes, self.hidden_dim)) | |
| edge_fts = jnp.zeros((batch_size, nb_nodes, nb_nodes, self.hidden_dim)) | |
| graph_fts = jnp.zeros((batch_size, self.hidden_dim)) | |
| adj_mat = jnp.repeat( | |
| jnp.expand_dims(jnp.eye(nb_nodes), 0), batch_size, axis=0) | |
| # ENCODE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| # Encode node/edge/graph features from inputs and (optionally) hints. | |
| trajectories = [inputs] | |
| if self.encode_hints: | |
| trajectories.append(hints) | |
| for trajectory in trajectories: | |
| for dp in trajectory: | |
| try: | |
| dp = encoders.preprocess(dp, nb_nodes) | |
| assert dp.type_ != _Type.SOFT_POINTER | |
| adj_mat = encoders.accum_adj_mat(dp, adj_mat) | |
| encoder = encs[dp.name] | |
| edge_fts = encoders.accum_edge_fts(encoder, dp, edge_fts) | |
| node_fts = encoders.accum_node_fts(encoder, dp, node_fts) | |
| graph_fts = encoders.accum_graph_fts(encoder, dp, graph_fts) | |
| except Exception as e: | |
| raise Exception(f'Failed to process {dp}') from e | |
| # PROCESS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| nxt_hidden = hidden | |
| for _ in range(self.nb_msg_passing_steps): | |
| nxt_hidden, nxt_edge = self.processor( | |
| node_fts, | |
| edge_fts, | |
| graph_fts, | |
| adj_mat, | |
| nxt_hidden, | |
| batch_size=batch_size, | |
| nb_nodes=nb_nodes, | |
| ) | |
| if not repred: # dropout only on training | |
| nxt_hidden = hk.dropout(hk.next_rng_key(), self._dropout_prob, nxt_hidden) | |
| if self.use_lstm: | |
| # lstm doesn't accept multiple batch dimensions (in our case, batch and | |
| # nodes), so we vmap over the (first) batch dimension. | |
| nxt_hidden, nxt_lstm_state = jax.vmap(self.lstm)(nxt_hidden, lstm_state) | |
| else: | |
| nxt_lstm_state = None | |
| h_t = jnp.concatenate([node_fts, hidden, nxt_hidden], axis=-1) | |
| if nxt_edge is not None: | |
| e_t = jnp.concatenate([edge_fts, nxt_edge], axis=-1) | |
| else: | |
| e_t = edge_fts | |
| # DECODE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
| # Decode features and (optionally) hints. | |
| hint_preds, output_preds = decoders.decode_fts( | |
| decoders=decs, | |
| spec=spec, | |
| h_t=h_t, | |
| adj_mat=adj_mat, | |
| edge_fts=e_t, | |
| graph_fts=graph_fts, | |
| inf_bias=self.processor.inf_bias, | |
| inf_bias_edge=self.processor.inf_bias_edge, | |
| repred=repred, | |
| ) | |
| return nxt_hidden, output_preds, hint_preds, nxt_lstm_state | |
| class NetChunked(Net): | |
| """A Net that will process time-chunked data instead of full samples.""" | |
| def _msg_passing_step(self, | |
| mp_state: MessagePassingStateChunked, | |
| xs, | |
| repred: bool, | |
| init_mp_state: bool, | |
| batch_size: int, | |
| nb_nodes: int, | |
| spec: _Spec, | |
| encs: Dict[str, List[hk.Module]], | |
| decs: Dict[str, Tuple[hk.Module]], | |
| ): | |
| """Perform one message passing step. | |
| This function is unrolled along the time axis to process a data chunk. | |
| Args: | |
| mp_state: message-passing state. Includes the inputs, hints, | |
| beginning-of-sample markers, hint predictions, hidden and lstm state | |
| to be used for prediction in the current step. | |
| xs: A 3-tuple of with the next timestep's inputs, hints, and | |
| beginning-of-sample markers. These will replace the contents of | |
| the `mp_state` at the output, in readiness for the next unroll step of | |
| the chunk (or the first step of the next chunk). Besides, the next | |
| timestep's hints are necessary to compute diffs when `decode_diffs` | |
| is True. | |
| repred: False during training, when we have access to ground-truth hints. | |
| True in validation/test mode, when we have to use our own | |
| hint predictions. | |
| init_mp_state: Indicates if we are calling the method just to initialise | |
| the message-passing state, before the beginning of training or | |
| validation. | |
| batch_size: Size of batch dimension. | |
| nb_nodes: Number of nodes in graph. | |
| spec: The spec of the algorithm being processed. | |
| encs: encoders for the algorithm being processed. | |
| decs: decoders for the algorithm being processed. | |
| Returns: | |
| A 2-tuple with the next mp_state and an output consisting of | |
| hint predictions and output predictions. | |
| """ | |
| def _as_prediction_data(hint): | |
| if hint.type_ == _Type.POINTER: | |
| return hk.one_hot(hint.data, nb_nodes) | |
| return hint.data | |
| nxt_inputs, nxt_hints, nxt_is_first = xs | |
| inputs = mp_state.inputs | |
| is_first = mp_state.is_first | |
| hints = mp_state.hints | |
| if init_mp_state: | |
| prev_hint_preds = {h.name: _as_prediction_data(h) for h in hints} | |
| hints_for_pred = hints | |
| else: | |
| prev_hint_preds = mp_state.hint_preds | |
| if self.decode_hints: | |
| if repred: | |
| force_mask = jnp.zeros(batch_size, dtype=bool) | |
| elif self._hint_teacher_forcing == 1.0: | |
| force_mask = jnp.ones(batch_size, dtype=bool) | |
| else: | |
| force_mask = jax.random.bernoulli( | |
| hk.next_rng_key(), self._hint_teacher_forcing, | |
| (batch_size,)) | |
| assert self._hint_repred_mode in ['soft', 'hard', 'hard_on_eval'] | |
| hard_postprocess = ( | |
| self._hint_repred_mode == 'hard' or | |
| (self._hint_repred_mode == 'hard_on_eval' and repred)) | |
| decoded_hints = decoders.postprocess(spec, | |
| prev_hint_preds, | |
| sinkhorn_temperature=0.1, | |
| sinkhorn_steps=25, | |
| hard=hard_postprocess) | |
| hints_for_pred = [] | |
| for h in hints: | |
| typ = h.type_ | |
| hint_data = h.data | |
| if (typ == _Type.POINTER and | |
| decoded_hints[h.name].type_ == _Type.SOFT_POINTER): | |
| hint_data = hk.one_hot(hint_data, nb_nodes) | |
| typ = _Type.SOFT_POINTER | |
| hints_for_pred.append(probing.DataPoint( | |
| name=h.name, location=h.location, type_=typ, | |
| data=jnp.where(_expand_to(is_first | force_mask, hint_data), | |
| hint_data, decoded_hints[h.name].data))) | |
| else: | |
| hints_for_pred = hints | |
| hiddens = jnp.where(is_first[..., None, None], 0.0, mp_state.hiddens) | |
| if self.use_lstm: | |
| lstm_state = jax.tree_util.tree_map( | |
| lambda x: jnp.where(is_first[..., None, None], 0.0, x), | |
| mp_state.lstm_state) | |
| else: | |
| lstm_state = None | |
| hiddens, output_preds, hint_preds, lstm_state = self._one_step_pred( | |
| inputs, hints_for_pred, hiddens, | |
| batch_size, nb_nodes, lstm_state, | |
| spec, encs, decs, repred) | |
| new_mp_state = MessagePassingStateChunked( # pytype: disable=wrong-arg-types # numpy-scalars | |
| hiddens=hiddens, lstm_state=lstm_state, hint_preds=hint_preds, | |
| inputs=nxt_inputs, hints=nxt_hints, is_first=nxt_is_first) | |
| mp_output = _MessagePassingOutputChunked( # pytype: disable=wrong-arg-types # numpy-scalars | |
| hint_preds=hint_preds, | |
| output_preds=output_preds) | |
| return new_mp_state, mp_output | |
| def __call__(self, features_list: List[_FeaturesChunked], | |
| mp_state_list: List[MessagePassingStateChunked], | |
| repred: bool, init_mp_state: bool, | |
| algorithm_index: int): | |
| """Process one chunk of data. | |
| Args: | |
| features_list: A list of _FeaturesChunked objects, each with the | |
| inputs, hints and beginning- and end-of-sample markers for | |
| a chunk (i.e., fixed time length) of data corresponding to one | |
| algorithm. All features are expected | |
| to have dimensions chunk_length x batch_size x ... | |
| The list should have either length 1, at train/evaluation time, | |
| or length equal to the number of algorithms this Net is meant to | |
| process, at initialization. | |
| mp_state_list: list of message-passing states. Each message-passing state | |
| includes the inputs, hints, beginning-of-sample markers, | |
| hint prediction, hidden and lstm state from the end of the previous | |
| chunk, for one algorithm. The length of the list should be the same | |
| as the length of `features_list`. | |
| repred: False during training, when we have access to ground-truth hints. | |
| True in validation/test mode, when we have to use our own hint | |
| predictions. | |
| init_mp_state: Indicates if we are calling the network just to initialise | |
| the message-passing state, before the beginning of training or | |
| validation. If True, `algorithm_index` (see below) must be -1 in order | |
| to initialize the message-passing state of all algorithms. | |
| algorithm_index: Which algorithm is being processed. It can be -1 at | |
| initialisation (either because we are initialising the parameters of | |
| the module or because we are intialising the message-passing state), | |
| meaning that all algorithms should be processed, in which case | |
| `features_list` and `mp_state_list` should have length equal to the | |
| number of specs of the Net. Otherwise, `algorithm_index` should be | |
| between 0 and `length(self.spec) - 1`, meaning only one of the | |
| algorithms will be processed, and `features_list` and `mp_state_list` | |
| should have length 1. | |
| Returns: | |
| A 2-tuple consisting of: | |
| - A 2-tuple with (output predictions, hint predictions) | |
| for the selected algorithm. Each of these has | |
| chunk_length x batch_size x ... data, where the first time | |
| slice contains outputs for the mp_state | |
| that was passed as input, and the last time slice contains outputs | |
| for the next-to-last slice of the input features. The outputs that | |
| correspond to the final time slice of the input features will be | |
| calculated when the next chunk is processed, using the data in the | |
| mp_state returned here (see below). If `init_mp_state` is True, | |
| we return None instead of the 2-tuple. | |
| - The mp_state (message-passing state) for the next chunk of data | |
| of the selected algorithm. If `init_mp_state` is True, we return | |
| initial mp states for all the algorithms. | |
| """ | |
| if algorithm_index == -1: | |
| algorithm_indices = range(len(features_list)) | |
| else: | |
| algorithm_indices = [algorithm_index] | |
| assert not init_mp_state # init state only allowed with all algorithms | |
| assert len(algorithm_indices) == len(features_list) | |
| assert len(algorithm_indices) == len(mp_state_list) | |
| self.encoders, self.decoders = self._construct_encoders_decoders() | |
| self.processor = self.processor_factory(self.hidden_dim) | |
| # Optionally construct LSTM. | |
| if self.use_lstm: | |
| self.lstm = hk.LSTM( | |
| hidden_size=self.hidden_dim, | |
| name='processor_lstm') | |
| lstm_init = self.lstm.initial_state | |
| else: | |
| self.lstm = None | |
| lstm_init = lambda x: 0 | |
| if init_mp_state: | |
| output_mp_states = [] | |
| for algorithm_index, features, mp_state in zip( | |
| algorithm_indices, features_list, mp_state_list): | |
| inputs = features.inputs | |
| hints = features.hints | |
| batch_size, nb_nodes = _data_dimensions_chunked(features) | |
| if self.use_lstm: | |
| lstm_state = lstm_init(batch_size * nb_nodes) | |
| lstm_state = jax.tree_util.tree_map( | |
| lambda x, b=batch_size, n=nb_nodes: jnp.reshape(x, [b, n, -1]), | |
| lstm_state) | |
| mp_state.lstm_state = lstm_state | |
| mp_state.inputs = jax.tree_util.tree_map(lambda x: x[0], inputs) | |
| mp_state.hints = jax.tree_util.tree_map(lambda x: x[0], hints) | |
| mp_state.is_first = jnp.zeros(batch_size, dtype=int) | |
| mp_state.hiddens = jnp.zeros((batch_size, nb_nodes, self.hidden_dim)) | |
| next_is_first = jnp.ones(batch_size, dtype=int) | |
| mp_state, _ = self._msg_passing_step( | |
| mp_state, | |
| (mp_state.inputs, mp_state.hints, next_is_first), | |
| repred=repred, | |
| init_mp_state=True, | |
| batch_size=batch_size, | |
| nb_nodes=nb_nodes, | |
| spec=self.spec[algorithm_index], | |
| encs=self.encoders[algorithm_index], | |
| decs=self.decoders[algorithm_index], | |
| ) | |
| output_mp_states.append(mp_state) | |
| return None, output_mp_states | |
| for algorithm_index, features, mp_state in zip( | |
| algorithm_indices, features_list, mp_state_list): | |
| inputs = features.inputs | |
| hints = features.hints | |
| is_first = features.is_first | |
| batch_size, nb_nodes = _data_dimensions_chunked(features) | |
| scan_fn = functools.partial( | |
| self._msg_passing_step, | |
| repred=repred, | |
| init_mp_state=False, | |
| batch_size=batch_size, | |
| nb_nodes=nb_nodes, | |
| spec=self.spec[algorithm_index], | |
| encs=self.encoders[algorithm_index], | |
| decs=self.decoders[algorithm_index], | |
| ) | |
| mp_state, scan_output = hk.scan( | |
| scan_fn, | |
| mp_state, | |
| (inputs, hints, is_first), | |
| ) | |
| # We only return the last algorithm's output and state. That's because | |
| # the output only matters when a single algorithm is processed; the case | |
| # `algorithm_index==-1` (meaning all algorithms should be processed) | |
| # is used only to init parameters. | |
| return (scan_output.output_preds, scan_output.hint_preds), mp_state | |
| def _data_dimensions(features: _Features) -> Tuple[int, int]: | |
| """Returns (batch_size, nb_nodes).""" | |
| for inp in features.inputs: | |
| if inp.location in [_Location.NODE, _Location.EDGE]: | |
| return inp.data.shape[:2] | |
| assert False | |
| def _data_dimensions_chunked(features: _FeaturesChunked) -> Tuple[int, int]: | |
| """Returns (batch_size, nb_nodes).""" | |
| for inp in features.inputs: | |
| if inp.location in [_Location.NODE, _Location.EDGE]: | |
| return inp.data.shape[1:3] | |
| assert False | |
| def _expand_to(x: _Array, y: _Array) -> _Array: | |
| while len(y.shape) > len(x.shape): | |
| x = jnp.expand_dims(x, -1) | |
| return x | |
| def _is_not_done_broadcast(lengths, i, tensor): | |
| is_not_done = (lengths > i + 1) * 1.0 | |
| while len(is_not_done.shape) < len(tensor.shape): # pytype: disable=attribute-error # numpy-scalars | |
| is_not_done = jnp.expand_dims(is_not_done, -1) | |
| return is_not_done | |