|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Tuple, Optional |
|
|
from components import RMSNorm, SwiGLU |
|
|
from transformer import OptimizedTransformerBlock |
|
|
import math |
|
|
|
|
|
class LayerScale(nn.Module): |
|
|
def __init__(self, dim: int, init_values: float = 1e-5): |
|
|
super().__init__() |
|
|
self.gamma = nn.Parameter(init_values * torch.ones(dim)) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return x * self.gamma |
|
|
|
|
|
class StochasticDepth(nn.Module): |
|
|
def __init__(self, drop_prob: float = 0.0): |
|
|
super().__init__() |
|
|
self.drop_prob = drop_prob |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
if not self.training or self.drop_prob == 0.0: |
|
|
return x |
|
|
|
|
|
keep_prob = 1 - self.drop_prob |
|
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
|
|
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
|
|
random_tensor.floor_() |
|
|
return x.div(keep_prob) * random_tensor |
|
|
|
|
|
class ImprovedPatchEmbedding(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
patch_size: int = 14, |
|
|
in_channels: int = 3, |
|
|
embed_dim: int = 2048, |
|
|
overlap: int = 0 |
|
|
): |
|
|
super().__init__() |
|
|
self.patch_size = patch_size |
|
|
stride = patch_size - overlap |
|
|
self.proj = nn.Conv2d( |
|
|
in_channels, |
|
|
embed_dim, |
|
|
kernel_size=patch_size, |
|
|
stride=stride, |
|
|
padding=overlap // 2 |
|
|
) |
|
|
|
|
|
self.norm = RMSNorm(embed_dim) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]: |
|
|
B, C, H, W = x.shape |
|
|
x = self.proj(x) |
|
|
grid_size = (x.shape[2], x.shape[3]) |
|
|
x = x.flatten(2).transpose(1, 2) |
|
|
x = self.norm(x) |
|
|
return x, grid_size |
|
|
|
|
|
class ImprovedVisionBlock(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
n_heads: int, |
|
|
dropout: float = 0.0, |
|
|
drop_path: float = 0.0, |
|
|
use_adapter: bool = False, |
|
|
adapter_dim: int = 64, |
|
|
use_layer_scale: bool = True, |
|
|
layer_scale_init: float = 1e-5 |
|
|
): |
|
|
super().__init__() |
|
|
self.norm1 = RMSNorm(dim) |
|
|
self.attn = nn.MultiheadAttention( |
|
|
dim, n_heads, dropout=dropout, batch_first=True |
|
|
) |
|
|
|
|
|
self.norm2 = RMSNorm(dim) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(dim, dim * 4), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(dim * 4, dim), |
|
|
nn.Dropout(dropout) |
|
|
) |
|
|
|
|
|
self.drop_path = StochasticDepth(drop_path) if drop_path > 0 else nn.Identity() |
|
|
|
|
|
if use_layer_scale: |
|
|
self.ls1 = LayerScale(dim, layer_scale_init) |
|
|
self.ls2 = LayerScale(dim, layer_scale_init) |
|
|
else: |
|
|
self.ls1 = nn.Identity() |
|
|
self.ls2 = nn.Identity() |
|
|
|
|
|
if use_adapter: |
|
|
self.adapter = nn.Sequential( |
|
|
nn.Linear(dim, adapter_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(adapter_dim, dim) |
|
|
) |
|
|
else: |
|
|
self.adapter = None |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
normx = self.norm1(x) |
|
|
attn_out, _ = self.attn(normx, normx, normx) |
|
|
x = x + self.drop_path(self.ls1(attn_out)) |
|
|
|
|
|
|
|
|
x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x)))) |
|
|
|
|
|
|
|
|
if self.adapter is not None: |
|
|
x = x + self.adapter(x) |
|
|
|
|
|
return x |
|
|
|
|
|
class ImprovedVisionTransformer(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
img_size: int = 224, |
|
|
patch_size: int = 14, |
|
|
in_channels: int = 3, |
|
|
embed_dim: int = 2048, |
|
|
depth: int = 24, |
|
|
n_heads: int = 16, |
|
|
dropout: float = 0.0, |
|
|
drop_path_rate: float = 0.1, |
|
|
use_register_tokens: bool = True, |
|
|
num_register_tokens: int = 4, |
|
|
use_adapter: bool = False, |
|
|
adapter_dim: int = 64, |
|
|
use_layer_scale: bool = True, |
|
|
layer_scale_init: float = 1e-5 |
|
|
): |
|
|
super().__init__() |
|
|
self.patch_size = patch_size |
|
|
self.embed_dim = embed_dim |
|
|
self.use_register_tokens = use_register_tokens |
|
|
self.num_register_tokens = num_register_tokens if use_register_tokens else 0 |
|
|
|
|
|
|
|
|
self.patch_embed = ImprovedPatchEmbedding( |
|
|
patch_size, in_channels, embed_dim, overlap=0 |
|
|
) |
|
|
|
|
|
self.pretrain_img_size = img_size |
|
|
n_patches_pretrain = (img_size // patch_size) ** 2 |
|
|
|
|
|
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
|
|
|
|
|
|
|
if use_register_tokens: |
|
|
self.register_tokens = nn.Parameter( |
|
|
torch.zeros(1, num_register_tokens, embed_dim) |
|
|
) |
|
|
|
|
|
total_tokens = 1 + n_patches_pretrain + self.num_register_tokens |
|
|
self.pos_embed = nn.Parameter( |
|
|
torch.zeros(1, total_tokens, embed_dim) |
|
|
) |
|
|
self.pos_drop = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] |
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList([ |
|
|
ImprovedVisionBlock( |
|
|
embed_dim, |
|
|
n_heads, |
|
|
dropout, |
|
|
drop_path=dpr[i], |
|
|
use_adapter=use_adapter, |
|
|
adapter_dim=adapter_dim, |
|
|
use_layer_scale=use_layer_scale, |
|
|
layer_scale_init=layer_scale_init |
|
|
) |
|
|
for i in range(depth) |
|
|
]) |
|
|
|
|
|
self.norm = RMSNorm(embed_dim) |
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self): |
|
|
nn.init.trunc_normal_(self.cls_token, std=0.02) |
|
|
nn.init.trunc_normal_(self.pos_embed, std=0.02) |
|
|
if self.use_register_tokens: |
|
|
nn.init.trunc_normal_(self.register_tokens, std=0.02) |
|
|
|
|
|
self.apply(self._init_module_weights) |
|
|
|
|
|
def _init_module_weights(self, m): |
|
|
if isinstance(m, nn.Linear): |
|
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
elif isinstance(m, nn.Conv2d): |
|
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
elif isinstance(m, RMSNorm): |
|
|
if hasattr(m, 'weight') and m.weight is not None: |
|
|
nn.init.ones_(m.weight) |
|
|
|
|
|
def _interpolate_pos_encoding( |
|
|
self, |
|
|
patch_tokens: torch.Tensor, |
|
|
grid_size: Tuple[int, int] |
|
|
) -> torch.Tensor: |
|
|
pretrain_grid_h = self.pretrain_img_size // self.patch_size |
|
|
pretrain_grid_w = pretrain_grid_h |
|
|
|
|
|
|
|
|
if grid_size[0] == pretrain_grid_h and grid_size[1] == pretrain_grid_w: |
|
|
return self.pos_embed |
|
|
|
|
|
|
|
|
|
|
|
num_extra_tokens = 1 + self.num_register_tokens |
|
|
cls_register_pos = self.pos_embed[:, :num_extra_tokens, :] |
|
|
patch_pos_embed = self.pos_embed[:, num_extra_tokens:, :] |
|
|
|
|
|
|
|
|
patch_pos_embed = patch_pos_embed.reshape( |
|
|
1, pretrain_grid_h, pretrain_grid_w, -1 |
|
|
).permute(0, 3, 1, 2) |
|
|
|
|
|
patch_pos_embed = F.interpolate( |
|
|
patch_pos_embed, |
|
|
size=grid_size, |
|
|
mode='bicubic', |
|
|
align_corners=False |
|
|
) |
|
|
|
|
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).flatten(1, 2) |
|
|
|
|
|
|
|
|
return torch.cat([cls_register_pos, patch_pos_embed], dim=1) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
B = x.shape[0] |
|
|
|
|
|
|
|
|
x, grid_size = self.patch_embed(x) |
|
|
|
|
|
|
|
|
cls_tokens = self.cls_token.expand(B, -1, -1) |
|
|
|
|
|
if self.use_register_tokens: |
|
|
register_tokens = self.register_tokens.expand(B, -1, -1) |
|
|
|
|
|
x = torch.cat([cls_tokens, register_tokens, x], dim=1) |
|
|
else: |
|
|
x = torch.cat([cls_tokens, x], dim=1) |
|
|
|
|
|
|
|
|
pos_embed = self._interpolate_pos_encoding(x, grid_size) |
|
|
x = self.pos_drop(x + pos_embed) |
|
|
|
|
|
|
|
|
for block in self.blocks: |
|
|
x = block(x) |
|
|
|
|
|
x = self.norm(x) |
|
|
|
|
|
return x |
|
|
|
|
|
class ImprovedAudioEncoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
n_mels: int = 128, |
|
|
target_length: int = 1024, |
|
|
embed_dim: int = 2048, |
|
|
depth: int = 12, |
|
|
n_heads: int = 16, |
|
|
patch_size: int = 16, |
|
|
dropout: float = 0.1, |
|
|
use_adapter: bool = False, |
|
|
adapter_dim: int = 64, |
|
|
use_dual_stream: bool = True |
|
|
): |
|
|
super().__init__() |
|
|
self.use_dual_stream = use_dual_stream |
|
|
self.patch_size = patch_size |
|
|
|
|
|
|
|
|
self.patch_embed = nn.Conv2d( |
|
|
1, embed_dim, kernel_size=patch_size, stride=patch_size |
|
|
) |
|
|
|
|
|
self.n_patches_h = n_mels // patch_size |
|
|
self.n_patches_w = target_length // patch_size |
|
|
n_patches = self.n_patches_h * self.n_patches_w |
|
|
|
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches, embed_dim)) |
|
|
self.pos_drop = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList([ |
|
|
OptimizedTransformerBlock( |
|
|
embed_dim, n_heads, None, None, dropout, |
|
|
use_adapter=use_adapter, adapter_dim=adapter_dim |
|
|
) |
|
|
for _ in range(depth) |
|
|
]) |
|
|
|
|
|
|
|
|
if use_dual_stream: |
|
|
self.temporal_pool = nn.AdaptiveAvgPool1d(1) |
|
|
self.frequency_pool = nn.AdaptiveAvgPool1d(1) |
|
|
|
|
|
self.temporal_proj = nn.Linear(embed_dim, embed_dim) |
|
|
self.frequency_proj = nn.Linear(embed_dim, embed_dim) |
|
|
|
|
|
self.fusion = nn.Linear(embed_dim * 2, embed_dim) |
|
|
|
|
|
self.norm = RMSNorm(embed_dim) |
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self): |
|
|
nn.init.trunc_normal_(self.pos_embed, std=0.02) |
|
|
self.apply(self._init_module_weights) |
|
|
|
|
|
def _init_module_weights(self, m): |
|
|
if isinstance(m, nn.Linear): |
|
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
elif isinstance(m, nn.Conv2d): |
|
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
|
|
|
def forward(self, mel_spec: torch.Tensor) -> torch.Tensor: |
|
|
if mel_spec.ndim == 3: |
|
|
mel_spec = mel_spec.unsqueeze(1) |
|
|
|
|
|
|
|
|
x = self.patch_embed(mel_spec) |
|
|
x = x.flatten(2).transpose(1, 2) |
|
|
x = self.pos_drop(x + self.pos_embed) |
|
|
|
|
|
|
|
|
for block in self.blocks: |
|
|
x, _, _ = block(x) |
|
|
|
|
|
x = self.norm(x) |
|
|
|
|
|
if self.use_dual_stream: |
|
|
B, N, C = x.shape |
|
|
|
|
|
|
|
|
x_2d = x.transpose(1, 2).reshape(B, C, self.n_patches_h, self.n_patches_w) |
|
|
|
|
|
|
|
|
temporal = x_2d.mean(dim=2) |
|
|
temporal = self.temporal_pool(temporal).squeeze(-1) |
|
|
temporal = self.temporal_proj(temporal).unsqueeze(1) |
|
|
|
|
|
|
|
|
frequency = x_2d.mean(dim=3) |
|
|
frequency = self.frequency_pool(frequency).squeeze(-1) |
|
|
frequency = self.frequency_proj(frequency).unsqueeze(1) |
|
|
|
|
|
|
|
|
x = self.fusion(torch.cat([temporal, frequency], dim=-1)) |
|
|
else: |
|
|
|
|
|
x = x.mean(dim=1, keepdim=True) |
|
|
|
|
|
return x |
|
|
|
|
|
class ImprovedVideoEncoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
img_size: int = 224, |
|
|
patch_size: int = 14, |
|
|
in_channels: int = 3, |
|
|
embed_dim: int = 2048, |
|
|
spatial_depth: int = 12, |
|
|
temporal_depth: int = 4, |
|
|
n_heads: int = 16, |
|
|
num_frames: int = 16, |
|
|
dropout: float = 0.1, |
|
|
use_adapter: bool = False, |
|
|
adapter_dim: int = 64, |
|
|
use_3d_conv: bool = False |
|
|
): |
|
|
super().__init__() |
|
|
self.num_frames = num_frames |
|
|
self.use_3d_conv = use_3d_conv |
|
|
self.patch_size = patch_size |
|
|
self.img_size = img_size |
|
|
|
|
|
if use_3d_conv: |
|
|
|
|
|
self.patch_embed = nn.Conv3d( |
|
|
in_channels, |
|
|
embed_dim, |
|
|
kernel_size=(2, patch_size, patch_size), |
|
|
stride=(2, patch_size, patch_size) |
|
|
) |
|
|
self.n_temporal_patches = num_frames // 2 |
|
|
self.n_spatial_patches = (img_size // patch_size) ** 2 |
|
|
else: |
|
|
|
|
|
self.patch_embed = ImprovedPatchEmbedding( |
|
|
patch_size, in_channels, embed_dim |
|
|
) |
|
|
self.n_spatial_patches = (img_size // patch_size) ** 2 |
|
|
|
|
|
|
|
|
self.spatial_pos_embed = nn.Parameter( |
|
|
torch.zeros(1, self.n_spatial_patches, embed_dim) |
|
|
) |
|
|
self.spatial_pos_drop = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
self.spatial_blocks = nn.ModuleList([ |
|
|
OptimizedTransformerBlock( |
|
|
embed_dim, n_heads, None, None, dropout, |
|
|
use_adapter=use_adapter, adapter_dim=adapter_dim |
|
|
) |
|
|
for _ in range(spatial_depth) |
|
|
]) |
|
|
|
|
|
|
|
|
if use_3d_conv: |
|
|
self.temporal_pos_embed = nn.Parameter( |
|
|
torch.zeros(1, self.n_temporal_patches, embed_dim) |
|
|
) |
|
|
else: |
|
|
self.temporal_pos_embed = nn.Parameter( |
|
|
torch.zeros(1, num_frames, embed_dim) |
|
|
) |
|
|
self.temporal_pos_drop = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
self.temporal_blocks = nn.ModuleList([ |
|
|
OptimizedTransformerBlock( |
|
|
embed_dim, n_heads, None, None, dropout, |
|
|
use_adapter=use_adapter, adapter_dim=adapter_dim |
|
|
) |
|
|
for _ in range(temporal_depth) |
|
|
]) |
|
|
|
|
|
self.norm = RMSNorm(embed_dim) |
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self): |
|
|
nn.init.trunc_normal_(self.spatial_pos_embed, std=0.02) |
|
|
nn.init.trunc_normal_(self.temporal_pos_embed, std=0.02) |
|
|
self.apply(self._init_module_weights) |
|
|
|
|
|
def _init_module_weights(self, m): |
|
|
if isinstance(m, nn.Linear): |
|
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
elif isinstance(m, (nn.Conv2d, nn.Conv3d)): |
|
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
B, T, C, H, W = x.shape |
|
|
|
|
|
if self.use_3d_conv: |
|
|
x = x.transpose(1, 2) |
|
|
x = self.patch_embed(x) |
|
|
|
|
|
|
|
|
B, D, T_new, H_new, W_new = x.shape |
|
|
x = x.view(B, D, T_new, -1).permute(0, 2, 3, 1) |
|
|
|
|
|
|
|
|
x = x + self.spatial_pos_embed.unsqueeze(1) |
|
|
|
|
|
|
|
|
x_flat = x.reshape(B * T_new, -1, D) |
|
|
for block in self.spatial_blocks: |
|
|
x_flat, _, _ = block(x_flat) |
|
|
|
|
|
|
|
|
x = x_flat.view(B, T_new, -1, D) |
|
|
x = x.mean(dim=2) |
|
|
|
|
|
else: |
|
|
|
|
|
x_flat = x.view(B * T, C, H, W) |
|
|
x_patched, grid_size = self.patch_embed(x_flat) |
|
|
|
|
|
|
|
|
x_patched = self.spatial_pos_drop(x_patched + self.spatial_pos_embed) |
|
|
|
|
|
|
|
|
for block in self.spatial_blocks: |
|
|
x_patched, _, _ = block(x_patched) |
|
|
|
|
|
_, N, D = x_patched.shape |
|
|
x_spatial = x_patched.view(B, T, N, D) |
|
|
x = x_spatial.mean(dim=2) |
|
|
|
|
|
|
|
|
x = self.temporal_pos_drop(x + self.temporal_pos_embed) |
|
|
|
|
|
|
|
|
for block in self.temporal_blocks: |
|
|
x, _, _ = block(x) |
|
|
|
|
|
return self.norm(x) |