|
|
import torch |
|
|
from einops import rearrange |
|
|
from torch import nn |
|
|
|
|
|
from .helpers import PerceiverResampler |
|
|
|
|
|
|
|
|
class Flamingo(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
vision_encoder: nn.Module, |
|
|
lang_encoder: nn.Module, |
|
|
eoc_token_id: int, |
|
|
media_token_id: int, |
|
|
vis_dim: int, |
|
|
cross_attn_every_n_layers: int = 1, |
|
|
use_media_placement_augmentation: bool = False, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
vision_encoder (nn.Module): HF CLIPModel |
|
|
lang_encoder (nn.Module): HF causal language model |
|
|
eoc_token_id (int): Token id for <|endofchunk|> |
|
|
media_token_id (int): Token id for <image> |
|
|
vis_dim (int): Dimension of the visual features. |
|
|
Visual features are projected to match this shape along the last dimension. |
|
|
cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1. |
|
|
use_media_placement_augmentation (bool, optional): Whether to randomly assign images to the preceding or following text in training. Defaults to False. |
|
|
""" |
|
|
super().__init__() |
|
|
self.eoc_token_id = eoc_token_id |
|
|
self.media_token_id = media_token_id |
|
|
self.use_media_placement_augmentation = use_media_placement_augmentation |
|
|
self.vis_dim = vis_dim |
|
|
self.vision_encoder = vision_encoder |
|
|
self.perceiver = PerceiverResampler(dim=self.vis_dim) |
|
|
self.lang_encoder = lang_encoder |
|
|
self.lang_encoder.init_flamingo( |
|
|
media_token_id=media_token_id, |
|
|
vis_hidden_size=self.vis_dim, |
|
|
cross_attn_every_n_layers=cross_attn_every_n_layers, |
|
|
use_media_placement_augmentation=self.use_media_placement_augmentation, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
vision_x: torch.Tensor, |
|
|
lang_x: torch.Tensor, |
|
|
attention_mask: torch.Tensor = None, |
|
|
labels: torch.Tensor = None, |
|
|
use_cached_vision_x: bool = False, |
|
|
clear_conditioned_layers: bool = True, |
|
|
past_key_values=None, |
|
|
use_cache: bool = False, |
|
|
): |
|
|
""" |
|
|
Forward pass of Flamingo. |
|
|
|
|
|
Args: |
|
|
vision_x (torch.Tensor): Vision input |
|
|
shape (B, T_img, F, C, H, W) with F=1 |
|
|
lang_x (torch.Tensor): Language input ids |
|
|
shape (B, T_txt) |
|
|
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. |
|
|
labels (torch.Tensor, optional): Labels. Defaults to None. |
|
|
clear_conditioned_layers: if True, clear the conditioned layers |
|
|
once the foward pass is completed. Set this to false if the |
|
|
same set of images will be reused in another subsequent |
|
|
forward pass. |
|
|
past_key_values: pre-computed values to pass to language model. |
|
|
See past_key_values documentation in Hugging Face |
|
|
CausalLM models. |
|
|
use_cache: whether to use cached key values. See use_cache |
|
|
documentation in Hugging Face CausalLM models. |
|
|
""" |
|
|
assert ( |
|
|
vision_x is not None |
|
|
) or use_cached_vision_x, ( |
|
|
"Must provide either vision_x or use_cached_vision_x to True." |
|
|
) |
|
|
|
|
|
if use_cached_vision_x: |
|
|
|
|
|
|
|
|
assert ( |
|
|
vision_x is None |
|
|
), "Expect vision_x to be None when use_cached_vision_x is True." |
|
|
assert self.lang_encoder.is_conditioned() |
|
|
|
|
|
else: |
|
|
|
|
|
self._encode_vision_x(vision_x=vision_x) |
|
|
|
|
|
output = self.lang_encoder( |
|
|
input_ids=lang_x, |
|
|
attention_mask=attention_mask, |
|
|
labels=labels, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
|
|
|
if clear_conditioned_layers: |
|
|
self.lang_encoder.clear_conditioned_layers() |
|
|
|
|
|
return output |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
vision_x: torch.Tensor, |
|
|
lang_x: torch.Tensor, |
|
|
attention_mask: torch.Tensor = None, |
|
|
num_beams=1, |
|
|
max_new_tokens=None, |
|
|
temperature=1.0, |
|
|
top_k=0, |
|
|
top_p=1.0, |
|
|
no_repeat_ngram_size=0, |
|
|
prefix_allowed_tokens_fn=None, |
|
|
length_penalty=1.0, |
|
|
num_return_sequences=1, |
|
|
do_sample=False, |
|
|
early_stopping=False, |
|
|
): |
|
|
""" |
|
|
Generate text conditioned on vision and language inputs. |
|
|
|
|
|
Args: |
|
|
vision_x (torch.Tensor): Vision input |
|
|
shape (B, T_img, F, C, H, W) |
|
|
images in the same chunk are collated along T_img, and frames are collated along F |
|
|
currently only F=1 is supported (single-frame videos) |
|
|
lang_x (torch.Tensor): Language input |
|
|
shape (B, T_txt) |
|
|
max_length (int, optional): Maximum length of the output. Defaults to None. |
|
|
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. |
|
|
num_beams (int, optional): Number of beams. Defaults to 1. |
|
|
max_new_tokens (int, optional): Maximum new tokens. Defaults to None. |
|
|
temperature (float, optional): Temperature. Defaults to 1.0. |
|
|
top_k (int, optional): Top k. Defaults to 0. |
|
|
top_p (float, optional): Top p. Defaults to 1.0. |
|
|
no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0. |
|
|
length_penalty (float, optional): Length penalty. Defaults to 1.0. |
|
|
num_return_sequences (int, optional): Number of return sequences. Defaults to 1. |
|
|
do_sample (bool, optional): Do sample. Defaults to False. |
|
|
early_stopping (bool, optional): Early stopping. Defaults to False. |
|
|
Returns: |
|
|
torch.Tensor: lang_x with generated tokens appended to it |
|
|
""" |
|
|
if num_beams > 1: |
|
|
vision_x = vision_x.repeat_interleave(num_beams, dim=0) |
|
|
|
|
|
self._encode_vision_x(vision_x=vision_x) |
|
|
|
|
|
output = self.lang_encoder.generate( |
|
|
lang_x, |
|
|
attention_mask=attention_mask, |
|
|
eos_token_id=self.eoc_token_id, |
|
|
num_beams=num_beams, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, |
|
|
no_repeat_ngram_size=no_repeat_ngram_size, |
|
|
length_penalty=length_penalty, |
|
|
num_return_sequences=num_return_sequences, |
|
|
do_sample=do_sample, |
|
|
early_stopping=early_stopping, |
|
|
) |
|
|
|
|
|
self.lang_encoder.clear_conditioned_layers() |
|
|
return output |
|
|
|
|
|
def _encode_vision_x(self, vision_x: torch.Tensor): |
|
|
""" |
|
|
Compute media tokens from vision input by passing it through vision encoder and conditioning language model. |
|
|
Args: |
|
|
vision_x (torch.Tensor): Vision input |
|
|
shape (B, T_img, F, C, H, W) |
|
|
Images in the same chunk are collated along T_img, and frames are collated along F |
|
|
Currently only F=1 is supported (single-frame videos) |
|
|
|
|
|
rearrange code based on https://github.com/dhansmair/flamingo-mini |
|
|
""" |
|
|
|
|
|
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)" |
|
|
b, T, F = vision_x.shape[:3] |
|
|
assert F == 1, "Only single frame supported" |
|
|
|
|
|
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w") |
|
|
with torch.no_grad(): |
|
|
vision_x = self.vision_encoder.visual(vision_x)[1] |
|
|
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) |
|
|
|
|
|
vision_x = self.perceiver(vision_x) |
|
|
|
|
|
for layer in self.lang_encoder._get_decoder_layers(): |
|
|
layer.condition_vis_x(vision_x) |
|
|
|