|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from functools import partial |
|
|
from dataclasses import dataclass |
|
|
from typing import Callable, Dict, Optional |
|
|
from timm.models.layers import to_2tuple |
|
|
from fairseq.tasks import FairseqTask |
|
|
from examples.data2vec.models.mae import get_2d_sincos_pos_embed, PatchEmbed |
|
|
from .base import ( |
|
|
D2vModalityConfig, |
|
|
ModalitySpecificEncoder, |
|
|
get_alibi_bias, |
|
|
MaskSeed, |
|
|
) |
|
|
from .modules import ( |
|
|
BlockEncoder, |
|
|
Decoder2d, |
|
|
FixedPositionalEncoder, |
|
|
TransformerDecoder, |
|
|
EncDecTransformerDecoder, |
|
|
) |
|
|
from examples.data2vec.data.modality import Modality |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class D2vImageConfig(D2vModalityConfig): |
|
|
type: Modality = Modality.IMAGE |
|
|
|
|
|
input_size: int = 224 |
|
|
in_chans: int = 3 |
|
|
patch_size: int = 16 |
|
|
embed_dim: int = 768 |
|
|
|
|
|
alibi_dims: int = 2 |
|
|
alibi_distance: str = "manhattan" |
|
|
|
|
|
fixed_positions: bool = True |
|
|
|
|
|
transformer_decoder: bool = False |
|
|
enc_dec_transformer: bool = False |
|
|
|
|
|
|
|
|
class ImageEncoder(ModalitySpecificEncoder): |
|
|
|
|
|
modality_cfg: D2vImageConfig |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
modality_cfg: D2vImageConfig, |
|
|
embed_dim: int, |
|
|
make_block: Callable[[float, Optional[int], Optional[int]], nn.ModuleList], |
|
|
norm_layer: Callable[[int], nn.LayerNorm], |
|
|
layer_norm_first: bool, |
|
|
alibi_biases: Dict, |
|
|
task: Optional[FairseqTask], |
|
|
): |
|
|
|
|
|
img_size = to_2tuple(modality_cfg.input_size) |
|
|
patch_size = to_2tuple(modality_cfg.patch_size) |
|
|
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) |
|
|
|
|
|
local_encoder = PatchEmbed( |
|
|
modality_cfg.input_size, |
|
|
modality_cfg.patch_size, |
|
|
modality_cfg.in_chans, |
|
|
modality_cfg.embed_dim, |
|
|
) |
|
|
|
|
|
w = local_encoder.proj.weight.data |
|
|
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
|
|
|
if modality_cfg.embed_dim != embed_dim: |
|
|
local_encoder = nn.Sequential( |
|
|
local_encoder, |
|
|
nn.Linear(modality_cfg.embed_dim, embed_dim), |
|
|
) |
|
|
|
|
|
project_features = nn.Identity() |
|
|
|
|
|
pos_embed = nn.Parameter( |
|
|
torch.zeros(1, num_patches, embed_dim), requires_grad=False |
|
|
) |
|
|
|
|
|
side_n = int(num_patches ** 0.5) |
|
|
|
|
|
emb = get_2d_sincos_pos_embed( |
|
|
pos_embed.shape[-1], |
|
|
side_n, |
|
|
cls_token=False, |
|
|
) |
|
|
pos_embed.data.copy_(torch.from_numpy(emb).float().unsqueeze(0)) |
|
|
fixed_positional_encoder = ( |
|
|
FixedPositionalEncoder(pos_embed) if modality_cfg.fixed_positions else None |
|
|
) |
|
|
|
|
|
dpr = np.linspace( |
|
|
modality_cfg.start_drop_path_rate, |
|
|
modality_cfg.end_drop_path_rate, |
|
|
modality_cfg.prenet_depth, |
|
|
) |
|
|
|
|
|
context_encoder = BlockEncoder( |
|
|
nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)), |
|
|
norm_layer(embed_dim) if not layer_norm_first else None, |
|
|
layer_norm_first, |
|
|
modality_cfg.prenet_layerdrop, |
|
|
modality_cfg.prenet_dropout, |
|
|
) |
|
|
|
|
|
if modality_cfg.transformer_decoder: |
|
|
if modality_cfg.enc_dec_transformer: |
|
|
decoder = EncDecTransformerDecoder(modality_cfg.decoder, embed_dim) |
|
|
else: |
|
|
dec_enc = BlockEncoder( |
|
|
nn.ModuleList( |
|
|
make_block(0, modality_cfg.decoder.decoder_dim, 8) |
|
|
for _ in range(modality_cfg.decoder.decoder_layers) |
|
|
), |
|
|
None, |
|
|
layer_norm_first, |
|
|
0, |
|
|
0, |
|
|
) |
|
|
decoder = TransformerDecoder(modality_cfg.decoder, embed_dim, dec_enc) |
|
|
else: |
|
|
decoder = ( |
|
|
Decoder2d(modality_cfg.decoder, embed_dim, side_n, side_n) |
|
|
if modality_cfg.decoder is not None |
|
|
else None |
|
|
) |
|
|
|
|
|
alibi_bias_fn = partial( |
|
|
get_alibi_bias, |
|
|
alibi_biases=alibi_biases, |
|
|
heads=modality_cfg.num_alibi_heads, |
|
|
dims=modality_cfg.alibi_dims, |
|
|
distance=modality_cfg.alibi_distance, |
|
|
) |
|
|
|
|
|
super().__init__( |
|
|
modality_cfg=modality_cfg, |
|
|
embed_dim=embed_dim, |
|
|
local_encoder=local_encoder, |
|
|
project_features=project_features, |
|
|
fixed_positional_encoder=fixed_positional_encoder, |
|
|
relative_positional_encoder=None, |
|
|
context_encoder=context_encoder, |
|
|
decoder=decoder, |
|
|
get_alibi_bias=alibi_bias_fn, |
|
|
) |
|
|
|
|
|
def reset_parameters(self): |
|
|
super().reset_parameters() |
|
|
if self.decoder is not None: |
|
|
self.decoder.reset_parameters() |
|
|
|
|
|
@torch.no_grad() |
|
|
def patchify(self, imgs): |
|
|
""" |
|
|
imgs: (N, 3, H, W) |
|
|
x: (N, L, patch_size**2 *3) |
|
|
""" |
|
|
p = self.modality_cfg.patch_size |
|
|
h = w = imgs.shape[2] // p |
|
|
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) |
|
|
x = torch.einsum("nchpwq->nhwpqc", x) |
|
|
x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) |
|
|
|
|
|
return x |
|
|
|
|
|
@torch.no_grad() |
|
|
def unpatchify(self, x): |
|
|
""" |
|
|
x: (N, L, patch_size**2 *3) |
|
|
imgs: (N, 3, H, W) |
|
|
""" |
|
|
p = self.modality_cfg.patch_size |
|
|
h = w = int(x.shape[1] ** 0.5) |
|
|
assert h * w == x.shape[1] |
|
|
|
|
|
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) |
|
|
x = torch.einsum("nhwpqc->nchpwq", x) |
|
|
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) |
|
|
return imgs |
|
|
|
|
|
def compute_mask( |
|
|
self, |
|
|
x, |
|
|
padding_mask, |
|
|
mask_seed: Optional[MaskSeed], |
|
|
apply, |
|
|
shape=None, |
|
|
precomputed_mask=None, |
|
|
): |
|
|
mlen = self.modality_cfg.mask_length |
|
|
if mlen <= 1: |
|
|
return super().compute_mask( |
|
|
x, padding_mask, mask_seed, apply, precomputed_mask |
|
|
) |
|
|
|
|
|
if precomputed_mask is not None: |
|
|
mask = precomputed_mask |
|
|
else: |
|
|
from fairseq.data.data_utils import compute_block_mask_2d |
|
|
|
|
|
if shape is not None: |
|
|
B, L, D = shape |
|
|
else: |
|
|
B, L, D = x.shape |
|
|
|
|
|
mask = compute_block_mask_2d( |
|
|
shape=(B, L), |
|
|
mask_prob=self.modality_cfg.mask_prob, |
|
|
mask_length=self.modality_cfg.mask_length, |
|
|
mask_prob_adjust=self.modality_cfg.mask_prob_adjust, |
|
|
inverse_mask=self.modality_cfg.inverse_mask, |
|
|
require_same_masks=True, |
|
|
mask_dropout=self.modality_cfg.mask_dropout, |
|
|
) |
|
|
|
|
|
mask_info = self.make_maskinfo(x, mask, shape) |
|
|
if apply: |
|
|
x = self.apply_mask(x, mask_info) |
|
|
|
|
|
return x, mask_info |
|
|
|
|
|
def decoder_input(self, x, mask_info): |
|
|
if ( |
|
|
not self.modality_cfg.transformer_decoder |
|
|
or not self.modality_cfg.enc_dec_transformer |
|
|
): |
|
|
return super().decoder_input(x, mask_info) |
|
|
|
|
|
inp_drop = self.modality_cfg.decoder.input_dropout |
|
|
if inp_drop > 0: |
|
|
x = F.dropout(x, inp_drop, training=self.training, inplace=True) |
|
|
|
|
|
kv = x[:, self.modality_cfg.num_extra_tokens :] |
|
|
|
|
|
assert self.fixed_positional_encoder is not None |
|
|
pos = self.fixed_positional_encoder(x, None).expand(x.size(0), -1, -1) |
|
|
|
|
|
mask = mask_info.mask.bool() |
|
|
if self.modality_cfg.decoder.add_positions_all: |
|
|
kv = kv + pos[~mask].view(kv.shape) |
|
|
|
|
|
q = pos[mask].view(x.size(0), -1, x.size(-1)) |
|
|
|
|
|
return q, kv |
|
|
|