fcxfcx's picture
Upload 2446 files
1327f34 verified
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Baseline for Image to Text Models.
B = batch size
H = height
W = width
N = number of image tokens
I = Input sequence length
O = Ouput sequence length
d = hidden dims
C = number of vocabulary
K = number of candidate
L = sequence length of retrieved document
M = sequence length of compressed tokens
"""
import functools
from typing import Any, Dict, Mapping, Optional, Tuple, List
from absl import logging
import flax.linen as nn
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
from scenic.dataset_lib import dataset_utils
from scenic.model_lib.base_models import base_model
from scenic.projects.knowledge_visual_language.data import data_utils
from scenic.projects.knowledge_visual_language.models import constants
from scenic.projects.knowledge_visual_language.models import fusion_in_decoder_soft
from scenic.projects.knowledge_visual_language.models import layers
from scenic.projects.knowledge_visual_language.models import local_memory
from scenic.projects.knowledge_visual_language.models import losses
from scenic.projects.knowledge_visual_language.models import metrics
from t5x import decoding
local_kb = local_memory.kb
class KnowledgeFIDModule(fusion_in_decoder_soft.FusionInDecoderSoftModule):
"""FID model (https://arxiv.org/pdf/2007.01282.pdf) with a retrieval module over a knowledge memory."""
retr_k: int
data_k: int
axis_index_groups: Optional[List[List[int]]] = None
across_index_groups: Optional[List[List[int]]] = None
def setup(self):
super().setup()
self.local_keys = self.variable(
'memory',
'keys',
functools.partial(jnp.zeros, dtype=jnp.bfloat16),
(local_kb.n_data_per_shard, self.key_dim),
)
self.local_dataset_idxs = self.variable(
'memory',
'idxs',
functools.partial(jnp.zeros, dtype=jnp.int16),
(local_kb.n_data_per_shard * local_kb.n_local_device),
)
self.dataset_gate = nn.Dense(
features=local_kb.n_kb_dataset, dtype=self.dtype, name='dataset_gate'
)
def _get_corpus_scores(self, corpus_scores, topk_ids):
corpus_ids = jnp.take(self.local_dataset_idxs.value, topk_ids, axis=0)
return layers.batch_index_select(corpus_scores, corpus_ids), corpus_ids
def _dist_mips_local(
self,
query,
corpus_scores,
local_device_id,
recall_target=0.99,
exact=False,
):
raise NotImplementedError(
'jax.experimental.host_callback has been removed.'
)
def _dist_mips_across(
self,
query,
corpus_scores,
local_device_id,
recall_target=0.99,
exact=False,
):
# must have n_host > retr_k
logging.info('mips global!!!')
logging.info(self.local_keys.value.shape)
logging.info(local_kb.n_data)
n_local_device = len(self.across_index_groups)
logging.info(n_local_device)
global_query = jax.lax.all_gather(
x=query, axis_name='batch', axis=0, tiled=True
)
logging.info(global_query.shape)
# global_query: (per_bsz * n_local_device * n_hosts) * d
global_corpus_scores = jax.lax.all_gather(
x=corpus_scores, axis_name='batch', axis=0, tiled=True
)
logging.info(global_corpus_scores.shape)
# global_corpus_scores: (per_bsz * n_local_device * n_hosts) * n_kb
local_scores = jax.lax.dot(global_query, self.local_keys.value.transpose())
local_k = local_kb.k
if exact:
local_topk_scores, local_topk_ids = jax.lax.top_k(local_scores, k=local_k)
else:
local_topk_scores, local_topk_ids = jax.lax.approx_max_k(
local_scores,
k=local_k,
recall_target=recall_target,
reduction_input_size_override=local_kb.n_data,
aggregate_to_topk=True,
)
local_topk_ids_offset = (
local_topk_ids + local_device_id * local_kb.n_data_per_shard
)
logging.info(local_topk_ids.shape)
# local_topk_ids: (per_bsz * n_local_device * n_hosts) * K
host_topk_scores = jax.lax.all_gather(
x=local_topk_scores,
axis_name='batch',
axis=1,
axis_index_groups=self.axis_index_groups,
tiled=True,
)
logging.info(host_topk_scores.shape)
# host_topk_scores: (per_bsz * n_local_device * n_hosts) * (n_hosts * K)
host_topk_ids = jax.lax.all_gather(
x=local_topk_ids_offset,
axis_name='batch',
axis=1,
axis_index_groups=self.axis_index_groups,
tiled=True,
)
# host_topk_ids: (per_bsz * n_local_device * n_hosts) * (n_hosts * K)
host_corpus_scores, host_corpus_ids = self._get_corpus_scores(
global_corpus_scores, host_topk_ids
)
# host_corpus_scores: (per_bsz * n_local_device * n_hosts) * (n_hosts * K)
host_topk_scores, host_rank_ids = jax.lax.top_k(
host_topk_scores * host_corpus_scores, k=local_k
)
# host_topk_scores: (per_bsz * n_local_device * n_hosts) * K
host_topk_ids = layers.batch_index_select(host_topk_ids, host_rank_ids)
logging.info(host_topk_ids.shape)
# host_topk_ids: (per_bsz * n_local_device * n_hosts) * K
host_topk_ids = jnp.reshape(host_topk_ids, (-1, n_local_device, local_k))
host_topk_ids = host_topk_ids[:, local_device_id]
logging.info(host_topk_ids.shape)
# host_topk_ids: (per_bsz * n_hosts)
host_topk_scores = jnp.reshape(
host_topk_scores, (-1, n_local_device, local_k)
)
host_topk_scores = host_topk_scores[:, local_device_id]
logging.info('host_topk_scores')
logging.info(host_topk_scores.shape)
ret_memory, ret_data = host_callback.call(
local_memory.retrieve_top_memory,
(host_topk_ids),
result_shape=local_kb.ret_top_specs,
)
global_topk_scores = jax.lax.all_to_all(
x=host_topk_scores,
axis_name='batch',
split_axis=0,
concat_axis=1,
axis_index_groups=self.across_index_groups,
tiled=True,
)
logging.info('global_topk_scores')
logging.info(global_topk_scores.shape)
# global_topk_scores: per_bsz * (n_device * k)
global_topk_scores, global_rank_ids = jax.lax.top_k(
global_topk_scores, k=self.retr_k
)
logging.info(global_topk_scores.shape)
# global_topk_scores: per_bsz * retr_k
global_data_ids = global_rank_ids[:, : int(self.data_k)]
global_memory_ids = global_rank_ids[:, int(self.data_k) :]
def _gather_val(local_ret_vals, top_ids):
logging.info(local_ret_vals.shape)
global_ret_vals = jax.lax.all_to_all(
x=local_ret_vals,
axis_name='batch',
split_axis=0,
concat_axis=1,
axis_index_groups=self.across_index_groups,
tiled=True,
)
logging.info(global_ret_vals.shape)
# global_ret_vals: per_bsz * (n_device * k) * dshape
global_ret_vals = layers.batch_index_select(global_ret_vals, top_ids)
logging.info(global_ret_vals.shape)
# global_ret_vals: per_bsz * retr_k * dshape
return global_ret_vals
logging.info('_gather_val!!!')
ret_memory = jax.tree_util.tree_map(
lambda local_val: _gather_val(local_val, global_memory_ids), ret_memory
)
ret_data = jax.tree_util.tree_map(
lambda local_val: _gather_val(local_val, global_data_ids), ret_data
)
ret_memory['masks'] = jnp.ones(ret_memory['values'].shape[:3]).astype(bool)
for k in ret_memory:
logging.info(k)
logging.info(ret_memory[k].shape)
logging.info(ret_memory[k].dtype)
host_corpus_ids = layers.batch_index_select(host_corpus_ids, host_rank_ids)
host_corpus_ids = jnp.reshape(
host_corpus_ids, (-1, n_local_device, local_k)
)[:, local_device_id]
logging.info('corpus_ids')
logging.info(host_corpus_ids.shape)
# host_corpus_ids: (per_bsz * n_local_device, k)
global_corpus_ids = jax.lax.all_to_all(
x=host_corpus_ids,
axis_name='batch',
split_axis=0,
concat_axis=1,
axis_index_groups=self.across_index_groups,
tiled=True,
)
logging.info(global_corpus_ids.shape)
# global_corpus_ids: (per_bsz, n_local_device)
global_corpus_ids = layers.batch_index_select(
global_corpus_ids, global_rank_ids
)
logging.info(global_corpus_ids.shape)
# global_corpus_ids: (per_bsz, 10)
return (
global_topk_scores,
ret_memory,
ret_data,
local_topk_ids,
global_rank_ids,
global_corpus_ids,
)
def t5_decode(
self,
encoded,
encoder_input_tokens: jnp.ndarray, # Only needed for masks.
decoder_input_tokens: jnp.ndarray,
decoder_target_tokens: jnp.ndarray,
enable_dropout: bool = True,
decode: bool = False,
max_decode_length: Optional[int] = None,
):
"""wraps _t5_decoder call (no packing) to enable autoregressive decoding."""
# Without this wrapper flax.model.apply does not know self._t5_decoder yet
# when doing a single (autoregressive) decode step.
return self.out_decoder(
encoded=encoded,
encoder_input_tokens=encoder_input_tokens,
decoder_input_tokens=decoder_input_tokens,
decoder_target_tokens=decoder_target_tokens,
enable_dropout=enable_dropout,
decode=decode,
max_decode_length=max_decode_length,
)
def __call__(
self,
decoder_input_tokens,
decoder_target_tokens,
encoder_input_image=None,
encoder_input_tokens=None,
retr_texts=None,
retr_images=None,
device_id=0,
train=False,
decode=False,
max_decode_length=None,
use_memory=False,
use_psudo_retr=False,
retrieve_local=False,
no_memory=False,
debug=False,
frozen_base=True,
only_encode=False,
**args
):
"""Conduct online retrieval and retrieval-augmented generataion.
Args:
decoder_input_tokens: # B×O.
decoder_target_tokens: # B×O.
encoder_input_image: # B×W×H×3.
encoder_input_tokens: # B×I.
retr_texts: # B×K×L.
retr_images: # B×K×W×H×3.
device_id: index of TPU device.
train: whether using train mode.
decode: whether in decode mode.
max_decode_length: maximum decode token length.
use_memory: whether use on-device memory.
use_psudo_retr: whether to use psudo retrieved groundtruth for guidance.
retrieve_local: whether only retrieve in local host or across hosts.
no_memory: whether not using any retrieval.
debug: whether use debug mode.
frozen_base: whether froze the whole encoder.
only_encode: skip decoding and only return encoded tokens.
**args: other possible arguments.
Returns:
output dictionary containing final and intermediate results.
"""
bsz = decoder_input_tokens.shape[0]
out_dict = {}
base_vals, base_masks, base_query = self.encode_query(
encoder_input_image=encoder_input_image,
encoder_input_tokens=encoder_input_tokens,
frozen_base=frozen_base,
)
base_query = self.dropout(base_query, deterministic=not train)
base_vals = self.dropout(base_vals, deterministic=not train)
if debug:
out_dict['base_query'] = base_query
out_dict['base_masks'] = base_masks
corpus_scores = jax.nn.softmax(self.dataset_gate(base_query), axis=-1)
out_dict['corpus_scores'] = corpus_scores
if no_memory:
fused_emb, fused_mask, attn_weights_all_layers = self.fusion_encoder(
fused_input_embs=base_vals, fused_mask=base_masks, use_dropout=train
) # B×(I+N)×d
else:
if use_memory:
detached_query = jax.lax.stop_gradient(base_query)
if retrieve_local:
(
topk_scores,
ret_memory,
ret_data,
local_topk_ids,
global_topk_ids,
global_corpus_ids,
) = self._dist_mips_local(
query=detached_query,
corpus_scores=corpus_scores,
local_device_id=device_id,
)
else:
(
topk_scores,
ret_memory,
ret_data,
local_topk_ids,
global_topk_ids,
global_corpus_ids,
) = self._dist_mips_across(
query=detached_query,
corpus_scores=corpus_scores,
local_device_id=device_id,
)
out_dict['topk_scores'] = topk_scores
# encode the retrieved data
retr_keys, retr_vals, retr_masks, _, disentangle_reg = (
self.encode_topk_knowledge(
bsz=bsz,
retr_images=ret_data['image'],
retr_texts=ret_data['text_tokens'],
train=train,
random_drop_image=False,
frozen_base=frozen_base,
)
)
global_corpus_scores = layers.batch_index_select(
corpus_scores, global_corpus_ids
)
if debug:
out_dict['detached_query'] = detached_query
out_dict['global_corpus_scores'] = global_corpus_scores
out_dict['global_corpus_ids'] = global_corpus_ids
out_dict['local_topk_ids'] = local_topk_ids
out_dict['global_topk_ids'] = global_topk_ids
out_dict['retr_keys'] = retr_keys
out_dict['retr_masks'] = retr_masks
out_dict['base_vals'] = base_vals
out_dict['retr_vals'] = retr_vals
out_dict['retr_data'] = ret_data
out_dict['base_norm'] = layers.l2_norm(base_vals).mean()
out_dict['data_norm'] = layers.l2_norm(retr_vals).mean()
out_dict['vals_norm'] = layers.l2_norm(ret_memory['values'][0]).mean()
out_dict['gap'] = jnp.abs(
1 - jnp.divide(out_dict['data_norm'], out_dict['base_norm'])
)
if train and retr_texts is not None and use_psudo_retr:
logging.info('global keys!!!')
ground_truth_keys, ground_truth_vals, _, _, _ = self.encode_knowledge(
retr_texts=retr_texts,
retr_images=retr_images,
bsz=bsz,
train=train,
random_drop_image=True,
frozen_base=frozen_base,
)
global_keys = jnp.concatenate(
jax.lax.all_gather(
x=ground_truth_keys, axis_name='batch', axis=0
),
axis=0,
)
logging.info(global_keys.shape)
inbatch_sim = jax.lax.dot(base_query, global_keys.transpose())
out_dict['inbatch_sim'] = inbatch_sim
if debug:
out_dict['global_keys'] = global_keys
out_dict['ground_truth_keys'] = ground_truth_keys
out_dict['ground_truth_vals'] = ground_truth_vals
# replace retrieved knowledge as ground-truth ones for stablization.
k = retr_keys.shape[1]
ground_truth_keys = jnp.repeat(
jnp.expand_dims(ground_truth_keys, axis=1), axis=1, repeats=k
)
ground_truth_vals = jnp.repeat(
jnp.expand_dims(ground_truth_vals, axis=1), axis=1, repeats=k
)
replace_mask = jax.random.bernoulli(
self.make_rng('dropout'), p=0.02, shape=(bsz, 1, 1)
)
keys_mask = jnp.broadcast_to(replace_mask, retr_keys.shape)
retr_keys = jax.lax.select(keys_mask, ground_truth_keys, retr_keys)
vals_mask = jnp.broadcast_to(
jnp.expand_dims(replace_mask, axis=-1), retr_vals.shape
)
retr_vals = jax.lax.select(vals_mask, ground_truth_vals, retr_vals)
logging.info('Concat memory and data!!!')
logging.info(retr_keys.shape)
logging.info(ret_memory['keys'].shape)
logging.info(global_corpus_scores.shape)
# concat retrieved memory (90%) with re-encoded ones (10%)
retr_keys = jnp.concatenate([retr_keys, ret_memory['keys']], axis=1)
retr_keys = retr_keys * jnp.expand_dims(global_corpus_scores, axis=-1)
retr_vals = jnp.concatenate([retr_vals, ret_memory['values']], axis=1)
retr_masks = jnp.concatenate([retr_masks, ret_memory['masks']], axis=1)
elif retr_texts is not None:
retr_keys, retr_vals, retr_masks, _, disentangle_reg = (
self.encode_topk_knowledge(
bsz=bsz,
retr_images=jnp.expand_dims(retr_images, axis=1),
retr_texts=jnp.expand_dims(retr_texts, axis=1),
train=train,
random_drop_image=False,
)
)
else:
retr_keys, retr_vals, retr_masks, _, disentangle_reg = (
self.encode_topk_knowledge(
bsz=bsz,
retr_images=jnp.expand_dims(encoder_input_image, axis=1),
retr_texts=jnp.expand_dims(encoder_input_tokens, axis=1),
train=train,
random_drop_image=False,
)
)
fused_emb, fused_mask, retr_scores, attn_weights_all_layers = (
self.fuse_topk_knowledge(
base_query=base_query,
base_vals=base_vals,
base_masks=base_masks,
retr_keys=retr_keys,
retr_vals=retr_vals,
retr_masks=retr_masks,
train=train,
)
) # B×(I+N+M*K)×d
out_dict['disentangle_reg'] = jnp.mean(disentangle_reg)
out_dict['retr_scores'] = retr_scores
out_dict['fused_emb'] = fused_emb
out_dict['fused_mask'] = fused_mask
logging.info('fused_emb.shape')
logging.info(fused_emb.shape)
out_dict['attn_weights_all_layers'] = attn_weights_all_layers
if not only_encode:
# decode: generate decoding results.
out_dict['predicted_logits'] = self.t5_decode(
encoded=fused_emb,
encoder_input_tokens=fused_mask,
decoder_input_tokens=decoder_input_tokens,
decoder_target_tokens=decoder_target_tokens,
enable_dropout=train,
decode=decode,
max_decode_length=max_decode_length,
)
return out_dict
class KnowledgeFIDModel(base_model.BaseModel):
"""FID model with a retrieval module over a knowledge memory."""
def __init__(
self,
config: Optional[ml_collections.ConfigDict],
dataset_meta_data: Dict[str, Any],
kb_datasets: Dict[str, dataset_utils.Dataset],
) -> None:
self.config = config
self.dataset_meta_data = dataset_meta_data
self.retr_k = self.config.model.retr_k
self.retr_data_ratio = self.config.model.retr_data_ratio
n_device = jax.device_count()
self.data_k = int(np.ceil(self.retr_k * self.retr_data_ratio))
device_per_axis = jax.local_device_count()
if n_device < device_per_axis:
self.axis_index_groups = None
self.across_index_groups = None
else:
self.axis_index_groups = np.arange(n_device).reshape(
[n_device // device_per_axis, device_per_axis]
)
self.across_index_groups = self.axis_index_groups.T.tolist()
self.axis_index_groups = self.axis_index_groups.tolist()
logging.info('axis_index_groups')
logging.info(self.axis_index_groups)
logging.info(self.across_index_groups)
local_kb.initialize(kb_datasets=kb_datasets)
self.flax_model = self.build_flax_model()
def build_flax_model(self) -> nn.Module:
return KnowledgeFIDModule(
self.config.model,
retr_k=self.retr_k,
data_k=self.data_k,
axis_index_groups=self.axis_index_groups,
across_index_groups=self.across_index_groups,
)
def loss_function_dict(
self, output: constants.JTensorDict, batch: constants.JTensorDict
) -> Dict[str, Any]:
"""Returns negative loglikelihood (NLL) of the target sentence.
Args:
output: Output of model in OrderedDict.
batch: Batch of data that has 'decoder_target' as ground-truth.
Returns:
Total loss.
"""
model_config = self.config.model
gen_loss = losses.nll_loss(
targets=batch['decoder_target_tokens'],
pred=output['predicted_logits'],
target_masks=batch['decoder_target_tokens'] > 0,
label_smoothing=self.config.model.get('label_smoothing'),
)
loss_dict = {'gen_loss': gen_loss}
if 'inbatch_sim' in output:
score_matrix = output['inbatch_sim']
bsz = score_matrix.shape[0]
labels = jnp.arange(bsz) + bsz * jax.lax.axis_index(axis_name='batch')
contra_loss = losses.nll_loss(
pred=score_matrix / self.config.model.get('temperature'),
targets=labels,
)
loss_dict['contra_loss'] = contra_loss
r = model_config.retrieval_ratio
loss = gen_loss * (1 - r) + contra_loss * r
accs = jnp.equal(jnp.argmax(score_matrix, axis=1), labels)
loss_dict['contra_accs'] = accs
else:
loss_dict['contra_loss'] = 0.0
loss_dict['contra_accs'] = 0.0
loss = gen_loss
if 'disentangle' in model_config and 'disentangle_reg' in output:
loss += output['disentangle_reg'] * 1e-2
if 'gap' in model_config and 'gap' in output:
loss += output['gap'] * 1e-4
loss_dict['total_loss'] = loss
return loss_dict
def get_metrics_fn(self, split: Optional[str] = None) -> Any:
"""Returns a callable metric function for the model.
Args:
split: The split for which we calculate the metrics. It should be one of
the ['train', 'validation', 'test'].
Returns: A metric function with the following API: ```metrics_fn(outputs,
batch)```
"""
return metrics.token_accuracy
def get_vqa_metrics(
self,
logits: jnp.ndarray,
batch: constants.JTensorDict,
split: Optional[str] = None,
) -> dict[str, float]:
"""Returns the VQA Accuracy for the validation / test set.
Args:
logits: Output of model in shape [B, L, C].
batch: Batch of data that has 'decoder_target' as ground-truth.
split: The split for which we calculate the metrics. It should be one of
the ['train', 'validation', 'test'].
Returns: VQA accuracy```
"""
return metrics.vqa_accuracy(logits, batch)
def single_decode_step(
self,
decoding_state: decoding.DecodingState,
variables: constants.PyTree,
encoded_inputs: jnp.ndarray,
input_masks: jnp.ndarray,
max_decode_length: int,
) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]:
"""Single autoregressive decode step with caching."""
flat_ids = decoding_state.cur_token
flat_cache = decoding_state.cache
# flat_ids: [batch * beam, seq_len=1]
# cache is expanded inside beam_search to become flat_cache
# flat_cache: [batch * beam, num_heads, depth_per_head, max_decode_len]
# flat_logits: [batch * beam, seq_len=1, vocab]
flat_logits, new_vars = self.flax_model.apply(
{'cache': flat_cache, **variables},
encoded=encoded_inputs,
encoder_input_tokens=input_masks,
decoder_input_tokens=flat_ids,
decoder_target_tokens=flat_ids,
decode=True,
enable_dropout=False,
max_decode_length=max_decode_length,
mutable=['cache'],
method=self.flax_model.t5_decode,
)
# Remove sequence length dimension since it's always 1 during decoding.
flat_logits = jnp.squeeze(flat_logits, axis=1)
new_flat_cache = new_vars['cache']
return flat_logits, new_flat_cache
def apply_with_autoregressive_decoding(
self,
variables: constants.PyTree,
decoder_input_tokens: jnp.ndarray,
decoder_target_tokens: jnp.ndarray,
encoder_input_image: Optional[jnp.ndarray] = None,
encoder_input_tokens: Optional[jnp.ndarray] = None,
num_decodes: int = 1,
debug: bool = False,
beam_search: bool = True,
decoder_params: Optional[dict[str, Any]] = None,
return_all_decodes: bool = False,
use_memory=False,
retrieve_local=False,
**args
):
"""Apply inference with autoregressive decoding.
Apply t5x autoregressive decoding with cache using either their
beam_search or temperature_sample decoding technique.
Args:
variables: variables of the models.
decoder_input_tokens: # B×O.
decoder_target_tokens: # B×O.
encoder_input_image: # B×W×H×3.
encoder_input_tokens: # B×I.
num_decodes: number of outputs generated per input for the decode search.
debug: Whether in debug mode or not.
beam_search: If True, do beam search. If False, do temperature sampling.
decoder_params: Additional decoding parameters. These provide additional
parameters to beam_search or temperature_sample (see decoder module).
return_all_decodes: If True, return all decodes. Otherwise only return the
top scored decoding.
use_memory: whether use on-device memory.
retrieve_local: whether only retrieve in local host or across hosts.
**args: other possible arguments.
Returns:
logits array from the final decoder.
"""
# Prepare zeroed-out autoregressive cache.
_, model_state_with_cache = self.flax_model.apply(
variables=variables,
encoder_input_image=encoder_input_image,
encoder_input_tokens=encoder_input_tokens,
decoder_input_tokens=decoder_input_tokens,
decoder_target_tokens=decoder_target_tokens,
train=False,
only_encode=False,
decode=True,
mutable=['cache'],
debug=debug,
use_memory=use_memory,
retrieve_local=retrieve_local,
)
# Call model to get the features consumed by the decoder. Skip the
# the decoding part itself.
out_dict = self.flax_model.apply(
variables=variables,
encoder_input_image=encoder_input_image,
encoder_input_tokens=encoder_input_tokens,
decoder_input_tokens=decoder_input_tokens,
decoder_target_tokens=decoder_target_tokens,
train=False,
only_encode=True,
debug=debug,
use_memory=use_memory,
retrieve_local=retrieve_local,
)
retr_top_image = out_dict['retr_data']['image'][:, 0]
# Prepare transformer fast-decoder call for beam search: for beam search, we
# need to set up our decoder model to handle a batch size equal to
# batch_size * num_decodes, where each batch item's data is expanded
# in-place rather than tiled.
# i.e. if we denote each batch element subtensor as el[n]:
# [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2]
# [batch * num_decodes, input_len, emb_dim]
beam_expand_fn = functools.partial(
decoding.flat_batch_beam_expand, beam_size=num_decodes
)
encoded_inputs = jax.tree_util.tree_map(
beam_expand_fn, out_dict['fused_emb']
)
encoded_masks = jax.tree_util.tree_map(
beam_expand_fn, out_dict['fused_mask']
)
bsz = decoder_input_tokens.shape[0]
max_decode_length = decoder_input_tokens.shape[-1]
# Define the token2logit function for a single decoding step.
tokens_ids_to_logits = functools.partial(
self.single_decode_step,
variables=variables,
encoded_inputs=encoded_inputs,
input_masks=encoded_masks,
max_decode_length=decoder_input_tokens.shape[-1],
)
if decoder_params is None:
decoder_params = {}
# For beam search, `decoder_prompt_inputs` is only used to obtain batch size
# and max decode length information. For temperature sampling,
# `decod_prompt_inputs` will be filled with the sampled ids.
decoder_prompt_inputs = jnp.zeros([bsz, max_decode_length - 1])
bos_inputs = jnp.ones([bsz, 1]) * data_utils.BOS_ID
decoder_prompt_inputs = jnp.concatenate(
(bos_inputs, decoder_prompt_inputs), axis=-1, dtype=jnp.int32
)
if beam_search:
decodes, scores = decoding.beam_search(
inputs=decoder_prompt_inputs,
cache=model_state_with_cache['cache'],
tokens_to_logits=tokens_ids_to_logits,
eos_id=data_utils.EOS_ID,
num_decodes=num_decodes,
cache_offset=0,
**decoder_params
)
else:
decodes, scores = decoding.temperature_sample(
inputs=decoder_prompt_inputs,
cache=model_state_with_cache['cache'],
tokens_to_logits=tokens_ids_to_logits,
eos_id=data_utils.EOS_ID,
num_decodes=num_decodes,
cache_offset=0,
initial_index=jnp.zeros([bsz], dtype=jnp.int32),
**decoder_params
)
if return_all_decodes:
return decodes, scores, retr_top_image
else:
return decodes[:, -1, :], scores[:, -1], retr_top_image