import torch import torch.nn as nn import torch.nn.functional as F from typing import Tuple, Optional from components import RMSNorm, SwiGLU from transformer import OptimizedTransformerBlock import math class LayerScale(nn.Module): def __init__(self, dim: int, init_values: float = 1e-5): super().__init__() self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: return x * self.gamma class StochasticDepth(nn.Module): def __init__(self, drop_prob: float = 0.0): super().__init__() self.drop_prob = drop_prob def forward(self, x: torch.Tensor) -> torch.Tensor: if not self.training or self.drop_prob == 0.0: return x keep_prob = 1 - self.drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() return x.div(keep_prob) * random_tensor class ImprovedPatchEmbedding(nn.Module): def __init__( self, patch_size: int = 14, in_channels: int = 3, embed_dim: int = 2048, overlap: int = 0 ): super().__init__() self.patch_size = patch_size stride = patch_size - overlap self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=patch_size, stride=stride, padding=overlap // 2 ) self.norm = RMSNorm(embed_dim) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]: B, C, H, W = x.shape x = self.proj(x) grid_size = (x.shape[2], x.shape[3]) x = x.flatten(2).transpose(1, 2) x = self.norm(x) return x, grid_size class ImprovedVisionBlock(nn.Module): def __init__( self, dim: int, n_heads: int, dropout: float = 0.0, drop_path: float = 0.0, use_adapter: bool = False, adapter_dim: int = 64, use_layer_scale: bool = True, layer_scale_init: float = 1e-5 ): super().__init__() self.norm1 = RMSNorm(dim) self.attn = nn.MultiheadAttention( dim, n_heads, dropout=dropout, batch_first=True ) self.norm2 = RMSNorm(dim) self.mlp = nn.Sequential( nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * 4, dim), nn.Dropout(dropout) ) self.drop_path = StochasticDepth(drop_path) if drop_path > 0 else nn.Identity() if use_layer_scale: self.ls1 = LayerScale(dim, layer_scale_init) self.ls2 = LayerScale(dim, layer_scale_init) else: self.ls1 = nn.Identity() self.ls2 = nn.Identity() if use_adapter: self.adapter = nn.Sequential( nn.Linear(dim, adapter_dim), nn.GELU(), nn.Linear(adapter_dim, dim) ) else: self.adapter = None def forward(self, x: torch.Tensor) -> torch.Tensor: # 注意力 normx = self.norm1(x) attn_out, _ = self.attn(normx, normx, normx) x = x + self.drop_path(self.ls1(attn_out)) # MLP x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x)))) # Adapter if self.adapter is not None: x = x + self.adapter(x) return x class ImprovedVisionTransformer(nn.Module): def __init__( self, img_size: int = 224, patch_size: int = 14, in_channels: int = 3, embed_dim: int = 2048, depth: int = 24, n_heads: int = 16, dropout: float = 0.0, drop_path_rate: float = 0.1, use_register_tokens: bool = True, num_register_tokens: int = 4, use_adapter: bool = False, adapter_dim: int = 64, use_layer_scale: bool = True, layer_scale_init: float = 1e-5 ): super().__init__() self.patch_size = patch_size self.embed_dim = embed_dim self.use_register_tokens = use_register_tokens self.num_register_tokens = num_register_tokens if use_register_tokens else 0 # Patch embedding self.patch_embed = ImprovedPatchEmbedding( patch_size, in_channels, embed_dim, overlap=0 ) self.pretrain_img_size = img_size n_patches_pretrain = (img_size // patch_size) ** 2 # CLS token self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # Register tokens if use_register_tokens: self.register_tokens = nn.Parameter( torch.zeros(1, num_register_tokens, embed_dim) ) total_tokens = 1 + n_patches_pretrain + self.num_register_tokens self.pos_embed = nn.Parameter( torch.zeros(1, total_tokens, embed_dim) ) self.pos_drop = nn.Dropout(dropout) # Stochastic depth dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # Transformer blocks self.blocks = nn.ModuleList([ ImprovedVisionBlock( embed_dim, n_heads, dropout, drop_path=dpr[i], use_adapter=use_adapter, adapter_dim=adapter_dim, use_layer_scale=use_layer_scale, layer_scale_init=layer_scale_init ) for i in range(depth) ]) self.norm = RMSNorm(embed_dim) self._init_weights() def _init_weights(self): nn.init.trunc_normal_(self.cls_token, std=0.02) nn.init.trunc_normal_(self.pos_embed, std=0.02) if self.use_register_tokens: nn.init.trunc_normal_(self.register_tokens, std=0.02) self.apply(self._init_module_weights) def _init_module_weights(self, m): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, RMSNorm): if hasattr(m, 'weight') and m.weight is not None: nn.init.ones_(m.weight) def _interpolate_pos_encoding( self, patch_tokens: torch.Tensor, grid_size: Tuple[int, int] ) -> torch.Tensor: pretrain_grid_h = self.pretrain_img_size // self.patch_size pretrain_grid_w = pretrain_grid_h # 如果尺寸匹配,直接返回原始位置编码 if grid_size[0] == pretrain_grid_h and grid_size[1] == pretrain_grid_w: return self.pos_embed # 分离不同部分的位置编码 # pos_embed结构: [CLS(1), register_tokens(n), patches(H*W)] num_extra_tokens = 1 + self.num_register_tokens cls_register_pos = self.pos_embed[:, :num_extra_tokens, :] # [1, 1+n, dim] patch_pos_embed = self.pos_embed[:, num_extra_tokens:, :] # [1, H*W, dim] # 2D插值patch位置编码 patch_pos_embed = patch_pos_embed.reshape( 1, pretrain_grid_h, pretrain_grid_w, -1 ).permute(0, 3, 1, 2) patch_pos_embed = F.interpolate( patch_pos_embed, size=grid_size, mode='bicubic', align_corners=False ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).flatten(1, 2) # 拼接回去 return torch.cat([cls_register_pos, patch_pos_embed], dim=1) def forward(self, x: torch.Tensor) -> torch.Tensor: B = x.shape[0] # Patch embedding x, grid_size = self.patch_embed(x) # 添加CLS token cls_tokens = self.cls_token.expand(B, -1, -1) if self.use_register_tokens: register_tokens = self.register_tokens.expand(B, -1, -1) # 顺序: [CLS, register_tokens, patches] x = torch.cat([cls_tokens, register_tokens, x], dim=1) else: x = torch.cat([cls_tokens, x], dim=1) # 位置编码 pos_embed = self._interpolate_pos_encoding(x, grid_size) x = self.pos_drop(x + pos_embed) # Transformer blocks for block in self.blocks: x = block(x) x = self.norm(x) return x class ImprovedAudioEncoder(nn.Module): def __init__( self, n_mels: int = 128, target_length: int = 1024, embed_dim: int = 2048, depth: int = 12, n_heads: int = 16, patch_size: int = 16, dropout: float = 0.1, use_adapter: bool = False, adapter_dim: int = 64, use_dual_stream: bool = True ): super().__init__() self.use_dual_stream = use_dual_stream self.patch_size = patch_size # 主编码器 self.patch_embed = nn.Conv2d( 1, embed_dim, kernel_size=patch_size, stride=patch_size ) self.n_patches_h = n_mels // patch_size self.n_patches_w = target_length // patch_size n_patches = self.n_patches_h * self.n_patches_w self.pos_embed = nn.Parameter(torch.zeros(1, n_patches, embed_dim)) self.pos_drop = nn.Dropout(dropout) # Transformer blocks self.blocks = nn.ModuleList([ OptimizedTransformerBlock( embed_dim, n_heads, None, None, dropout, use_adapter=use_adapter, adapter_dim=adapter_dim ) for _ in range(depth) ]) # 双流:时间流和频率流 if use_dual_stream: self.temporal_pool = nn.AdaptiveAvgPool1d(1) self.frequency_pool = nn.AdaptiveAvgPool1d(1) self.temporal_proj = nn.Linear(embed_dim, embed_dim) self.frequency_proj = nn.Linear(embed_dim, embed_dim) self.fusion = nn.Linear(embed_dim * 2, embed_dim) self.norm = RMSNorm(embed_dim) self._init_weights() def _init_weights(self): nn.init.trunc_normal_(self.pos_embed, std=0.02) self.apply(self._init_module_weights) def _init_module_weights(self, m): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) def forward(self, mel_spec: torch.Tensor) -> torch.Tensor: if mel_spec.ndim == 3: mel_spec = mel_spec.unsqueeze(1) # Patch embedding x = self.patch_embed(mel_spec) # [B, C, H, W] x = x.flatten(2).transpose(1, 2) # [B, H*W, C] x = self.pos_drop(x + self.pos_embed) # Transformer encoding for block in self.blocks: x, _, _ = block(x) x = self.norm(x) if self.use_dual_stream: B, N, C = x.shape # 重塑为2D网格 x_2d = x.transpose(1, 2).reshape(B, C, self.n_patches_h, self.n_patches_w) # 时间流:沿频率维度池化(保留时间) temporal = x_2d.mean(dim=2) # [B, C, W] temporal = self.temporal_pool(temporal).squeeze(-1) # [B, C] temporal = self.temporal_proj(temporal).unsqueeze(1) # [B, 1, C] # 频率流:沿时间维度池化(保留频率) frequency = x_2d.mean(dim=3) # [B, C, H] frequency = self.frequency_pool(frequency).squeeze(-1) # [B, C] frequency = self.frequency_proj(frequency).unsqueeze(1) # [B, 1, C] # 融合 x = self.fusion(torch.cat([temporal, frequency], dim=-1)) else: # 简单全局平均池化 x = x.mean(dim=1, keepdim=True) return x class ImprovedVideoEncoder(nn.Module): def __init__( self, img_size: int = 224, patch_size: int = 14, in_channels: int = 3, embed_dim: int = 2048, spatial_depth: int = 12, temporal_depth: int = 4, n_heads: int = 16, num_frames: int = 16, dropout: float = 0.1, use_adapter: bool = False, adapter_dim: int = 64, use_3d_conv: bool = False ): super().__init__() self.num_frames = num_frames self.use_3d_conv = use_3d_conv self.patch_size = patch_size self.img_size = img_size if use_3d_conv: # 3D卷积处理时空信息 self.patch_embed = nn.Conv3d( in_channels, embed_dim, kernel_size=(2, patch_size, patch_size), stride=(2, patch_size, patch_size) ) self.n_temporal_patches = num_frames // 2 self.n_spatial_patches = (img_size // patch_size) ** 2 else: # 2D卷积 + 时序建模 self.patch_embed = ImprovedPatchEmbedding( patch_size, in_channels, embed_dim ) self.n_spatial_patches = (img_size // patch_size) ** 2 # 空间位置编码 self.spatial_pos_embed = nn.Parameter( torch.zeros(1, self.n_spatial_patches, embed_dim) ) self.spatial_pos_drop = nn.Dropout(dropout) # 空间编码器 self.spatial_blocks = nn.ModuleList([ OptimizedTransformerBlock( embed_dim, n_heads, None, None, dropout, use_adapter=use_adapter, adapter_dim=adapter_dim ) for _ in range(spatial_depth) ]) # 时间位置编码 if use_3d_conv: self.temporal_pos_embed = nn.Parameter( torch.zeros(1, self.n_temporal_patches, embed_dim) ) else: self.temporal_pos_embed = nn.Parameter( torch.zeros(1, num_frames, embed_dim) ) self.temporal_pos_drop = nn.Dropout(dropout) # 时序编码器 self.temporal_blocks = nn.ModuleList([ OptimizedTransformerBlock( embed_dim, n_heads, None, None, dropout, use_adapter=use_adapter, adapter_dim=adapter_dim ) for _ in range(temporal_depth) ]) self.norm = RMSNorm(embed_dim) self._init_weights() def _init_weights(self): nn.init.trunc_normal_(self.spatial_pos_embed, std=0.02) nn.init.trunc_normal_(self.temporal_pos_embed, std=0.02) self.apply(self._init_module_weights) def _init_module_weights(self, m): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, (nn.Conv2d, nn.Conv3d)): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: B, T, C, H, W = x.shape if self.use_3d_conv: x = x.transpose(1, 2) # [B, C, T, H, W] x = self.patch_embed(x) # [B, embed_dim, T', H', W'] # 重塑: [B, D, T', H'*W'] -> [B, T', H'*W', D] B, D, T_new, H_new, W_new = x.shape x = x.view(B, D, T_new, -1).permute(0, 2, 3, 1) # [B, T', H'*W', D] # 空间位置编码(每帧独立) x = x + self.spatial_pos_embed.unsqueeze(1) # 逐帧空间编码 x_flat = x.reshape(B * T_new, -1, D) for block in self.spatial_blocks: x_flat, _, _ = block(x_flat) # 重塑回时序维度 x = x_flat.view(B, T_new, -1, D) x = x.mean(dim=2) # [B, T', D] else: # 2D卷积 + 分离时空建模 x_flat = x.view(B * T, C, H, W) x_patched, grid_size = self.patch_embed(x_flat) # 空间位置编码 x_patched = self.spatial_pos_drop(x_patched + self.spatial_pos_embed) # 空间编码 for block in self.spatial_blocks: x_patched, _, _ = block(x_patched) _, N, D = x_patched.shape x_spatial = x_patched.view(B, T, N, D) x = x_spatial.mean(dim=2) # 时序位置编码 x = self.temporal_pos_drop(x + self.temporal_pos_embed) # 时序编码 for block in self.temporal_blocks: x, _, _ = block(x) return self.norm(x)