| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Baseline for Image to Text Models. |
| | |
| | B = batch size |
| | H = height |
| | W = width |
| | N = number of image tokens |
| | I = Input sequence length |
| | O = Ouput sequence length |
| | d = hidden dims |
| | C = number of vocabulary |
| | K = number of candidate |
| | L = sequence length of retrieved document |
| | M = sequence length of compressed tokens |
| | """ |
| | from typing import Any, Dict, Optional |
| |
|
| | import flax.linen as nn |
| | import jax |
| | import jax.numpy as jnp |
| | import ml_collections |
| | from scenic.model_lib.base_models import base_model |
| | from scenic.projects.knowledge_visual_language.models import constants |
| | from scenic.projects.knowledge_visual_language.models import layers |
| | from scenic.projects.knowledge_visual_language.models import losses |
| | from scenic.projects.knowledge_visual_language.models import metrics |
| | from scenic.projects.knowledge_visual_language.models import vit as vit_model |
| | from scenic.projects.t5 import layers as t5_model |
| | from scenic.projects.t5 import model as t5_pretrained |
| |
|
| |
|
| | class VisionLanguageModule(nn.Module): |
| | """Basic ViT + T5 vision language model.""" |
| |
|
| | config: ml_collections.ConfigDict |
| |
|
| | def setup(self): |
| | t5_config = t5_pretrained.CONFIGS[self.config.t5_name] |
| | self.t5_config = t5_config |
| | t5_config['dropout_rate'] = self.config.dropout_rate |
| | self.ndim = t5_config['emb_dim'] |
| | self.dropout_rate = t5_config['dropout_rate'] |
| | self.key_dim = self.config.key_dim |
| | self.dtype = t5_config['dtype'] |
| | |
| | self.shared_token_embedder = t5_model.t5_layers.Embed( |
| | num_embeddings=t5_config['vocab_size'], |
| | features=self.ndim, |
| | dtype=self.dtype, |
| | attend_dtype=self.dtype, |
| | embedding_init=nn.initializers.normal(stddev=1.0), |
| | one_hot=True, |
| | name='shared_token_embedder', |
| | ) |
| | |
| | self.out_decoder = t5_model.T5Decoder( |
| | **t5_config, |
| | shared_embedding=self.shared_token_embedder, |
| | name='out_decoder' |
| | ) |
| | |
| | self.text_encoder = layers.LowerT5Encoder( |
| | **t5_config, |
| | num_fusion_layers=self.config.num_fusion_layers, |
| | shared_embedding=self.shared_token_embedder, |
| | name='text_encoder' |
| | ) |
| | |
| | self.fusion_encoder = layers.FusedT5Encoder( |
| | **t5_config, |
| | num_fusion_layers=self.config.num_fusion_layers, |
| | name='fusion_encoder' |
| | ) |
| | |
| | self.img_encoder = vit_model.Model( |
| | num_classes=self.ndim, |
| | dropout=self.dropout_rate, |
| | name='img_encoder', |
| | variant=self.config.vit_name, |
| | head_zeroinit=False, |
| | dtype=jnp.bfloat16, |
| | num_frozen_layers=self.config.get('vit_num_frozen_layers', -1), |
| | pool_type='gap', |
| | ) |
| | self.dropout = nn.Dropout(rate=0.2) |
| |
|
| | def get_base_encoded( |
| | self, |
| | image=None, |
| | text_tokens=None, |
| | train=False, |
| | random_drop_image=False, |
| | bsz=None, |
| | frozen_base=True, |
| | ): |
| | if bsz is None: |
| | if text_tokens is not None: |
| | bsz = len(text_tokens) |
| | elif image is not None: |
| | bsz = len(image) |
| | if text_tokens is not None: |
| | text_query, text_mask = self.text_encoder( |
| | encoder_input_tokens=text_tokens, |
| | use_dropout=train, |
| | frozen_base=frozen_base, |
| | ) |
| | else: |
| | text_query = jnp.zeros([bsz, 1, self.ndim], dtype=self.dtype) |
| | text_mask = jnp.zeros([bsz, 1], dtype=self.dtype) |
| | if image is not None: |
| | img_query, img_emb = self.encode_image(image, train=train) |
| | n_img_tokens = img_query.shape[1] |
| | else: |
| | n_img_tokens = 1 |
| | img_query = jnp.zeros([bsz, n_img_tokens, self.ndim], dtype=self.dtype) |
| | img_emb = jnp.zeros([bsz, self.ndim], dtype=self.dtype) |
| | if train and random_drop_image: |
| | image_mask = jax.random.bernoulli( |
| | self.make_rng('dropout'), p=1 - 0.2, shape=(bsz, 1) |
| | ).astype(self.dtype) |
| | img_emb = img_emb * image_mask |
| | image_mask = jnp.repeat(image_mask, repeats=n_img_tokens, axis=1) |
| | else: |
| | image_mask = jnp.ones([bsz, n_img_tokens], dtype=self.dtype) |
| | base_masks = jnp.concatenate([text_mask, image_mask], axis=1) |
| | return [text_query, img_query], base_masks, img_emb |
| |
|
| |
|
| | class FusionInDecoderSoftModule(VisionLanguageModule): |
| | """Modification of FID (https://arxiv.org/pdf/2007.01282.pdf) model. |
| | |
| | Take continous embedding of retrieved document at middle fusion layer |
| | instead of whole sequence at input. |
| | """ |
| |
|
| | config: ml_collections.ConfigDict |
| |
|
| | def setup(self): |
| | super().setup() |
| | self.n_compressed_tokens = self.config.n_compressed_tokens |
| | |
| | self.value_perceiver = layers.PerceiverEncoder( |
| | **self.t5_config, |
| | num_fusion_layers=self.config.num_fusion_layers, |
| | perceiver_output_dim=self.n_compressed_tokens, |
| | name='value_perceiver' |
| | ) |
| | |
| | self.compress_head = nn.Dense( |
| | features=self.key_dim, dtype=self.dtype, name='head_out', use_bias=False |
| | ) |
| | self.query_head = layers.TransformerHead( |
| | **self.t5_config, |
| | num_head_layers=self.config.num_fusion_layers, |
| | out_head=self.compress_head, |
| | key_dim=self.key_dim, |
| | name='query_head' |
| | ) |
| | self.key_head = layers.TransformerHead( |
| | **self.t5_config, |
| | num_head_layers=self.config.num_fusion_layers, |
| | out_head=self.compress_head, |
| | key_dim=self.key_dim, |
| | name='key_head' |
| | ) |
| | self.att_transform = layers.AffineTransform() |
| |
|
| | def compress_and_pool_key(self, h, mask): |
| | window_size = self.n_stride |
| | pooled_tokens = nn.avg_pool( |
| | h[:, self.n_compressed_tokens :, :], |
| | window_shape=(window_size,), |
| | strides=(self.n_stride,), |
| | ) |
| | pooled_tokens = jnp.concatenate( |
| | (h[:, : self.n_compressed_tokens, :], pooled_tokens), axis=1 |
| | ) |
| | pooled_mask = jnp.squeeze( |
| | -nn.max_pool( |
| | jnp.expand_dims(-mask[:, self.n_compressed_tokens :], axis=-1), |
| | window_shape=(window_size,), |
| | strides=(self.n_stride,), |
| | ) |
| | ) |
| | pooled_mask = jnp.concatenate( |
| | (mask[:, : self.n_compressed_tokens], pooled_mask), axis=1 |
| | ) |
| | |
| | return pooled_tokens, pooled_mask |
| |
|
| | def compress_key(self, h, mask): |
| | pooled_tokens = h[:, : self.n_compressed_tokens, :] |
| | pooled_mask = mask[:, : self.n_compressed_tokens] |
| | return pooled_tokens, pooled_mask |
| |
|
| | def encode_knowledge( |
| | self, |
| | retr_texts, |
| | retr_images=None, |
| | bsz=None, |
| | train=False, |
| | random_drop_image=False, |
| | frozen_base=True, |
| | ): |
| | retr_tokens, retr_masks, retr_img_emb = self.get_base_encoded( |
| | bsz=bsz, |
| | image=retr_images, |
| | text_tokens=retr_texts, |
| | train=train, |
| | random_drop_image=random_drop_image, |
| | frozen_base=frozen_base, |
| | ) |
| | retr_tokens = jnp.concatenate(retr_tokens, axis=1) |
| | retr_keys = self.key_head( |
| | encoded_emb=retr_tokens, encoder_mask=retr_masks, use_dropout=train |
| | ) |
| | compressed_val, compressed_mask, disentangle_reg = self.value_perceiver( |
| | encoded=retr_tokens, encoded_mask=retr_masks, use_dropout=train |
| | ) |
| |
|
| | return ( |
| | retr_keys, |
| | compressed_val, |
| | compressed_mask, |
| | retr_img_emb, |
| | disentangle_reg, |
| | ) |
| |
|
| | def encode_query( |
| | self, |
| | encoder_input_image, |
| | encoder_input_tokens, |
| | train=False, |
| | frozen_base=True, |
| | ): |
| | bsz = encoder_input_image.shape[0] |
| | base_vals, base_masks, _ = self.get_base_encoded( |
| | bsz=bsz, |
| | image=encoder_input_image, |
| | text_tokens=encoder_input_tokens, |
| | train=train, |
| | frozen_base=frozen_base, |
| | ) |
| | base_vals = self.dropout( |
| | jnp.concatenate(base_vals, axis=1), deterministic=not train |
| | ) |
| | base_query = self.query_head( |
| | encoded_emb=base_vals, encoder_mask=base_masks, use_dropout=train |
| | ) |
| | return base_vals, base_masks, base_query |
| |
|
| | def encode_topk_knowledge( |
| | self, |
| | bsz, |
| | retr_texts, |
| | retr_images=None, |
| | train=False, |
| | random_drop_image=False, |
| | frozen_base=True, |
| | ): |
| | k, l = retr_texts.shape[1], retr_texts.shape[2] |
| | retr_texts = jnp.reshape(retr_texts, (bsz * k, l)) |
| | if retr_images is not None: |
| | image_shape = (bsz * k,) + retr_images.shape[2:] |
| | retr_images = jnp.reshape(retr_images, image_shape) |
| | ( |
| | retr_keys, |
| | compressed_val, |
| | compressed_mask, |
| | retr_img_emb, |
| | disentangle_reg, |
| | ) = self.encode_knowledge( |
| | retr_texts, |
| | retr_images, |
| | bsz=bsz * k, |
| | train=train, |
| | random_drop_image=random_drop_image, |
| | frozen_base=frozen_base, |
| | ) |
| | n_tokens = compressed_val.shape[1] |
| | retr_keys = jnp.reshape(retr_keys, (bsz, k, self.key_dim)) |
| | compressed_val = jnp.reshape( |
| | compressed_val, (bsz, k, n_tokens, self.ndim) |
| | ) |
| | compressed_mask = jnp.reshape(compressed_mask, (bsz, k, n_tokens)) |
| | return ( |
| | retr_keys, |
| | compressed_val, |
| | compressed_mask, |
| | retr_img_emb, |
| | disentangle_reg, |
| | ) |
| |
|
| | def encode_image(self, image, train=False): |
| | _, out = self.img_encoder(image, train=train) |
| | img_query = jnp.asarray(out['logits_2d'] * 4, self.dtype) |
| | n_img_tokens = img_query.shape[1] * img_query.shape[2] |
| | img_query = jnp.reshape(img_query, [-1, n_img_tokens, self.ndim]) |
| | img_emb = jnp.asarray(out['head_input'], self.dtype) |
| | return img_query, img_emb |
| |
|
| | def fuse_topk_knowledge( |
| | self, |
| | base_query, |
| | base_vals, |
| | base_masks, |
| | retr_keys, |
| | retr_vals, |
| | retr_masks, |
| | train=False, |
| | ): |
| | (bsz, k, n_tokens) = retr_vals.shape[:3] |
| | retr_vals = jnp.reshape( |
| | retr_vals, (bsz, k * n_tokens, self.ndim) |
| | ) |
| | retr_scores = jnp.einsum('bd,bkd->bk', base_query, retr_keys) |
| | retr_scores = jax.nn.softmax(self.att_transform(retr_scores), axis=-1) * k |
| | retr_masks = jnp.reshape(retr_masks, (bsz, k * n_tokens)) |
| | att_mask = [ |
| | jnp.ones([bsz, base_vals.shape[1]]), |
| | jnp.repeat(retr_scores, repeats=n_tokens, axis=-1), |
| | ] |
| | att_mask = jnp.expand_dims(jnp.concatenate(att_mask, axis=-1), axis=-1) |
| | fused_query, fused_mask, attn_weights_all_layers = self.fusion_encoder( |
| | encoder_input_embs=base_vals, |
| | fused_input_embs=retr_vals, |
| | encoder_mask=base_masks, |
| | fused_mask=retr_masks, |
| | att_mask=att_mask, |
| | use_dropout=train, |
| | output=True, |
| | ) |
| | return fused_query, fused_mask, retr_scores, attn_weights_all_layers |
| |
|
| | def __call__( |
| | self, |
| | decoder_input_tokens, |
| | decoder_target_tokens, |
| | encoder_input_image=None, |
| | encoder_input_tokens=None, |
| | retr_texts=None, |
| | retr_images=None, |
| | train=False, |
| | decode=False, |
| | fuse_retrieval=True, |
| | max_decode_length=None, |
| | debug: bool = False, |
| | in_batch_neg: bool = False, |
| | frozen_base=True, |
| | **args |
| | ): |
| | """Conduct supervised retrieval-augmented training with given retrieved documents. |
| | |
| | Args: |
| | decoder_input_tokens: # B×O. |
| | decoder_target_tokens: # B×O. |
| | encoder_input_image: # B×W×H×3. |
| | encoder_input_tokens: # B×I. |
| | retr_texts: # B×K×L. |
| | retr_images: # B×K×W×H×3. |
| | train: whether using train mode. |
| | decode: whether in decode mode. |
| | fuse_retrieval: whether use input retrieval docs. |
| | max_decode_length: maximum decode token length. |
| | debug: whether use debug mode. |
| | in_batch_neg: whether use in-batch contastive learning. |
| | frozen_base: whether froze the whole encoder. |
| | **args: other possible arguments. |
| | |
| | Returns: |
| | output dictionary containing final and intermediate results. |
| | """ |
| | bsz = decoder_input_tokens.shape[0] |
| | base_vals, base_masks, query_img_emb = self.get_base_encoded( |
| | bsz=bsz, |
| | image=encoder_input_image, |
| | text_tokens=encoder_input_tokens, |
| | train=train, |
| | frozen_base=frozen_base, |
| | ) |
| | out_dict = { |
| | 'query_img_emb': query_img_emb, |
| | 'text_query': base_vals[0], |
| | 'image_query': base_vals[1], |
| | } |
| | base_vals = jnp.concatenate(base_vals, axis=1) |
| | if retr_texts is not None: |
| | retr_keys, retr_vals, retr_masks, retr_img_emb, disentangle_reg = ( |
| | self.encode_topk_knowledge( |
| | bsz=bsz, |
| | retr_images=retr_images, |
| | retr_texts=retr_texts, |
| | train=train, |
| | random_drop_image=True, |
| | ) |
| | ) |
| | base_query = self.query_head( |
| | encoded_emb=base_vals, encoder_mask=base_masks, use_dropout=train |
| | ) |
| | out_dict['disentangle_reg'] = disentangle_reg |
| | out_dict['retr_img_emb'] = retr_img_emb |
| | out_dict['base_query'] = base_query |
| | out_dict['retr_keys'] = retr_keys |
| | out_dict['retr_vals'] = retr_vals |
| |
|
| | if fuse_retrieval and retr_texts is not None: |
| | |
| | if in_batch_neg and retr_vals.shape[1] == 1: |
| | |
| | retr_vals = jnp.concatenate( |
| | (retr_vals, jnp.roll(retr_vals, shift=1, axis=0)), axis=1 |
| | ) |
| | retr_keys = jnp.concatenate( |
| | (retr_keys, jnp.roll(retr_keys, shift=1, axis=0)), axis=1 |
| | ) |
| | retr_masks = jnp.concatenate( |
| | (retr_masks, jnp.roll(retr_masks, shift=1, axis=0)), axis=1 |
| | ) |
| |
|
| | fused_emb, fused_mask, retr_scores, attn_weights_all_layers = ( |
| | self.fuse_topk_knowledge( |
| | base_query=base_query, |
| | base_vals=base_vals, |
| | base_masks=base_masks, |
| | retr_keys=retr_keys, |
| | retr_vals=retr_vals, |
| | retr_masks=retr_masks, |
| | train=train, |
| | ) |
| | ) |
| | out_dict['retr_scores'] = retr_scores |
| | else: |
| | |
| | fused_emb, fused_mask, attn_weights_all_layers = self.fusion_encoder( |
| | fused_input_embs=base_vals, fused_mask=base_masks, use_dropout=train |
| | ) |
| | |
| | out_dict['attn_weights_all_layers'] = attn_weights_all_layers |
| | out_dict['predicted_logits'] = self.out_decoder( |
| | encoded=fused_emb, |
| | decoder_input_tokens=decoder_input_tokens, |
| | encoder_input_tokens=fused_mask, |
| | decoder_target_tokens=decoder_target_tokens, |
| | enable_dropout=train, |
| | decode=decode, |
| | max_decode_length=max_decode_length, |
| | encoder_segment_ids=None, |
| | decoder_segment_ids=None, |
| | ) |
| | return out_dict |
| |
|
| |
|
| | class FIDSoftModel(base_model.BaseModel): |
| | """FID model.""" |
| |
|
| | def build_flax_model(self) -> nn.Module: |
| | return FusionInDecoderSoftModule(self.config.model) |
| |
|
| | def loss_function_dict( |
| | self, output: constants.JTensorDict, batch: constants.JTensorDict |
| | ) -> Dict[str, Any]: |
| | """Returns negative loglikelihood (NLL) of the target sentence. |
| | |
| | Args: |
| | output: Output of model in OrderedDict. |
| | batch: Batch of data that has 'decoder_target' as ground-truth. |
| | |
| | Returns: |
| | Total loss. |
| | """ |
| | gen_loss = losses.nll_loss( |
| | targets=batch['decoder_target_tokens'], |
| | pred=output['predicted_logits'], |
| | target_masks=batch['decoder_target_tokens'] > 0, |
| | label_smoothing=self.config.model.get('label_smoothing'), |
| | ) |
| | loss_dict = {'gen_loss': gen_loss} |
| | if output['supervised_retrieval']: |
| | retr_loss, (retr_acc, s0, s1) = losses.contrastive_loss( |
| | query_emb=output['base_query'], |
| | key_emb=output['retr_keys'], |
| | temperature=self.config.model.get('temperature'), |
| | ) |
| | loss_dict['retr_loss'] = retr_loss |
| | loss_dict['retr_acc'] = retr_acc |
| | loss_dict['s0'] = s0 |
| | loss_dict['s1'] = s1 |
| | else: |
| | loss_dict['retr_loss'] = -1 |
| | loss_dict['retr_acc'] = -1 |
| | loss_dict['s0'] = -1 |
| | loss_dict['s1'] = -1 |
| | return loss_dict |
| |
|
| | def get_metrics_fn(self, split: Optional[str] = None) -> base_model.MetricFn: |
| | """Returns a callable metric function for the model. |
| | |
| | Args: |
| | split: The split for which we calculate the metrics. It should be one of |
| | the ['train', 'validation', 'test']. |
| | Returns: A metric function with the following API: ```metrics_fn(outputs, |
| | batch)``` |
| | """ |
| |
|
| | return metrics.token_accuracy |
| |
|