MultiModal / encoders.py
szxllm's picture
Update encoders.py
6328772 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
from components import RMSNorm, SwiGLU
from transformer import OptimizedTransformerBlock
import math
class LayerScale(nn.Module):
def __init__(self, dim: int, init_values: float = 1e-5):
super().__init__()
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.gamma
class StochasticDepth(nn.Module):
def __init__(self, drop_prob: float = 0.0):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.training or self.drop_prob == 0.0:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
return x.div(keep_prob) * random_tensor
class ImprovedPatchEmbedding(nn.Module):
def __init__(
self,
patch_size: int = 14,
in_channels: int = 3,
embed_dim: int = 2048,
overlap: int = 0
):
super().__init__()
self.patch_size = patch_size
stride = patch_size - overlap
self.proj = nn.Conv2d(
in_channels,
embed_dim,
kernel_size=patch_size,
stride=stride,
padding=overlap // 2
)
self.norm = RMSNorm(embed_dim)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
B, C, H, W = x.shape
x = self.proj(x)
grid_size = (x.shape[2], x.shape[3])
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, grid_size
class ImprovedVisionBlock(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
dropout: float = 0.0,
drop_path: float = 0.0,
use_adapter: bool = False,
adapter_dim: int = 64,
use_layer_scale: bool = True,
layer_scale_init: float = 1e-5
):
super().__init__()
self.norm1 = RMSNorm(dim)
self.attn = nn.MultiheadAttention(
dim, n_heads, dropout=dropout, batch_first=True
)
self.norm2 = RMSNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * 4, dim),
nn.Dropout(dropout)
)
self.drop_path = StochasticDepth(drop_path) if drop_path > 0 else nn.Identity()
if use_layer_scale:
self.ls1 = LayerScale(dim, layer_scale_init)
self.ls2 = LayerScale(dim, layer_scale_init)
else:
self.ls1 = nn.Identity()
self.ls2 = nn.Identity()
if use_adapter:
self.adapter = nn.Sequential(
nn.Linear(dim, adapter_dim),
nn.GELU(),
nn.Linear(adapter_dim, dim)
)
else:
self.adapter = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 注意力
normx = self.norm1(x)
attn_out, _ = self.attn(normx, normx, normx)
x = x + self.drop_path(self.ls1(attn_out))
# MLP
x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
# Adapter
if self.adapter is not None:
x = x + self.adapter(x)
return x
class ImprovedVisionTransformer(nn.Module):
def __init__(
self,
img_size: int = 224,
patch_size: int = 14,
in_channels: int = 3,
embed_dim: int = 2048,
depth: int = 24,
n_heads: int = 16,
dropout: float = 0.0,
drop_path_rate: float = 0.1,
use_register_tokens: bool = True,
num_register_tokens: int = 4,
use_adapter: bool = False,
adapter_dim: int = 64,
use_layer_scale: bool = True,
layer_scale_init: float = 1e-5
):
super().__init__()
self.patch_size = patch_size
self.embed_dim = embed_dim
self.use_register_tokens = use_register_tokens
self.num_register_tokens = num_register_tokens if use_register_tokens else 0
# Patch embedding
self.patch_embed = ImprovedPatchEmbedding(
patch_size, in_channels, embed_dim, overlap=0
)
self.pretrain_img_size = img_size
n_patches_pretrain = (img_size // patch_size) ** 2
# CLS token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# Register tokens
if use_register_tokens:
self.register_tokens = nn.Parameter(
torch.zeros(1, num_register_tokens, embed_dim)
)
total_tokens = 1 + n_patches_pretrain + self.num_register_tokens
self.pos_embed = nn.Parameter(
torch.zeros(1, total_tokens, embed_dim)
)
self.pos_drop = nn.Dropout(dropout)
# Stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
# Transformer blocks
self.blocks = nn.ModuleList([
ImprovedVisionBlock(
embed_dim,
n_heads,
dropout,
drop_path=dpr[i],
use_adapter=use_adapter,
adapter_dim=adapter_dim,
use_layer_scale=use_layer_scale,
layer_scale_init=layer_scale_init
)
for i in range(depth)
])
self.norm = RMSNorm(embed_dim)
self._init_weights()
def _init_weights(self):
nn.init.trunc_normal_(self.cls_token, std=0.02)
nn.init.trunc_normal_(self.pos_embed, std=0.02)
if self.use_register_tokens:
nn.init.trunc_normal_(self.register_tokens, std=0.02)
self.apply(self._init_module_weights)
def _init_module_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, RMSNorm):
if hasattr(m, 'weight') and m.weight is not None:
nn.init.ones_(m.weight)
def _interpolate_pos_encoding(
self,
patch_tokens: torch.Tensor,
grid_size: Tuple[int, int]
) -> torch.Tensor:
pretrain_grid_h = self.pretrain_img_size // self.patch_size
pretrain_grid_w = pretrain_grid_h
# 如果尺寸匹配,直接返回原始位置编码
if grid_size[0] == pretrain_grid_h and grid_size[1] == pretrain_grid_w:
return self.pos_embed
# 分离不同部分的位置编码
# pos_embed结构: [CLS(1), register_tokens(n), patches(H*W)]
num_extra_tokens = 1 + self.num_register_tokens
cls_register_pos = self.pos_embed[:, :num_extra_tokens, :] # [1, 1+n, dim]
patch_pos_embed = self.pos_embed[:, num_extra_tokens:, :] # [1, H*W, dim]
# 2D插值patch位置编码
patch_pos_embed = patch_pos_embed.reshape(
1, pretrain_grid_h, pretrain_grid_w, -1
).permute(0, 3, 1, 2)
patch_pos_embed = F.interpolate(
patch_pos_embed,
size=grid_size,
mode='bicubic',
align_corners=False
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).flatten(1, 2)
# 拼接回去
return torch.cat([cls_register_pos, patch_pos_embed], dim=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B = x.shape[0]
# Patch embedding
x, grid_size = self.patch_embed(x)
# 添加CLS token
cls_tokens = self.cls_token.expand(B, -1, -1)
if self.use_register_tokens:
register_tokens = self.register_tokens.expand(B, -1, -1)
# 顺序: [CLS, register_tokens, patches]
x = torch.cat([cls_tokens, register_tokens, x], dim=1)
else:
x = torch.cat([cls_tokens, x], dim=1)
# 位置编码
pos_embed = self._interpolate_pos_encoding(x, grid_size)
x = self.pos_drop(x + pos_embed)
# Transformer blocks
for block in self.blocks:
x = block(x)
x = self.norm(x)
return x
class ImprovedAudioEncoder(nn.Module):
def __init__(
self,
n_mels: int = 128,
target_length: int = 1024,
embed_dim: int = 2048,
depth: int = 12,
n_heads: int = 16,
patch_size: int = 16,
dropout: float = 0.1,
use_adapter: bool = False,
adapter_dim: int = 64,
use_dual_stream: bool = True
):
super().__init__()
self.use_dual_stream = use_dual_stream
self.patch_size = patch_size
# 主编码器
self.patch_embed = nn.Conv2d(
1, embed_dim, kernel_size=patch_size, stride=patch_size
)
self.n_patches_h = n_mels // patch_size
self.n_patches_w = target_length // patch_size
n_patches = self.n_patches_h * self.n_patches_w
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches, embed_dim))
self.pos_drop = nn.Dropout(dropout)
# Transformer blocks
self.blocks = nn.ModuleList([
OptimizedTransformerBlock(
embed_dim, n_heads, None, None, dropout,
use_adapter=use_adapter, adapter_dim=adapter_dim
)
for _ in range(depth)
])
# 双流:时间流和频率流
if use_dual_stream:
self.temporal_pool = nn.AdaptiveAvgPool1d(1)
self.frequency_pool = nn.AdaptiveAvgPool1d(1)
self.temporal_proj = nn.Linear(embed_dim, embed_dim)
self.frequency_proj = nn.Linear(embed_dim, embed_dim)
self.fusion = nn.Linear(embed_dim * 2, embed_dim)
self.norm = RMSNorm(embed_dim)
self._init_weights()
def _init_weights(self):
nn.init.trunc_normal_(self.pos_embed, std=0.02)
self.apply(self._init_module_weights)
def _init_module_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
if mel_spec.ndim == 3:
mel_spec = mel_spec.unsqueeze(1)
# Patch embedding
x = self.patch_embed(mel_spec) # [B, C, H, W]
x = x.flatten(2).transpose(1, 2) # [B, H*W, C]
x = self.pos_drop(x + self.pos_embed)
# Transformer encoding
for block in self.blocks:
x, _, _ = block(x)
x = self.norm(x)
if self.use_dual_stream:
B, N, C = x.shape
# 重塑为2D网格
x_2d = x.transpose(1, 2).reshape(B, C, self.n_patches_h, self.n_patches_w)
# 时间流:沿频率维度池化(保留时间)
temporal = x_2d.mean(dim=2) # [B, C, W]
temporal = self.temporal_pool(temporal).squeeze(-1) # [B, C]
temporal = self.temporal_proj(temporal).unsqueeze(1) # [B, 1, C]
# 频率流:沿时间维度池化(保留频率)
frequency = x_2d.mean(dim=3) # [B, C, H]
frequency = self.frequency_pool(frequency).squeeze(-1) # [B, C]
frequency = self.frequency_proj(frequency).unsqueeze(1) # [B, 1, C]
# 融合
x = self.fusion(torch.cat([temporal, frequency], dim=-1))
else:
# 简单全局平均池化
x = x.mean(dim=1, keepdim=True)
return x
class ImprovedVideoEncoder(nn.Module):
def __init__(
self,
img_size: int = 224,
patch_size: int = 14,
in_channels: int = 3,
embed_dim: int = 2048,
spatial_depth: int = 12,
temporal_depth: int = 4,
n_heads: int = 16,
num_frames: int = 16,
dropout: float = 0.1,
use_adapter: bool = False,
adapter_dim: int = 64,
use_3d_conv: bool = False
):
super().__init__()
self.num_frames = num_frames
self.use_3d_conv = use_3d_conv
self.patch_size = patch_size
self.img_size = img_size
if use_3d_conv:
# 3D卷积处理时空信息
self.patch_embed = nn.Conv3d(
in_channels,
embed_dim,
kernel_size=(2, patch_size, patch_size),
stride=(2, patch_size, patch_size)
)
self.n_temporal_patches = num_frames // 2
self.n_spatial_patches = (img_size // patch_size) ** 2
else:
# 2D卷积 + 时序建模
self.patch_embed = ImprovedPatchEmbedding(
patch_size, in_channels, embed_dim
)
self.n_spatial_patches = (img_size // patch_size) ** 2
# 空间位置编码
self.spatial_pos_embed = nn.Parameter(
torch.zeros(1, self.n_spatial_patches, embed_dim)
)
self.spatial_pos_drop = nn.Dropout(dropout)
# 空间编码器
self.spatial_blocks = nn.ModuleList([
OptimizedTransformerBlock(
embed_dim, n_heads, None, None, dropout,
use_adapter=use_adapter, adapter_dim=adapter_dim
)
for _ in range(spatial_depth)
])
# 时间位置编码
if use_3d_conv:
self.temporal_pos_embed = nn.Parameter(
torch.zeros(1, self.n_temporal_patches, embed_dim)
)
else:
self.temporal_pos_embed = nn.Parameter(
torch.zeros(1, num_frames, embed_dim)
)
self.temporal_pos_drop = nn.Dropout(dropout)
# 时序编码器
self.temporal_blocks = nn.ModuleList([
OptimizedTransformerBlock(
embed_dim, n_heads, None, None, dropout,
use_adapter=use_adapter, adapter_dim=adapter_dim
)
for _ in range(temporal_depth)
])
self.norm = RMSNorm(embed_dim)
self._init_weights()
def _init_weights(self):
nn.init.trunc_normal_(self.spatial_pos_embed, std=0.02)
nn.init.trunc_normal_(self.temporal_pos_embed, std=0.02)
self.apply(self._init_module_weights)
def _init_module_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.Conv2d, nn.Conv3d)):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C, H, W = x.shape
if self.use_3d_conv:
x = x.transpose(1, 2) # [B, C, T, H, W]
x = self.patch_embed(x) # [B, embed_dim, T', H', W']
# 重塑: [B, D, T', H'*W'] -> [B, T', H'*W', D]
B, D, T_new, H_new, W_new = x.shape
x = x.view(B, D, T_new, -1).permute(0, 2, 3, 1) # [B, T', H'*W', D]
# 空间位置编码(每帧独立)
x = x + self.spatial_pos_embed.unsqueeze(1)
# 逐帧空间编码
x_flat = x.reshape(B * T_new, -1, D)
for block in self.spatial_blocks:
x_flat, _, _ = block(x_flat)
# 重塑回时序维度
x = x_flat.view(B, T_new, -1, D)
x = x.mean(dim=2) # [B, T', D]
else:
# 2D卷积 + 分离时空建模
x_flat = x.view(B * T, C, H, W)
x_patched, grid_size = self.patch_embed(x_flat)
# 空间位置编码
x_patched = self.spatial_pos_drop(x_patched + self.spatial_pos_embed)
# 空间编码
for block in self.spatial_blocks:
x_patched, _, _ = block(x_patched)
_, N, D = x_patched.shape
x_spatial = x_patched.view(B, T, N, D)
x = x_spatial.mean(dim=2)
# 时序位置编码
x = self.temporal_pos_drop(x + self.temporal_pos_embed)
# 时序编码
for block in self.temporal_blocks:
x, _, _ = block(x)
return self.norm(x)