| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Baseline for Image to Text Models. |
| | |
| | B = batch size |
| | H = height |
| | W = width |
| | N = number of image tokens |
| | I = Input sequence length |
| | O = Ouput sequence length |
| | d = hidden dims |
| | C = number of vocabulary |
| | K = number of candidate |
| | L = sequence length of retrieved document |
| | M = sequence length of compressed tokens |
| | """ |
| | import functools |
| | from typing import Any, Dict, Mapping, Optional, Tuple, List |
| |
|
| | from absl import logging |
| | import flax.linen as nn |
| | import jax |
| | import jax.numpy as jnp |
| | import ml_collections |
| | import numpy as np |
| | from scenic.dataset_lib import dataset_utils |
| | from scenic.model_lib.base_models import base_model |
| | from scenic.projects.knowledge_visual_language.data import data_utils |
| | from scenic.projects.knowledge_visual_language.models import constants |
| | from scenic.projects.knowledge_visual_language.models import fusion_in_decoder_soft |
| | from scenic.projects.knowledge_visual_language.models import layers |
| | from scenic.projects.knowledge_visual_language.models import local_memory |
| | from scenic.projects.knowledge_visual_language.models import losses |
| | from scenic.projects.knowledge_visual_language.models import metrics |
| | from t5x import decoding |
| |
|
| | local_kb = local_memory.kb |
| |
|
| |
|
| | class KnowledgeFIDModule(fusion_in_decoder_soft.FusionInDecoderSoftModule): |
| | """FID model (https://arxiv.org/pdf/2007.01282.pdf) with a retrieval module over a knowledge memory.""" |
| |
|
| | retr_k: int |
| | data_k: int |
| | axis_index_groups: Optional[List[List[int]]] = None |
| | across_index_groups: Optional[List[List[int]]] = None |
| |
|
| | def setup(self): |
| | super().setup() |
| | self.local_keys = self.variable( |
| | 'memory', |
| | 'keys', |
| | functools.partial(jnp.zeros, dtype=jnp.bfloat16), |
| | (local_kb.n_data_per_shard, self.key_dim), |
| | ) |
| | self.local_dataset_idxs = self.variable( |
| | 'memory', |
| | 'idxs', |
| | functools.partial(jnp.zeros, dtype=jnp.int16), |
| | (local_kb.n_data_per_shard * local_kb.n_local_device), |
| | ) |
| | self.dataset_gate = nn.Dense( |
| | features=local_kb.n_kb_dataset, dtype=self.dtype, name='dataset_gate' |
| | ) |
| |
|
| | def _get_corpus_scores(self, corpus_scores, topk_ids): |
| | corpus_ids = jnp.take(self.local_dataset_idxs.value, topk_ids, axis=0) |
| | return layers.batch_index_select(corpus_scores, corpus_ids), corpus_ids |
| |
|
| | def _dist_mips_local( |
| | self, |
| | query, |
| | corpus_scores, |
| | local_device_id, |
| | recall_target=0.99, |
| | exact=False, |
| | ): |
| | raise NotImplementedError( |
| | 'jax.experimental.host_callback has been removed.' |
| | ) |
| |
|
| | def _dist_mips_across( |
| | self, |
| | query, |
| | corpus_scores, |
| | local_device_id, |
| | recall_target=0.99, |
| | exact=False, |
| | ): |
| | |
| | logging.info('mips global!!!') |
| | logging.info(self.local_keys.value.shape) |
| | logging.info(local_kb.n_data) |
| | n_local_device = len(self.across_index_groups) |
| | logging.info(n_local_device) |
| | global_query = jax.lax.all_gather( |
| | x=query, axis_name='batch', axis=0, tiled=True |
| | ) |
| | logging.info(global_query.shape) |
| | |
| | global_corpus_scores = jax.lax.all_gather( |
| | x=corpus_scores, axis_name='batch', axis=0, tiled=True |
| | ) |
| | logging.info(global_corpus_scores.shape) |
| | |
| |
|
| | local_scores = jax.lax.dot(global_query, self.local_keys.value.transpose()) |
| | local_k = local_kb.k |
| | if exact: |
| | local_topk_scores, local_topk_ids = jax.lax.top_k(local_scores, k=local_k) |
| | else: |
| | local_topk_scores, local_topk_ids = jax.lax.approx_max_k( |
| | local_scores, |
| | k=local_k, |
| | recall_target=recall_target, |
| | reduction_input_size_override=local_kb.n_data, |
| | aggregate_to_topk=True, |
| | ) |
| |
|
| | local_topk_ids_offset = ( |
| | local_topk_ids + local_device_id * local_kb.n_data_per_shard |
| | ) |
| | logging.info(local_topk_ids.shape) |
| | |
| | host_topk_scores = jax.lax.all_gather( |
| | x=local_topk_scores, |
| | axis_name='batch', |
| | axis=1, |
| | axis_index_groups=self.axis_index_groups, |
| | tiled=True, |
| | ) |
| | logging.info(host_topk_scores.shape) |
| | |
| | host_topk_ids = jax.lax.all_gather( |
| | x=local_topk_ids_offset, |
| | axis_name='batch', |
| | axis=1, |
| | axis_index_groups=self.axis_index_groups, |
| | tiled=True, |
| | ) |
| | |
| |
|
| | host_corpus_scores, host_corpus_ids = self._get_corpus_scores( |
| | global_corpus_scores, host_topk_ids |
| | ) |
| | |
| | host_topk_scores, host_rank_ids = jax.lax.top_k( |
| | host_topk_scores * host_corpus_scores, k=local_k |
| | ) |
| | |
| | host_topk_ids = layers.batch_index_select(host_topk_ids, host_rank_ids) |
| | logging.info(host_topk_ids.shape) |
| | |
| | host_topk_ids = jnp.reshape(host_topk_ids, (-1, n_local_device, local_k)) |
| | host_topk_ids = host_topk_ids[:, local_device_id] |
| | logging.info(host_topk_ids.shape) |
| | |
| | host_topk_scores = jnp.reshape( |
| | host_topk_scores, (-1, n_local_device, local_k) |
| | ) |
| | host_topk_scores = host_topk_scores[:, local_device_id] |
| | logging.info('host_topk_scores') |
| | logging.info(host_topk_scores.shape) |
| |
|
| | ret_memory, ret_data = host_callback.call( |
| | local_memory.retrieve_top_memory, |
| | (host_topk_ids), |
| | result_shape=local_kb.ret_top_specs, |
| | ) |
| |
|
| | global_topk_scores = jax.lax.all_to_all( |
| | x=host_topk_scores, |
| | axis_name='batch', |
| | split_axis=0, |
| | concat_axis=1, |
| | axis_index_groups=self.across_index_groups, |
| | tiled=True, |
| | ) |
| | logging.info('global_topk_scores') |
| | logging.info(global_topk_scores.shape) |
| | |
| | global_topk_scores, global_rank_ids = jax.lax.top_k( |
| | global_topk_scores, k=self.retr_k |
| | ) |
| | logging.info(global_topk_scores.shape) |
| | |
| | global_data_ids = global_rank_ids[:, : int(self.data_k)] |
| | global_memory_ids = global_rank_ids[:, int(self.data_k) :] |
| |
|
| | def _gather_val(local_ret_vals, top_ids): |
| | logging.info(local_ret_vals.shape) |
| | global_ret_vals = jax.lax.all_to_all( |
| | x=local_ret_vals, |
| | axis_name='batch', |
| | split_axis=0, |
| | concat_axis=1, |
| | axis_index_groups=self.across_index_groups, |
| | tiled=True, |
| | ) |
| | logging.info(global_ret_vals.shape) |
| | |
| | global_ret_vals = layers.batch_index_select(global_ret_vals, top_ids) |
| | logging.info(global_ret_vals.shape) |
| | |
| | return global_ret_vals |
| |
|
| | logging.info('_gather_val!!!') |
| |
|
| | ret_memory = jax.tree_util.tree_map( |
| | lambda local_val: _gather_val(local_val, global_memory_ids), ret_memory |
| | ) |
| |
|
| | ret_data = jax.tree_util.tree_map( |
| | lambda local_val: _gather_val(local_val, global_data_ids), ret_data |
| | ) |
| |
|
| | ret_memory['masks'] = jnp.ones(ret_memory['values'].shape[:3]).astype(bool) |
| |
|
| | for k in ret_memory: |
| | logging.info(k) |
| | logging.info(ret_memory[k].shape) |
| | logging.info(ret_memory[k].dtype) |
| |
|
| | host_corpus_ids = layers.batch_index_select(host_corpus_ids, host_rank_ids) |
| | host_corpus_ids = jnp.reshape( |
| | host_corpus_ids, (-1, n_local_device, local_k) |
| | )[:, local_device_id] |
| | logging.info('corpus_ids') |
| | logging.info(host_corpus_ids.shape) |
| | |
| |
|
| | global_corpus_ids = jax.lax.all_to_all( |
| | x=host_corpus_ids, |
| | axis_name='batch', |
| | split_axis=0, |
| | concat_axis=1, |
| | axis_index_groups=self.across_index_groups, |
| | tiled=True, |
| | ) |
| | logging.info(global_corpus_ids.shape) |
| | |
| | global_corpus_ids = layers.batch_index_select( |
| | global_corpus_ids, global_rank_ids |
| | ) |
| | logging.info(global_corpus_ids.shape) |
| | |
| | return ( |
| | global_topk_scores, |
| | ret_memory, |
| | ret_data, |
| | local_topk_ids, |
| | global_rank_ids, |
| | global_corpus_ids, |
| | ) |
| |
|
| | def t5_decode( |
| | self, |
| | encoded, |
| | encoder_input_tokens: jnp.ndarray, |
| | decoder_input_tokens: jnp.ndarray, |
| | decoder_target_tokens: jnp.ndarray, |
| | enable_dropout: bool = True, |
| | decode: bool = False, |
| | max_decode_length: Optional[int] = None, |
| | ): |
| | """wraps _t5_decoder call (no packing) to enable autoregressive decoding.""" |
| | |
| | |
| | return self.out_decoder( |
| | encoded=encoded, |
| | encoder_input_tokens=encoder_input_tokens, |
| | decoder_input_tokens=decoder_input_tokens, |
| | decoder_target_tokens=decoder_target_tokens, |
| | enable_dropout=enable_dropout, |
| | decode=decode, |
| | max_decode_length=max_decode_length, |
| | ) |
| |
|
| | def __call__( |
| | self, |
| | decoder_input_tokens, |
| | decoder_target_tokens, |
| | encoder_input_image=None, |
| | encoder_input_tokens=None, |
| | retr_texts=None, |
| | retr_images=None, |
| | device_id=0, |
| | train=False, |
| | decode=False, |
| | max_decode_length=None, |
| | use_memory=False, |
| | use_psudo_retr=False, |
| | retrieve_local=False, |
| | no_memory=False, |
| | debug=False, |
| | frozen_base=True, |
| | only_encode=False, |
| | **args |
| | ): |
| | """Conduct online retrieval and retrieval-augmented generataion. |
| | |
| | Args: |
| | decoder_input_tokens: # B×O. |
| | decoder_target_tokens: # B×O. |
| | encoder_input_image: # B×W×H×3. |
| | encoder_input_tokens: # B×I. |
| | retr_texts: # B×K×L. |
| | retr_images: # B×K×W×H×3. |
| | device_id: index of TPU device. |
| | train: whether using train mode. |
| | decode: whether in decode mode. |
| | max_decode_length: maximum decode token length. |
| | use_memory: whether use on-device memory. |
| | use_psudo_retr: whether to use psudo retrieved groundtruth for guidance. |
| | retrieve_local: whether only retrieve in local host or across hosts. |
| | no_memory: whether not using any retrieval. |
| | debug: whether use debug mode. |
| | frozen_base: whether froze the whole encoder. |
| | only_encode: skip decoding and only return encoded tokens. |
| | **args: other possible arguments. |
| | |
| | Returns: |
| | output dictionary containing final and intermediate results. |
| | """ |
| | bsz = decoder_input_tokens.shape[0] |
| |
|
| | out_dict = {} |
| | base_vals, base_masks, base_query = self.encode_query( |
| | encoder_input_image=encoder_input_image, |
| | encoder_input_tokens=encoder_input_tokens, |
| | frozen_base=frozen_base, |
| | ) |
| | base_query = self.dropout(base_query, deterministic=not train) |
| | base_vals = self.dropout(base_vals, deterministic=not train) |
| | if debug: |
| | out_dict['base_query'] = base_query |
| | out_dict['base_masks'] = base_masks |
| | corpus_scores = jax.nn.softmax(self.dataset_gate(base_query), axis=-1) |
| | out_dict['corpus_scores'] = corpus_scores |
| | if no_memory: |
| | fused_emb, fused_mask, attn_weights_all_layers = self.fusion_encoder( |
| | fused_input_embs=base_vals, fused_mask=base_masks, use_dropout=train |
| | ) |
| | else: |
| | if use_memory: |
| | detached_query = jax.lax.stop_gradient(base_query) |
| | if retrieve_local: |
| | ( |
| | topk_scores, |
| | ret_memory, |
| | ret_data, |
| | local_topk_ids, |
| | global_topk_ids, |
| | global_corpus_ids, |
| | ) = self._dist_mips_local( |
| | query=detached_query, |
| | corpus_scores=corpus_scores, |
| | local_device_id=device_id, |
| | ) |
| | else: |
| | ( |
| | topk_scores, |
| | ret_memory, |
| | ret_data, |
| | local_topk_ids, |
| | global_topk_ids, |
| | global_corpus_ids, |
| | ) = self._dist_mips_across( |
| | query=detached_query, |
| | corpus_scores=corpus_scores, |
| | local_device_id=device_id, |
| | ) |
| | out_dict['topk_scores'] = topk_scores |
| |
|
| | |
| | retr_keys, retr_vals, retr_masks, _, disentangle_reg = ( |
| | self.encode_topk_knowledge( |
| | bsz=bsz, |
| | retr_images=ret_data['image'], |
| | retr_texts=ret_data['text_tokens'], |
| | train=train, |
| | random_drop_image=False, |
| | frozen_base=frozen_base, |
| | ) |
| | ) |
| |
|
| | global_corpus_scores = layers.batch_index_select( |
| | corpus_scores, global_corpus_ids |
| | ) |
| |
|
| | if debug: |
| | out_dict['detached_query'] = detached_query |
| | out_dict['global_corpus_scores'] = global_corpus_scores |
| | out_dict['global_corpus_ids'] = global_corpus_ids |
| | out_dict['local_topk_ids'] = local_topk_ids |
| | out_dict['global_topk_ids'] = global_topk_ids |
| | out_dict['retr_keys'] = retr_keys |
| | out_dict['retr_masks'] = retr_masks |
| | out_dict['base_vals'] = base_vals |
| | out_dict['retr_vals'] = retr_vals |
| |
|
| | out_dict['retr_data'] = ret_data |
| | out_dict['base_norm'] = layers.l2_norm(base_vals).mean() |
| | out_dict['data_norm'] = layers.l2_norm(retr_vals).mean() |
| | out_dict['vals_norm'] = layers.l2_norm(ret_memory['values'][0]).mean() |
| | out_dict['gap'] = jnp.abs( |
| | 1 - jnp.divide(out_dict['data_norm'], out_dict['base_norm']) |
| | ) |
| |
|
| | if train and retr_texts is not None and use_psudo_retr: |
| | logging.info('global keys!!!') |
| | ground_truth_keys, ground_truth_vals, _, _, _ = self.encode_knowledge( |
| | retr_texts=retr_texts, |
| | retr_images=retr_images, |
| | bsz=bsz, |
| | train=train, |
| | random_drop_image=True, |
| | frozen_base=frozen_base, |
| | ) |
| | global_keys = jnp.concatenate( |
| | jax.lax.all_gather( |
| | x=ground_truth_keys, axis_name='batch', axis=0 |
| | ), |
| | axis=0, |
| | ) |
| | logging.info(global_keys.shape) |
| | inbatch_sim = jax.lax.dot(base_query, global_keys.transpose()) |
| | out_dict['inbatch_sim'] = inbatch_sim |
| | if debug: |
| | out_dict['global_keys'] = global_keys |
| | out_dict['ground_truth_keys'] = ground_truth_keys |
| | out_dict['ground_truth_vals'] = ground_truth_vals |
| | |
| | k = retr_keys.shape[1] |
| | ground_truth_keys = jnp.repeat( |
| | jnp.expand_dims(ground_truth_keys, axis=1), axis=1, repeats=k |
| | ) |
| | ground_truth_vals = jnp.repeat( |
| | jnp.expand_dims(ground_truth_vals, axis=1), axis=1, repeats=k |
| | ) |
| | replace_mask = jax.random.bernoulli( |
| | self.make_rng('dropout'), p=0.02, shape=(bsz, 1, 1) |
| | ) |
| | keys_mask = jnp.broadcast_to(replace_mask, retr_keys.shape) |
| | retr_keys = jax.lax.select(keys_mask, ground_truth_keys, retr_keys) |
| | vals_mask = jnp.broadcast_to( |
| | jnp.expand_dims(replace_mask, axis=-1), retr_vals.shape |
| | ) |
| | retr_vals = jax.lax.select(vals_mask, ground_truth_vals, retr_vals) |
| |
|
| | logging.info('Concat memory and data!!!') |
| | logging.info(retr_keys.shape) |
| | logging.info(ret_memory['keys'].shape) |
| | logging.info(global_corpus_scores.shape) |
| | |
| |
|
| | retr_keys = jnp.concatenate([retr_keys, ret_memory['keys']], axis=1) |
| | retr_keys = retr_keys * jnp.expand_dims(global_corpus_scores, axis=-1) |
| | retr_vals = jnp.concatenate([retr_vals, ret_memory['values']], axis=1) |
| | retr_masks = jnp.concatenate([retr_masks, ret_memory['masks']], axis=1) |
| | elif retr_texts is not None: |
| | retr_keys, retr_vals, retr_masks, _, disentangle_reg = ( |
| | self.encode_topk_knowledge( |
| | bsz=bsz, |
| | retr_images=jnp.expand_dims(retr_images, axis=1), |
| | retr_texts=jnp.expand_dims(retr_texts, axis=1), |
| | train=train, |
| | random_drop_image=False, |
| | ) |
| | ) |
| | else: |
| | retr_keys, retr_vals, retr_masks, _, disentangle_reg = ( |
| | self.encode_topk_knowledge( |
| | bsz=bsz, |
| | retr_images=jnp.expand_dims(encoder_input_image, axis=1), |
| | retr_texts=jnp.expand_dims(encoder_input_tokens, axis=1), |
| | train=train, |
| | random_drop_image=False, |
| | ) |
| | ) |
| |
|
| | fused_emb, fused_mask, retr_scores, attn_weights_all_layers = ( |
| | self.fuse_topk_knowledge( |
| | base_query=base_query, |
| | base_vals=base_vals, |
| | base_masks=base_masks, |
| | retr_keys=retr_keys, |
| | retr_vals=retr_vals, |
| | retr_masks=retr_masks, |
| | train=train, |
| | ) |
| | ) |
| | out_dict['disentangle_reg'] = jnp.mean(disentangle_reg) |
| | out_dict['retr_scores'] = retr_scores |
| |
|
| | out_dict['fused_emb'] = fused_emb |
| | out_dict['fused_mask'] = fused_mask |
| | logging.info('fused_emb.shape') |
| | logging.info(fused_emb.shape) |
| | out_dict['attn_weights_all_layers'] = attn_weights_all_layers |
| |
|
| | if not only_encode: |
| | |
| | out_dict['predicted_logits'] = self.t5_decode( |
| | encoded=fused_emb, |
| | encoder_input_tokens=fused_mask, |
| | decoder_input_tokens=decoder_input_tokens, |
| | decoder_target_tokens=decoder_target_tokens, |
| | enable_dropout=train, |
| | decode=decode, |
| | max_decode_length=max_decode_length, |
| | ) |
| | return out_dict |
| |
|
| |
|
| | class KnowledgeFIDModel(base_model.BaseModel): |
| | """FID model with a retrieval module over a knowledge memory.""" |
| |
|
| | def __init__( |
| | self, |
| | config: Optional[ml_collections.ConfigDict], |
| | dataset_meta_data: Dict[str, Any], |
| | kb_datasets: Dict[str, dataset_utils.Dataset], |
| | ) -> None: |
| | self.config = config |
| | self.dataset_meta_data = dataset_meta_data |
| | self.retr_k = self.config.model.retr_k |
| | self.retr_data_ratio = self.config.model.retr_data_ratio |
| | n_device = jax.device_count() |
| | self.data_k = int(np.ceil(self.retr_k * self.retr_data_ratio)) |
| | device_per_axis = jax.local_device_count() |
| | if n_device < device_per_axis: |
| | self.axis_index_groups = None |
| | self.across_index_groups = None |
| | else: |
| | self.axis_index_groups = np.arange(n_device).reshape( |
| | [n_device // device_per_axis, device_per_axis] |
| | ) |
| | self.across_index_groups = self.axis_index_groups.T.tolist() |
| | self.axis_index_groups = self.axis_index_groups.tolist() |
| | logging.info('axis_index_groups') |
| | logging.info(self.axis_index_groups) |
| | logging.info(self.across_index_groups) |
| | local_kb.initialize(kb_datasets=kb_datasets) |
| | self.flax_model = self.build_flax_model() |
| |
|
| | def build_flax_model(self) -> nn.Module: |
| | return KnowledgeFIDModule( |
| | self.config.model, |
| | retr_k=self.retr_k, |
| | data_k=self.data_k, |
| | axis_index_groups=self.axis_index_groups, |
| | across_index_groups=self.across_index_groups, |
| | ) |
| |
|
| | def loss_function_dict( |
| | self, output: constants.JTensorDict, batch: constants.JTensorDict |
| | ) -> Dict[str, Any]: |
| | """Returns negative loglikelihood (NLL) of the target sentence. |
| | |
| | Args: |
| | output: Output of model in OrderedDict. |
| | batch: Batch of data that has 'decoder_target' as ground-truth. |
| | |
| | Returns: |
| | Total loss. |
| | """ |
| | model_config = self.config.model |
| | gen_loss = losses.nll_loss( |
| | targets=batch['decoder_target_tokens'], |
| | pred=output['predicted_logits'], |
| | target_masks=batch['decoder_target_tokens'] > 0, |
| | label_smoothing=self.config.model.get('label_smoothing'), |
| | ) |
| | loss_dict = {'gen_loss': gen_loss} |
| |
|
| | if 'inbatch_sim' in output: |
| | score_matrix = output['inbatch_sim'] |
| | bsz = score_matrix.shape[0] |
| | labels = jnp.arange(bsz) + bsz * jax.lax.axis_index(axis_name='batch') |
| | contra_loss = losses.nll_loss( |
| | pred=score_matrix / self.config.model.get('temperature'), |
| | targets=labels, |
| | ) |
| | loss_dict['contra_loss'] = contra_loss |
| | r = model_config.retrieval_ratio |
| | loss = gen_loss * (1 - r) + contra_loss * r |
| | accs = jnp.equal(jnp.argmax(score_matrix, axis=1), labels) |
| | loss_dict['contra_accs'] = accs |
| | else: |
| | loss_dict['contra_loss'] = 0.0 |
| | loss_dict['contra_accs'] = 0.0 |
| | loss = gen_loss |
| |
|
| | if 'disentangle' in model_config and 'disentangle_reg' in output: |
| | loss += output['disentangle_reg'] * 1e-2 |
| | if 'gap' in model_config and 'gap' in output: |
| | loss += output['gap'] * 1e-4 |
| | loss_dict['total_loss'] = loss |
| | return loss_dict |
| |
|
| | def get_metrics_fn(self, split: Optional[str] = None) -> Any: |
| | """Returns a callable metric function for the model. |
| | |
| | Args: |
| | split: The split for which we calculate the metrics. It should be one of |
| | the ['train', 'validation', 'test']. |
| | Returns: A metric function with the following API: ```metrics_fn(outputs, |
| | batch)``` |
| | """ |
| | return metrics.token_accuracy |
| |
|
| | def get_vqa_metrics( |
| | self, |
| | logits: jnp.ndarray, |
| | batch: constants.JTensorDict, |
| | split: Optional[str] = None, |
| | ) -> dict[str, float]: |
| | """Returns the VQA Accuracy for the validation / test set. |
| | |
| | Args: |
| | logits: Output of model in shape [B, L, C]. |
| | batch: Batch of data that has 'decoder_target' as ground-truth. |
| | split: The split for which we calculate the metrics. It should be one of |
| | the ['train', 'validation', 'test']. |
| | Returns: VQA accuracy``` |
| | """ |
| |
|
| | return metrics.vqa_accuracy(logits, batch) |
| |
|
| | def single_decode_step( |
| | self, |
| | decoding_state: decoding.DecodingState, |
| | variables: constants.PyTree, |
| | encoded_inputs: jnp.ndarray, |
| | input_masks: jnp.ndarray, |
| | max_decode_length: int, |
| | ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: |
| | """Single autoregressive decode step with caching.""" |
| | flat_ids = decoding_state.cur_token |
| | flat_cache = decoding_state.cache |
| | |
| | |
| | |
| | |
| | flat_logits, new_vars = self.flax_model.apply( |
| | {'cache': flat_cache, **variables}, |
| | encoded=encoded_inputs, |
| | encoder_input_tokens=input_masks, |
| | decoder_input_tokens=flat_ids, |
| | decoder_target_tokens=flat_ids, |
| | decode=True, |
| | enable_dropout=False, |
| | max_decode_length=max_decode_length, |
| | mutable=['cache'], |
| | method=self.flax_model.t5_decode, |
| | ) |
| | |
| | flat_logits = jnp.squeeze(flat_logits, axis=1) |
| | new_flat_cache = new_vars['cache'] |
| | return flat_logits, new_flat_cache |
| |
|
| | def apply_with_autoregressive_decoding( |
| | self, |
| | variables: constants.PyTree, |
| | decoder_input_tokens: jnp.ndarray, |
| | decoder_target_tokens: jnp.ndarray, |
| | encoder_input_image: Optional[jnp.ndarray] = None, |
| | encoder_input_tokens: Optional[jnp.ndarray] = None, |
| | num_decodes: int = 1, |
| | debug: bool = False, |
| | beam_search: bool = True, |
| | decoder_params: Optional[dict[str, Any]] = None, |
| | return_all_decodes: bool = False, |
| | use_memory=False, |
| | retrieve_local=False, |
| | **args |
| | ): |
| | """Apply inference with autoregressive decoding. |
| | |
| | Apply t5x autoregressive decoding with cache using either their |
| | beam_search or temperature_sample decoding technique. |
| | |
| | Args: |
| | variables: variables of the models. |
| | decoder_input_tokens: # B×O. |
| | decoder_target_tokens: # B×O. |
| | encoder_input_image: # B×W×H×3. |
| | encoder_input_tokens: # B×I. |
| | num_decodes: number of outputs generated per input for the decode search. |
| | debug: Whether in debug mode or not. |
| | beam_search: If True, do beam search. If False, do temperature sampling. |
| | decoder_params: Additional decoding parameters. These provide additional |
| | parameters to beam_search or temperature_sample (see decoder module). |
| | return_all_decodes: If True, return all decodes. Otherwise only return the |
| | top scored decoding. |
| | use_memory: whether use on-device memory. |
| | retrieve_local: whether only retrieve in local host or across hosts. |
| | **args: other possible arguments. |
| | |
| | Returns: |
| | logits array from the final decoder. |
| | """ |
| | |
| | _, model_state_with_cache = self.flax_model.apply( |
| | variables=variables, |
| | encoder_input_image=encoder_input_image, |
| | encoder_input_tokens=encoder_input_tokens, |
| | decoder_input_tokens=decoder_input_tokens, |
| | decoder_target_tokens=decoder_target_tokens, |
| | train=False, |
| | only_encode=False, |
| | decode=True, |
| | mutable=['cache'], |
| | debug=debug, |
| | use_memory=use_memory, |
| | retrieve_local=retrieve_local, |
| | ) |
| |
|
| | |
| | |
| | out_dict = self.flax_model.apply( |
| | variables=variables, |
| | encoder_input_image=encoder_input_image, |
| | encoder_input_tokens=encoder_input_tokens, |
| | decoder_input_tokens=decoder_input_tokens, |
| | decoder_target_tokens=decoder_target_tokens, |
| | train=False, |
| | only_encode=True, |
| | debug=debug, |
| | use_memory=use_memory, |
| | retrieve_local=retrieve_local, |
| | ) |
| | retr_top_image = out_dict['retr_data']['image'][:, 0] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | beam_expand_fn = functools.partial( |
| | decoding.flat_batch_beam_expand, beam_size=num_decodes |
| | ) |
| | encoded_inputs = jax.tree_util.tree_map( |
| | beam_expand_fn, out_dict['fused_emb'] |
| | ) |
| | encoded_masks = jax.tree_util.tree_map( |
| | beam_expand_fn, out_dict['fused_mask'] |
| | ) |
| | bsz = decoder_input_tokens.shape[0] |
| | max_decode_length = decoder_input_tokens.shape[-1] |
| | |
| | tokens_ids_to_logits = functools.partial( |
| | self.single_decode_step, |
| | variables=variables, |
| | encoded_inputs=encoded_inputs, |
| | input_masks=encoded_masks, |
| | max_decode_length=decoder_input_tokens.shape[-1], |
| | ) |
| |
|
| | if decoder_params is None: |
| | decoder_params = {} |
| | |
| | |
| | |
| | decoder_prompt_inputs = jnp.zeros([bsz, max_decode_length - 1]) |
| | bos_inputs = jnp.ones([bsz, 1]) * data_utils.BOS_ID |
| | decoder_prompt_inputs = jnp.concatenate( |
| | (bos_inputs, decoder_prompt_inputs), axis=-1, dtype=jnp.int32 |
| | ) |
| | if beam_search: |
| | decodes, scores = decoding.beam_search( |
| | inputs=decoder_prompt_inputs, |
| | cache=model_state_with_cache['cache'], |
| | tokens_to_logits=tokens_ids_to_logits, |
| | eos_id=data_utils.EOS_ID, |
| | num_decodes=num_decodes, |
| | cache_offset=0, |
| | **decoder_params |
| | ) |
| | else: |
| | decodes, scores = decoding.temperature_sample( |
| | inputs=decoder_prompt_inputs, |
| | cache=model_state_with_cache['cache'], |
| | tokens_to_logits=tokens_ids_to_logits, |
| | eos_id=data_utils.EOS_ID, |
| | num_decodes=num_decodes, |
| | cache_offset=0, |
| | initial_index=jnp.zeros([bsz], dtype=jnp.int32), |
| | **decoder_params |
| | ) |
| | if return_all_decodes: |
| | return decodes, scores, retr_top_image |
| | else: |
| | return decodes[:, -1, :], scores[:, -1], retr_top_image |
| |
|