Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| class PatchEmbed(nn.Module): | |
| def __init__( | |
| self, | |
| img_size=224, | |
| patch_size=16, | |
| stride=10, | |
| in_chans=1, | |
| embed_dim=768 | |
| ): | |
| super().__init__() | |
| self.img_size = img_size | |
| self.patch_size = patch_size | |
| self.stride = stride | |
| self.projection = nn.Conv2d( | |
| in_channels=in_chans, | |
| out_channels=embed_dim, | |
| kernel_size=patch_size, | |
| stride=stride | |
| ) | |
| def forward(self, x): | |
| B, C, F, T = x.shape | |
| x = self.projection(x) | |
| x = x.flatten(2) | |
| x = x.transpose(1, 2) | |
| return x | |
| class MultiHeadAttention(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| num_heads=12, | |
| qkv_bias=True, | |
| attn_drop=0.0, | |
| proj_drop=0.0 | |
| ): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| print("num_heads:", dim, type(num_heads)) | |
| print("dim:", dim, type(dim)) | |
| self.head_dim = dim // num_heads | |
| print("head_dim:", self.head_dim, type(self.head_dim)) | |
| self.scale = self.head_dim ** -0.5 | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.projection = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| def forward(self, x): | |
| B, N, D = x.shape | |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) | |
| qkv = qkv.permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv[0], qkv[1], qkv[2] | |
| attn = (q @ k.transpose(-2, -1)) * self.scale | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_drop(attn) | |
| x = (attn @ v).transpose(1, 2) | |
| x = x.reshape(B, N, D) | |
| x = self.projection(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class MLP(nn.Module): | |
| def __init__( | |
| self, | |
| in_features, | |
| hidden_features=None, | |
| out_features=None, | |
| drop=0.0 | |
| ): | |
| super().__init__() | |
| out_features = out_features or in_features | |
| hidden_features = hidden_features or in_features | |
| self.fc1 = nn.Linear(in_features, hidden_features) | |
| self.act = nn.GELU() | |
| self.fc2 = nn.Linear(hidden_features, out_features) | |
| self.drop = nn.Dropout(drop) | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| x = self.drop(x) | |
| x = self.fc2(x) | |
| x = self.drop(x) | |
| return x | |
| class TransformerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| num_heads, | |
| mlp_ratio=4.0, | |
| qkv_bias=True, | |
| drop=0.0, | |
| attn_drop=0.0 | |
| ): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.attn = MultiHeadAttention( | |
| dim=dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| attn_drop=attn_drop, | |
| proj_drop=drop | |
| ) | |
| self.norm2 = nn.LayerNorm(dim) | |
| mlp_hidden_dim = int(dim * mlp_ratio) | |
| self.mlp = MLP( | |
| in_features=dim, | |
| hidden_features=mlp_hidden_dim, | |
| drop=drop | |
| ) | |
| def forward(self, x): | |
| x = x + self.attn(self.norm1(x)) | |
| x = x + self.mlp(self.norm2(x)) | |
| return x | |
| class AudioSpectrogramTransformer(nn.Module): | |
| def __init__( | |
| self, | |
| num_classes=50, | |
| input_fdim=128, | |
| input_tdim=500, | |
| patch_size=16, | |
| stride=10, | |
| in_chans=1, | |
| embed_dim=768, | |
| depth=12, | |
| num_heads=12, | |
| mlp_ratio=4, | |
| qkv_bias=True, | |
| drop_rate=0, | |
| attn_drop_rate=0 | |
| ): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.embed_dim = embed_dim | |
| self.patch_embed = PatchEmbed( | |
| patch_size=patch_size, | |
| stride=stride, | |
| in_chans=in_chans, | |
| embed_dim=embed_dim | |
| ) | |
| self.num_patches = self._calculate_num_patches(input_fdim, input_tdim, patch_size, stride) | |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
| self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim)) | |
| self.pos_drop = nn.Dropout(p=drop_rate) | |
| self.blocks = nn.ModuleList([ | |
| TransformerBlock( | |
| dim=embed_dim, | |
| num_heads=num_heads, | |
| mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| drop=drop_rate, | |
| attn_drop=attn_drop_rate | |
| ) | |
| for _ in range(depth) | |
| ]) | |
| self.norm = nn.LayerNorm(embed_dim) | |
| self.head = nn.Linear(embed_dim, num_classes) | |
| self._init_weights() | |
| def _calculate_num_patches(self, fdim, tdim, patch_size, stride): | |
| f_patches = (fdim - patch_size) // stride + 1 | |
| t_patches = (tdim - patch_size) // stride + 1 | |
| return f_patches * t_patches | |
| def _init_weights(self): | |
| nn.init.trunc_normal_(self.pos_embed, std=0.02) | |
| nn.init.trunc_normal_(self.cls_token, std=0.02) | |
| self.apply(self._init_layer_weights) | |
| def _init_layer_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| nn.init.trunc_normal_(m.weight, std=0.02) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| elif isinstance(m, nn.Conv2d): | |
| nn.init.trunc_normal_(m.weight, std=0.02) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x): | |
| if x.dim() == 3: | |
| x = x.transpose(1, 2).unsqueeze(1) | |
| elif x.dim() == 4 and x.shape[1] == 1: | |
| pass | |
| else: | |
| raise ValueError(f"Expected input shape [B, T, F] or [B, 1, F, T], got {x.shape}") | |
| B = x.shape[0] | |
| x = self.patch_embed(x) | |
| cls_tokens = self.cls_token.expand(B, -1, -1) | |
| x = torch.cat([cls_tokens, x], dim=1) | |
| x = x + self.pos_embed | |
| x = self.pos_drop(x) | |
| for block in self.blocks: | |
| x = block(x) | |
| x = self.norm(x) | |
| cls_output = x[:, 0] | |
| logits = self.head(cls_output) | |
| return logits | |
| def get_attention_maps(self, x, block_idx=-1): | |
| if x.dim() == 3: | |
| x = x.transpose(1, 2).unsqueeze(1) | |
| N = x.shape[0] | |
| x = self.patch_embed(x) | |
| cls_tokens = self.cls_token.expand(B, -1, -1) | |
| x = torch.cat([cls_token, x], dim=1) | |
| x = x + self.pos_embed | |
| x = self.pos_drop(x) | |
| target_block = self.blocks[block_idx] | |
| for i, block in enumerate(self.blocks): | |
| if i < len(self.blocks) + block_idx: | |
| x = block(x) | |
| else: | |
| break | |
| x_norm = target_block.norm1(x) | |
| B, N, D = x_norm.shape | |
| qkv = target_block.attn.qkv(x_norm).reshape(B, N, 3, target_block.attn.num_heads, target_block.attn.head_dim) | |
| qkv = qkv.permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv[0], qkv[1], qkv[2] | |
| attn = (q @ k.transpose(-2, -1)) * target_block.attn.scale | |
| attn = attn.softmax(dim=-1) | |
| return attn | |
| def create_ast_model( | |
| num_classes=50, | |
| model_size='base', | |
| input_fdim=128, | |
| input_tdim=500 | |
| ): | |
| """ | |
| Create AST model with preset configurations. | |
| Args: | |
| num_classes: Number of output classes | |
| model_size: 'tiny', 'small', 'base', or 'large' | |
| input_fdim: Frequency dimension | |
| input_tdim: Time dimension | |
| Returns: | |
| AST model | |
| """ | |
| configs = { | |
| 'tiny': { | |
| 'embed_dim': 192, | |
| 'depth': 12, | |
| 'num_heads': 3, | |
| }, | |
| 'small': { | |
| 'embed_dim': 384, | |
| 'depth': 12, | |
| 'num_heads': 6, | |
| }, | |
| 'base': { | |
| 'embed_dim': 768, | |
| 'depth': 12, | |
| 'num_heads': 12, | |
| }, | |
| 'large': { | |
| 'embed_dim': 1024, | |
| 'depth': 24, | |
| 'num_heads': 16, | |
| } | |
| } | |
| if model_size not in configs: | |
| raise ValueError(f"Model size must be one of {list(configs.keys())}") | |
| config = configs[model_size] | |
| model = AudioSpectrogramTransformer( | |
| num_classes=num_classes, | |
| input_fdim=input_fdim, | |
| input_tdim=input_tdim, | |
| embed_dim=config['embed_dim'], | |
| depth=config['depth'], | |
| num_heads=config['num_heads'], | |
| patch_size=16, | |
| stride=10, | |
| mlp_ratio=4.0, | |
| drop_rate=0.0, | |
| attn_drop_rate=0.0 | |
| ) | |
| return model | |
| if __name__ == '__main__': | |
| # Test the model | |
| print("Testing AST model...") | |
| # Create model | |
| model = create_ast_model(num_classes=50, model_size='base') | |
| print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters") | |
| # Test forward pass | |
| batch_size = 4 | |
| test_input = torch.randn(batch_size, 500, 128) # [B, T, F] | |
| model.eval() | |
| with torch.no_grad(): | |
| output = model(test_input) | |
| print(f"Input shape: {test_input.shape}") | |
| print(f"Output shape: {output.shape}") | |
| print("✓ Model works correctly!") | |