| | """ |
| | Based on: https://github.com/lucidrains/flamingo-pytorch |
| | """ |
| |
|
| | import re |
| | import torch |
| | import torch.nn.functional as F |
| | from einops import rearrange, repeat |
| | from einops_exts import rearrange_many |
| | from torch import einsum, nn |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| | from typing import Optional |
| | from dataclasses import dataclass |
| |
|
| |
|
| | @dataclass |
| | class VLMOutputWithPast(CausalLMOutputWithPast): |
| | """ |
| | VLMOutputWithPast is a wrapper around CausalLMOutputWithPast that adds the following attributes: |
| | past_media_locations: Optional[torch.Tensor] = None, |
| | past_vision_tokens: Optional[torch.Tensor] = None, |
| | """ |
| |
|
| | past_media_locations: Optional[torch.Tensor] = None |
| | past_vision_tokens: Optional[torch.Tensor] = None |
| |
|
| |
|
| | def exists(val): |
| | return val is not None |
| |
|
| |
|
| | def FeedForward(dim, mult=4): |
| | inner_dim = int(dim * mult) |
| | return nn.Sequential( |
| | nn.LayerNorm(dim), |
| | nn.Linear(dim, inner_dim, bias=False), |
| | nn.GELU(), |
| | nn.Linear(inner_dim, dim, bias=False), |
| | ) |
| |
|
| |
|
| | class VisionTokenizer(nn.Module): |
| | def __init__(self, dim_media, num_tokens_per_media): |
| | super().__init__() |
| | self.dim_media = dim_media |
| | self.num_tokens_per_media = num_tokens_per_media |
| |
|
| |
|
| | |
| | class MLPVisionProjector(VisionTokenizer): |
| | def __init__(self, *, dim, dim_inner, num_latents): |
| | super().__init__(dim_media=dim, num_tokens_per_media=num_latents) |
| | self.projector = nn.Sequential( |
| | nn.Linear(dim, dim_inner), |
| | nn.GELU(), |
| | nn.Linear(dim_inner, dim_inner), |
| | ) |
| | |
| | def forward(self, x): |
| | return self.projector(x) |
| |
|
| | class PerceiverAttention(nn.Module): |
| | def __init__(self, *, dim, dim_head=64, heads=8): |
| | super().__init__() |
| | self.scale = dim_head**-0.5 |
| | self.heads = heads |
| | inner_dim = dim_head * heads |
| |
|
| | self.norm_media = nn.LayerNorm(dim) |
| | self.norm_latents = nn.LayerNorm(dim) |
| |
|
| | self.to_q = nn.Linear(dim, inner_dim, bias=False) |
| | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) |
| | self.to_out = nn.Linear(inner_dim, dim, bias=False) |
| |
|
| | def forward(self, x, latents): |
| | """ |
| | Args: |
| | x (torch.Tensor): image features |
| | shape (b, T, n1, D) |
| | latent (torch.Tensor): latent features |
| | shape (b, T, n2, D) |
| | """ |
| | x = self.norm_media(x) |
| | latents = self.norm_latents(latents) |
| |
|
| | h = self.heads |
| |
|
| | q = self.to_q(latents) |
| | kv_input = torch.cat((x, latents), dim=-2) |
| | k, v = self.to_kv(kv_input).chunk(2, dim=-1) |
| | q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) |
| | q = q * self.scale |
| |
|
| | |
| | sim = einsum("... i d, ... j d -> ... i j", q, k) |
| | sim = sim - sim.amax(dim=-1, keepdim=True).detach() |
| | attn = sim.softmax(dim=-1) |
| |
|
| | out = einsum("... i j, ... j d -> ... i d", attn, v) |
| | out = rearrange(out, "b h t n d -> b t n (h d)", h=h) |
| | return self.to_out(out) |
| |
|
| |
|
| | class PerceiverResampler(VisionTokenizer): |
| | def __init__( |
| | self, |
| | *, |
| | dim, |
| | dim_inner=None, |
| | depth=6, |
| | dim_head=64, |
| | heads=8, |
| | num_latents=64, |
| | max_num_media=None, |
| | max_num_frames=None, |
| | ff_mult=4, |
| | ): |
| | """ |
| | Perceiver module which takes in image features and outputs image tokens. |
| | Args: |
| | dim (int): dimension of the incoming image features |
| | dim_inner (int, optional): final dimension to project the incoming image features to; |
| | also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim. |
| | depth (int, optional): number of layers. Defaults to 6. |
| | dim_head (int, optional): dimension of each head. Defaults to 64. |
| | heads (int, optional): number of heads. Defaults to 8. |
| | num_latents (int, optional): number of latent tokens to use in the Perceiver; |
| | also corresponds to number of tokens per sequence to output. Defaults to 64. |
| | max_num_media (int, optional): maximum number of media per sequence to input into the Perceiver |
| | and keep positional embeddings for. If None, no positional embeddings are used. |
| | max_num_frames (int, optional): maximum number of frames to input into the Perceiver |
| | and keep positional embeddings for. If None, no positional embeddings are used. |
| | ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4. |
| | """ |
| | if dim_inner is not None: |
| | projection = nn.Linear(dim, dim_inner) |
| | else: |
| | projection = None |
| | dim_inner = dim |
| | super().__init__(dim_media=dim, num_tokens_per_media=num_latents) |
| | self.projection = projection |
| | self.latents = nn.Parameter(torch.randn(num_latents, dim)) |
| |
|
| | |
| | self.frame_embs = ( |
| | nn.Parameter(torch.randn(max_num_frames, dim)) |
| | if exists(max_num_frames) |
| | else None |
| | ) |
| | self.media_time_embs = ( |
| | nn.Parameter(torch.randn(max_num_media, 1, dim)) |
| | if exists(max_num_media) |
| | else None |
| | ) |
| |
|
| | self.layers = nn.ModuleList([]) |
| | for _ in range(depth): |
| | self.layers.append( |
| | nn.ModuleList( |
| | [ |
| | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), |
| | FeedForward(dim=dim, mult=ff_mult), |
| | ] |
| | ) |
| | ) |
| |
|
| | self.norm = nn.LayerNorm(dim) |
| |
|
| | def forward(self, x): |
| | """ |
| | Args: |
| | x (torch.Tensor): image features |
| | shape (b, T, F, v, D) |
| | Returns: |
| | shape (b, T, n, D) where n is self.num_latents |
| | """ |
| | b, T, F, v = x.shape[:4] |
| |
|
| | |
| | if exists(self.frame_embs): |
| | frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) |
| | x = x + frame_embs |
| | x = rearrange( |
| | x, "b T F v d -> b T (F v) d" |
| | ) |
| | if exists(self.media_time_embs): |
| | x = x + self.media_time_embs[:T] |
| |
|
| | |
| | latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) |
| | for attn, ff in self.layers: |
| | latents = attn(x, latents) + latents |
| | latents = ff(latents) + latents |
| | |
| | if exists(self.projection): |
| | return self.projection(self.norm(latents)) |
| | else: |
| | return self.norm(latents) |
| |
|
| |
|
| | |
| | class MaskedCrossAttention(nn.Module): |
| | def __init__( |
| | self, |
| | *, |
| | dim, |
| | dim_visual, |
| | dim_head=64, |
| | heads=8, |
| | only_attend_immediate_media=True, |
| | ): |
| | super().__init__() |
| | self.scale = dim_head**-0.5 |
| | self.heads = heads |
| | inner_dim = dim_head * heads |
| |
|
| | self.norm = nn.LayerNorm(dim) |
| |
|
| | self.to_q = nn.Linear(dim, inner_dim, bias=False) |
| | self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False) |
| | self.to_out = nn.Linear(inner_dim, dim, bias=False) |
| |
|
| | |
| | self.only_attend_immediate_media = only_attend_immediate_media |
| |
|
| | def forward(self, x, media, media_locations=None, use_cached_media=False): |
| | """ |
| | Args: |
| | x (torch.Tensor): text features |
| | shape (B, T_txt, D_txt) |
| | media (torch.Tensor): image features |
| | shape (B, T_img, n, D_img) where n is the dim of the latents |
| | media_locations: boolean mask identifying the media tokens in x |
| | shape (B, T_txt) |
| | use_cached_media: bool |
| | If true, treat all of x as if they occur after the last media |
| | registered in media_locations. T_txt does not need to exactly |
| | equal media_locations.shape[1] in this case |
| | """ |
| |
|
| | if not use_cached_media: |
| | assert ( |
| | media_locations.shape[1] == x.shape[1] |
| | ), f"media_location.shape is {media_locations.shape} but x.shape is {x.shape}" |
| |
|
| | T_txt = x.shape[1] |
| | _, T_img, n = media.shape[:3] |
| | h = self.heads |
| |
|
| | x = self.norm(x) |
| |
|
| | q = self.to_q(x) |
| | media = rearrange(media, "b t n d -> b (t n) d") |
| |
|
| | k, v = self.to_kv(media).chunk(2, dim=-1) |
| | q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h) |
| |
|
| | q = q * self.scale |
| |
|
| | sim = einsum("... i d, ... j d -> ... i j", q, k) |
| |
|
| | if exists(media_locations): |
| | media_time = torch.arange(T_img, device=x.device) + 1 |
| |
|
| | if use_cached_media: |
| | |
| | text_time = repeat( |
| | torch.count_nonzero(media_locations, dim=1), |
| | "b -> b i", |
| | i=T_txt, |
| | ) |
| | else: |
| | |
| | text_time = media_locations.cumsum(dim=-1) |
| |
|
| | |
| | |
| | mask_op = torch.eq if self.only_attend_immediate_media else torch.ge |
| |
|
| | text_to_media_mask = mask_op( |
| | rearrange(text_time, "b i -> b 1 i 1"), |
| | repeat(media_time, "j -> 1 1 1 (j n)", n=n), |
| | ) |
| | sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max) |
| |
|
| | sim = sim - sim.amax(dim=-1, keepdim=True).detach() |
| | attn = sim.softmax(dim=-1) |
| |
|
| | if exists(media_locations) and self.only_attend_immediate_media: |
| | |
| | text_without_media_mask = text_time == 0 |
| | text_without_media_mask = rearrange( |
| | text_without_media_mask, "b i -> b 1 i 1" |
| | ) |
| | attn = attn.masked_fill(text_without_media_mask, 0.0) |
| |
|
| | out = einsum("... i j, ... j d -> ... i d", attn, v) |
| | out = rearrange(out, "b h n d -> b n (h d)") |
| | return self.to_out(out) |
| |
|
| |
|
| | class GatedCrossAttentionBlock(nn.Module): |
| | def __init__( |
| | self, |
| | *, |
| | dim, |
| | dim_visual, |
| | dim_head=64, |
| | heads=8, |
| | ff_mult=4, |
| | only_attend_immediate_media=True, |
| | ): |
| | super().__init__() |
| | self.attn = MaskedCrossAttention( |
| | dim=dim, |
| | dim_visual=dim_visual, |
| | dim_head=dim_head, |
| | heads=heads, |
| | only_attend_immediate_media=only_attend_immediate_media, |
| | ) |
| | self.attn_gate = nn.Parameter(torch.tensor([0.0])) |
| |
|
| | self.ff = FeedForward(dim, mult=ff_mult) |
| | self.ff_gate = nn.Parameter(torch.tensor([0.0])) |
| |
|
| | def forward( |
| | self, |
| | x, |
| | media, |
| | media_locations=None, |
| | use_cached_media=False, |
| | ): |
| | x = ( |
| | self.attn( |
| | x, |
| | media, |
| | media_locations=media_locations, |
| | use_cached_media=use_cached_media, |
| | ) |
| | * self.attn_gate.tanh() |
| | + x |
| | ) |
| | x = self.ff(x) * self.ff_gate.tanh() + x |
| |
|
| | return x |
| |
|
| |
|
| | |
| | class DecoupledEmbedding(nn.Embedding): |
| | |
| | """ |
| | Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the |
| | regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, |
| | then it will create `num_additional_embeddings` additional parameters that are always trained. If |
| | `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | max_original_id: int, |
| | num_additional_embeddings: int = 0, |
| | _weight: torch.Tensor = None, |
| | num_original_embeddings: int = None, |
| | embedding_dim: int = None, |
| | partially_freeze=True, |
| | device=None, |
| | dtype=None, |
| | pad_token_id=None, |
| | ) -> None: |
| | """ |
| | Args: |
| | max_original_id (`int`): |
| | The largest token id that should be embedded using the regular embedding (regular `weight`). |
| | This is usually len(tokenizer) - 1 before additional tokens are added. |
| | Note that this may not equal self.weight.shape[0] |
| | num_additional_embeddings (`int`): |
| | Number of additional tokens to initialize an Embedding matrix for (`additional_weight`). |
| | _weight (`torch.Tensor`, *optional*, defaults to `None`): The regular weight tensor. |
| | If provided, this sets the `num_original_embeddings` and `embedding_dim` parameters. |
| | num_original_embeddings (`int`): |
| | self.weight.shape[0] |
| | embedding_dim (`int`): |
| | The size of each embedding vector |
| | partially_freeze: (`bool`, *optional*, defaults to `True`): |
| | If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen. |
| | padding_idx (`int`, *optional*): |
| | The padding index (needs to be less than num_embeddings) |
| | |
| | Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, |
| | `max_norm` or `norm_type`. We are not supporting these. |
| | """ |
| | |
| | if pad_token_id is not None and pad_token_id > max_original_id: |
| | raise ValueError( |
| | f"pad_token_id must be <= max_original_id. Got {pad_token_id} and {max_original_id}." |
| | + "If the original tokenizer does not have a pad_token_id, use pad_token_id=None." |
| | ) |
| | if _weight is not None: |
| | assert (num_original_embeddings is None) or ( |
| | _weight.shape[0] == num_original_embeddings |
| | ), f"num_original_embeddings={num_original_embeddings} but _weight.shape[0]={_weight.shape[0]}" |
| | assert (embedding_dim is None) or ( |
| | _weight.shape[1] == embedding_dim |
| | ), f"embedding_dim={embedding_dim} but _weight.shape[1]={_weight.shape[1]}" |
| | num_original_embeddings = _weight.shape[0] |
| | embedding_dim = _weight.shape[1] |
| | else: |
| | assert ( |
| | num_original_embeddings is not None |
| | ), "num_original_embeddings must be provided if _weight is not provided" |
| | assert ( |
| | embedding_dim is not None |
| | ), "embedding_dim must be provided if _weight is not provided" |
| |
|
| | super().__init__( |
| | num_embeddings=num_original_embeddings, |
| | embedding_dim=embedding_dim, |
| | device=device, |
| | dtype=dtype, |
| | padding_idx=pad_token_id, |
| | _weight=_weight, |
| | ) |
| | self.max_original_id = max_original_id |
| | self.padding_idx = pad_token_id |
| | self.num_additional_embeddings = num_additional_embeddings |
| | if self.num_additional_embeddings > 0: |
| | self.additional_embedding = nn.Embedding( |
| | num_embeddings=self.num_additional_embeddings, |
| | embedding_dim=embedding_dim, |
| | device=device, |
| | dtype=dtype, |
| | ) |
| | self.set_requires_grad( |
| | require_regular_grad=not partially_freeze, require_additional_grad=True |
| | ) |
| |
|
| | def set_requires_grad(self, require_regular_grad, require_additional_grad): |
| | """ |
| | Helper function to separately set the requires_grad flag for the regular weight and the additional weight. |
| | """ |
| | self.weight.requires_grad_(require_regular_grad) |
| | self.additional_embedding.requires_grad_(require_additional_grad) |
| |
|
| | def forward(self, input_ids): |
| | """ |
| | we have 2 embeddings, with different indices - one pretrained self.weight and another |
| | self.additional_embedding.weight that is being trained. |
| | |
| | in order to make a lookup of the input ids, we: |
| | 1. find out the indices of the entries belonging to the 2nd embedding |
| | 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd |
| | embedding starts from 0 and not num_embeddings |
| | 3. perform the 2nd embedding lookup |
| | 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index |
| | 5. perform the 1st embedding lookup |
| | 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup |
| | |
| | note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but |
| | then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices - |
| | i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are |
| | usually relatively short it's probably not faster or if faster not by much - but might be a good idea to |
| | measure. |
| | |
| | """ |
| | if self.num_additional_embeddings == 0: |
| | return F.embedding(input_ids, self.weight) |
| |
|
| | |
| | input_ids = input_ids.clone() |
| | additional_vocab_indices = torch.where(input_ids > self.max_original_id) |
| | input_ids_additional_vocab = input_ids[additional_vocab_indices] |
| | additional_embeddings = self.additional_embedding( |
| | input_ids_additional_vocab - self.max_original_id - 1 |
| | ) |
| |
|
| | |
| | input_ids[additional_vocab_indices] = 0 |
| | full_vector = F.embedding(input_ids, self.weight) |
| |
|
| | |
| | full_vector[additional_vocab_indices] = additional_embeddings |
| |
|
| | return full_vector |
| |
|
| | def extra_repr(self) -> str: |
| | return "num_original_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format( |
| | self.max_original_id + 1, |
| | self.num_additional_embeddings, |
| | self.embedding_dim, |
| | (not self.weight.requires_grad), |
| | ) |
| |
|
| |
|
| | class DecoupledLinear(nn.Linear): |
| | |
| | """ |
| | Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the |
| | regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `additional_out_features` > 0, |
| | then it will create `additional_out_features * in_features` additional parameters that are always trained. If |
| | `additional_out_features=0`, then the module defaults back to the regular behavior of `nn.Linear`. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | max_original_id: int, |
| | additional_out_features: int = 0, |
| | _weight: torch.Tensor = None, |
| | _bias: torch.Tensor = None, |
| | in_features: int = None, |
| | original_out_features: int = None, |
| | bias: bool = True, |
| | partially_freeze: bool = True, |
| | device=None, |
| | dtype=None, |
| | ) -> None: |
| | """ |
| | Args: |
| | max_original_id (`int`): The largest token id that should be extracted from the regular weight. |
| | This is usually len(tokenizer) - 1 before additional tokens are added. |
| | Note that this may not equal original_out_features - 1 |
| | _weight: torch.Tensor, *optional*, defaults to `None`. The regular weight tensor. |
| | If provided, this sets the `in_features` and `original_out_features` parameters. |
| | _bias: torch.Tensor, *optional*, defaults to `None`. The regular bias tensor. |
| | in_features: int. Input hidden size. |
| | original_out_features: int. Original out_features of the language model's get_output_embeddings() function. |
| | additional_out_features: int. Number of additional trainable dimensions. |
| | bias: bool. Whether to include a bias term. |
| | partially_freeze: bool, *optional*, defaults to `True`): If `True`, the regular `weight` will be frozen. |
| | """ |
| | |
| | if _weight is not None: |
| | assert (_weight.shape[0] == original_out_features) or ( |
| | original_out_features is None |
| | ), f"original_out_features={original_out_features} but _weight.shape[0]={_weight.shape[0]}" |
| | assert (_weight.shape[1] == in_features) or ( |
| | in_features is None |
| | ), f"in_features={in_features} but _weight.shape[1]={_weight.shape[1]}" |
| | in_features = _weight.shape[1] |
| | original_out_features = _weight.shape[0] |
| | else: |
| | assert ( |
| | in_features is not None |
| | ), "in_features must be provided if _weight is not provided" |
| | assert ( |
| | original_out_features is not None |
| | ), "original_out_features must be provided if _weight is not provided" |
| |
|
| | if _bias is not None: |
| | assert bias is True, "bias must be True if _bias is provided" |
| |
|
| | |
| | super().__init__( |
| | in_features, |
| | original_out_features, |
| | bias, |
| | device, |
| | dtype) |
| | |
| | |
| | if _weight is not None: |
| | self.weight = nn.Parameter(_weight) |
| | if _bias is not None: |
| | self.bias = nn.Parameter(_bias) |
| | |
| | self.in_features = in_features |
| | self.original_out_features = original_out_features |
| | self.max_original_id = max_original_id |
| |
|
| | |
| | self.additional_out_features = additional_out_features |
| | self.has_bias = bias |
| | if additional_out_features > 0: |
| | self.additional_fc = nn.Linear( |
| | in_features=in_features, |
| | out_features=additional_out_features, |
| | bias=self.has_bias, |
| | device=device, |
| | dtype=dtype, |
| | ) |
| | self.set_requires_grad( |
| | require_regular_grad=not partially_freeze, require_additional_grad=True |
| | ) |
| |
|
| | def set_requires_grad(self, require_regular_grad, require_additional_grad): |
| | """ |
| | Helper function to separately set the requires_grad flag for the regular weight and the additional weight. |
| | """ |
| | self.weight.requires_grad_(require_regular_grad) |
| | if self.has_bias: |
| | self.bias.requires_grad_(require_regular_grad) |
| | self.additional_fc.requires_grad_(require_additional_grad) |
| |
|
| | def forward(self, input: torch.Tensor) -> torch.Tensor: |
| | output = F.linear(input, self.weight, self.bias) |
| | output = output[..., : self.max_original_id + 1] |
| |
|
| | if self.additional_out_features > 0: |
| | additional_features = F.linear( |
| | input, self.additional_fc.weight, self.additional_fc.bias |
| | ) |
| | output = torch.cat((output, additional_features), -1) |
| | return output |
| |
|
| | def extra_repr(self) -> str: |
| | """Overwriting `nn.Linear.extra_repr` to include new parameters.""" |
| | return "in_features={}, out_features={}, additional_out_features={}, bias={}, partially_freeze={}".format( |
| | self.in_features, |
| | self.max_original_id + 1, |
| | self.additional_out_features, |
| | self.bias is not None, |
| | (not self.weight.requires_grad or not self.bias.requires_grad), |
| | ) |