PyTorch
ssl-aasist
custom_code
ash56's picture
Add files using upload-large-folder tool
6789f6f verified
raw
history blame
7.91 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from functools import partial
from dataclasses import dataclass
from typing import Callable, Dict, Optional
from timm.models.layers import to_2tuple
from fairseq.tasks import FairseqTask
from examples.data2vec.models.mae import get_2d_sincos_pos_embed, PatchEmbed
from .base import (
D2vModalityConfig,
ModalitySpecificEncoder,
get_alibi_bias,
MaskSeed,
)
from .modules import (
BlockEncoder,
Decoder2d,
FixedPositionalEncoder,
TransformerDecoder,
EncDecTransformerDecoder,
)
from examples.data2vec.data.modality import Modality
@dataclass
class D2vImageConfig(D2vModalityConfig):
type: Modality = Modality.IMAGE
input_size: int = 224
in_chans: int = 3
patch_size: int = 16
embed_dim: int = 768
alibi_dims: int = 2
alibi_distance: str = "manhattan"
fixed_positions: bool = True
transformer_decoder: bool = False
enc_dec_transformer: bool = False
class ImageEncoder(ModalitySpecificEncoder):
modality_cfg: D2vImageConfig
def __init__(
self,
modality_cfg: D2vImageConfig,
embed_dim: int,
make_block: Callable[[float, Optional[int], Optional[int]], nn.ModuleList],
norm_layer: Callable[[int], nn.LayerNorm],
layer_norm_first: bool,
alibi_biases: Dict,
task: Optional[FairseqTask],
):
img_size = to_2tuple(modality_cfg.input_size)
patch_size = to_2tuple(modality_cfg.patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
local_encoder = PatchEmbed(
modality_cfg.input_size,
modality_cfg.patch_size,
modality_cfg.in_chans,
modality_cfg.embed_dim,
)
w = local_encoder.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
if modality_cfg.embed_dim != embed_dim:
local_encoder = nn.Sequential(
local_encoder,
nn.Linear(modality_cfg.embed_dim, embed_dim),
)
project_features = nn.Identity()
pos_embed = nn.Parameter(
torch.zeros(1, num_patches, embed_dim), requires_grad=False
)
side_n = int(num_patches ** 0.5)
emb = get_2d_sincos_pos_embed(
pos_embed.shape[-1],
side_n,
cls_token=False,
)
pos_embed.data.copy_(torch.from_numpy(emb).float().unsqueeze(0))
fixed_positional_encoder = (
FixedPositionalEncoder(pos_embed) if modality_cfg.fixed_positions else None
)
dpr = np.linspace(
modality_cfg.start_drop_path_rate,
modality_cfg.end_drop_path_rate,
modality_cfg.prenet_depth,
)
context_encoder = BlockEncoder(
nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
norm_layer(embed_dim) if not layer_norm_first else None,
layer_norm_first,
modality_cfg.prenet_layerdrop,
modality_cfg.prenet_dropout,
)
if modality_cfg.transformer_decoder:
if modality_cfg.enc_dec_transformer:
decoder = EncDecTransformerDecoder(modality_cfg.decoder, embed_dim)
else:
dec_enc = BlockEncoder(
nn.ModuleList(
make_block(0, modality_cfg.decoder.decoder_dim, 8)
for _ in range(modality_cfg.decoder.decoder_layers)
),
None,
layer_norm_first,
0,
0,
)
decoder = TransformerDecoder(modality_cfg.decoder, embed_dim, dec_enc)
else:
decoder = (
Decoder2d(modality_cfg.decoder, embed_dim, side_n, side_n)
if modality_cfg.decoder is not None
else None
)
alibi_bias_fn = partial(
get_alibi_bias,
alibi_biases=alibi_biases,
heads=modality_cfg.num_alibi_heads,
dims=modality_cfg.alibi_dims,
distance=modality_cfg.alibi_distance,
)
super().__init__(
modality_cfg=modality_cfg,
embed_dim=embed_dim,
local_encoder=local_encoder,
project_features=project_features,
fixed_positional_encoder=fixed_positional_encoder,
relative_positional_encoder=None,
context_encoder=context_encoder,
decoder=decoder,
get_alibi_bias=alibi_bias_fn,
)
def reset_parameters(self):
super().reset_parameters()
if self.decoder is not None:
self.decoder.reset_parameters()
@torch.no_grad()
def patchify(self, imgs):
"""
imgs: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
"""
p = self.modality_cfg.patch_size
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum("nchpwq->nhwpqc", x)
x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
return x
@torch.no_grad()
def unpatchify(self, x):
"""
x: (N, L, patch_size**2 *3)
imgs: (N, 3, H, W)
"""
p = self.modality_cfg.patch_size
h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
return imgs
def compute_mask(
self,
x,
padding_mask,
mask_seed: Optional[MaskSeed],
apply,
shape=None,
precomputed_mask=None,
):
mlen = self.modality_cfg.mask_length
if mlen <= 1:
return super().compute_mask(
x, padding_mask, mask_seed, apply, precomputed_mask
)
if precomputed_mask is not None:
mask = precomputed_mask
else:
from fairseq.data.data_utils import compute_block_mask_2d
if shape is not None:
B, L, D = shape
else:
B, L, D = x.shape
mask = compute_block_mask_2d(
shape=(B, L),
mask_prob=self.modality_cfg.mask_prob,
mask_length=self.modality_cfg.mask_length,
mask_prob_adjust=self.modality_cfg.mask_prob_adjust,
inverse_mask=self.modality_cfg.inverse_mask,
require_same_masks=True,
mask_dropout=self.modality_cfg.mask_dropout,
)
mask_info = self.make_maskinfo(x, mask, shape)
if apply:
x = self.apply_mask(x, mask_info)
return x, mask_info
def decoder_input(self, x, mask_info):
if (
not self.modality_cfg.transformer_decoder
or not self.modality_cfg.enc_dec_transformer
):
return super().decoder_input(x, mask_info)
inp_drop = self.modality_cfg.decoder.input_dropout
if inp_drop > 0:
x = F.dropout(x, inp_drop, training=self.training, inplace=True)
kv = x[:, self.modality_cfg.num_extra_tokens :]
assert self.fixed_positional_encoder is not None
pos = self.fixed_positional_encoder(x, None).expand(x.size(0), -1, -1)
mask = mask_info.mask.bool()
if self.modality_cfg.decoder.add_positions_all:
kv = kv + pos[~mask].view(kv.shape)
q = pos[mask].view(x.size(0), -1, x.size(-1))
return q, kv