| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from __future__ import annotations |
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
|
|
| import os |
| import io |
| import sys |
| import math |
| import random |
| import collections |
| import collections.abc |
| import re |
| from itertools import repeat |
| from pathlib import Path |
| from typing import Optional, Tuple, Union, List, Dict |
|
|
| import csv |
| import numpy as np |
| import pandas as pd |
| from PIL import Image |
| import seaborn as sns |
| import matplotlib.pyplot as plt |
| from tqdm import trange, tqdm |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| from torch.nn.init import _calculate_fan_in_and_fan_out |
| import torch.utils.checkpoint as checkpoint |
|
|
| import torchvision as tv |
| from torchvision.transforms import v2 |
| from torch.utils.tensorboard import SummaryWriter |
| |
|
|
| os.environ["CUDA_VISIBLE_DEVICES"] = "1" |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| import torchaudio |
| import torchaudio.transforms as T |
| from torchlibrosa.stft import Spectrogram, LogmelFilterBank |
| from torchlibrosa.augmentation import SpecAugmentation |
|
|
| from transformers import AutoModel, AutoTokenizer, logging |
| from huggingface_hub.file_download import hf_hub_download |
| from huggingface_hub.file_download import hf_hub_download |
| from peft import get_peft_config, get_peft_model |
| from transformers import CLIPVisionModel, AutoProcessor |
|
|
| from watermark import watermark |
| print(watermark( |
| author='Ashish', |
| |
| current_date=True, |
| datename=True, |
| current_time=True, |
| iso8601=True, |
| timezone=True, |
| updated=True, |
| custom_time=None, |
| python=True, |
| |
| conda=True, |
| hostname=True, |
| machine=True, |
| watermark=False, |
| iversions=True, |
| gpu=True, |
| globals_=globals() |
| )) |
|
|
|
|
| |
| |
| |
| class HTSATConfig: |
| |
| |
| |
| |
|
|
| exp_name = "exp_htsat_pretrain" |
| workspace = "/home/kechen/Research/HTSAT" |
| dataset_path = "/home/Research/audioset" |
| desed_folder = "/home/Research/DESED" |
|
|
| dataset_type = "audioset" |
| index_type = "full_train" |
| balanced_data = True |
|
|
| loss_type = "clip_bce" |
| |
|
|
| |
| resume_checkpoint = None |
| |
| |
| esc_fold = 0 |
|
|
|
|
| debug = False |
|
|
| random_seed = 970131 |
| batch_size = 32 * 4 |
| learning_rate = 1e-3 |
| max_epoch = 100 |
| num_workers = 3 |
|
|
| lr_scheduler_epoch = [10,20,30] |
| lr_rate = [0.02, 0.05, 0.1] |
|
|
| |
| enable_token_label = False |
| class_map_path = "class_hier_map.npy" |
| class_filter = None |
| retrieval_index = [15382, 9202, 130, 17618, 17157, 17516, 16356, 6165, 13992, 9238, 5550, 5733, 1914, 1600, 3450, 13735, 11108, 3762, |
| 9840, 11318, 8131, 4429, 16748, 4992, 16783, 12691, 4945, 8779, 2805, 9418, 2797, 14357, 5603, 212, 3852, 12666, 1338, 10269, 2388, 8260, 4293, 14454, 7677, 11253, 5060, 14938, 8840, 4542, 2627, 16336, 8992, 15496, 11140, 446, 6126, 10691, 8624, 10127, 9068, 16710, 10155, 14358, 7567, 5695, 2354, 8057, 17635, 133, 16183, 14535, 7248, 4560, 14429, 2463, 10773, 113, 2462, 9223, 4929, 14274, 4716, 17307, 4617, 2132, 11083, 1039, 1403, 9621, 13936, 2229, 2875, 17840, 9359, 13311, 9790, 13288, 4750, 17052, 8260, 14900] |
| token_label_range = [0.2,0.6] |
| enable_time_shift = False |
| enable_label_enhance = False |
| enable_repeat_mode = False |
|
|
|
|
|
|
| |
| enable_tscam = True |
|
|
| |
| sample_rate = 32000 |
| clip_samples = sample_rate * 10 |
| window_size = 1024 |
| hop_size = 320 |
| mel_bins = 64 |
| fmin = 50 |
| fmax = 14000 |
| shift_max = int(clip_samples * 0.5) |
|
|
| |
| classes_num = 527 |
| patch_size = (25, 4) |
| crop_size = None |
|
|
| |
| htsat_window_size = 8 |
| htsat_spec_size = 256 |
| htsat_patch_size = 4 |
| htsat_stride = (4, 4) |
| htsat_num_head = [4,8,16,32] |
| htsat_dim = 96 |
| htsat_depth = [2,2,6,2] |
|
|
| swin_pretrain_path = None |
| |
|
|
| |
| htsat_attn_heatmap = False |
| htsat_hier_output = False |
| htsat_use_max = False |
|
|
|
|
| |
|
|
| ensemble_checkpoints = [] |
| ensemble_strides = [] |
|
|
|
|
| |
| wa_folder = "/home/version_0/checkpoints/" |
| |
| wa_model_path = "HTSAT_AudioSet_Saved_x.ckpt" |
|
|
| esm_model_pathes = [ |
| "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt", |
| "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_2.ckpt", |
| "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_3.ckpt", |
| "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_4.ckpt", |
| "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_5.ckpt", |
| "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_6.ckpt" |
| ] |
|
|
| |
| heatmap_dir = "/home/Research/heatmap_output" |
| test_file = "htsat-test-ensemble" |
| fl_local = False |
| fl_dataset = "/home/Research/desed/desedim_embval.npy" |
| fl_class_num = [ |
| "Speech", "Frying", "Dishes", "Running_water", |
| "Blender", "Electric_shaver_toothbrush", "Alarm_bell_ringing", |
| "Cat", "Dog", "Vacuum_cleaner" |
| ] |
|
|
| |
| fl_audioset_mapping = [ |
| [0,1,2,3,4,5,6,7], |
| [366, 367, 368], |
| [364], |
| [288, 289, 290, 291, 292, 293, 294, 295, 296, 297], |
| [369], |
| [382], |
| [310, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402], |
| [81, 82, 83, 84, 85], |
| [74, 75, 76, 77, 78, 79], |
| [377] |
| ] |
|
|
|
|
|
|
| def _ntuple(n): |
| def parse(x): |
| if isinstance(x, collections.abc.Iterable): |
| return x |
| return tuple(repeat(x, n)) |
| return parse |
|
|
| to_1tuple = _ntuple(1) |
| to_2tuple = _ntuple(2) |
| to_3tuple = _ntuple(3) |
| to_4tuple = _ntuple(4) |
| to_ntuple = _ntuple |
|
|
| def do_mixup(x, mixup_lambda): |
| """Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes |
| (1, 3, 5, ...). |
| Args: |
| x: (batch_size * 2, ...) |
| mixup_lambda: (batch_size * 2,) |
| Returns: |
| out: (batch_size, ...) |
| """ |
| out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \ |
| x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1) |
| return out |
|
|
| def interpolate(x, ratio): |
| """Interpolate data in time domain. This is used to compensate the |
| resolution reduction in downsampling of a CNN. |
| |
| Args: |
| x: (batch_size, time_steps, classes_num) |
| ratio: int, ratio to interpolate |
| Returns: |
| upsampled: (batch_size, time_steps * ratio, classes_num) |
| """ |
| (batch_size, time_steps, classes_num) = x.shape |
| upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) |
| upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) |
| return upsampled |
|
|
|
|
| def drop_path(x, drop_prob: float = 0., training: bool = False): |
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
| This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, |
| the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... |
| See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for |
| changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use |
| 'survival rate' as the argument. |
| """ |
| if drop_prob == 0. or not training: |
| return x |
| keep_prob = 1 - drop_prob |
| shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
| random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
| random_tensor.floor_() |
| output = x.div(keep_prob) * random_tensor |
| return output |
|
|
|
|
| class DropPath(nn.Module): |
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
| """ |
| def __init__(self, drop_prob=None): |
| super(DropPath, self).__init__() |
| self.drop_prob = drop_prob |
|
|
| def forward(self, x): |
| return drop_path(x, self.drop_prob, self.training) |
|
|
| class PatchEmbed(nn.Module): |
| """ 2D Image to Patch Embedding |
| """ |
| def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16): |
| super().__init__() |
| img_size = to_2tuple(img_size) |
| patch_size = to_2tuple(patch_size) |
| patch_stride = to_2tuple(patch_stride) |
| self.img_size = img_size |
| self.patch_size = patch_size |
| self.patch_stride = patch_stride |
| self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1]) |
| self.num_patches = self.grid_size[0] * self.grid_size[1] |
| self.flatten = flatten |
| self.in_chans = in_chans |
| self.embed_dim = embed_dim |
| |
| padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2) |
|
|
| self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding) |
| self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() |
|
|
| def forward(self, x): |
| B, C, H, W = x.shape |
| assert H == self.img_size[0] and W == self.img_size[1], \ |
| f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." |
| x = self.proj(x) |
| if self.flatten: |
| x = x.flatten(2).transpose(1, 2) |
| x = self.norm(x) |
| return x |
|
|
| class Mlp(nn.Module): |
| """ MLP as used in Vision Transformer, MLP-Mixer and related networks |
| """ |
| def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| self.fc1 = nn.Linear(in_features, hidden_features) |
| self.act = act_layer() |
| self.fc2 = nn.Linear(hidden_features, out_features) |
| self.drop = nn.Dropout(drop) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.drop(x) |
| x = self.fc2(x) |
| x = self.drop(x) |
| return x |
|
|
| def _no_gradim_audiorunc_normal_(tensor, mean, std, a, b): |
| |
| |
| def norm_cdf(x): |
| |
| return (1. + math.erf(x / math.sqrt(2.))) / 2. |
|
|
| if (mean < a - 2 * std) or (mean > b + 2 * std): |
| warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
| "The distribution of values may be incorrect.", |
| stacklevel=2) |
|
|
| with torch.no_grad(): |
| |
| |
| |
| l = norm_cdf((a - mean) / std) |
| u = norm_cdf((b - mean) / std) |
|
|
| |
| |
| tensor.uniform_(2 * l - 1, 2 * u - 1) |
|
|
| |
| |
| tensor.erfinv_() |
|
|
| |
| tensor.mul_(std * math.sqrt(2.)) |
| tensor.add_(mean) |
|
|
| |
| tensor.clamp_(min=a, max=b) |
| return tensor |
|
|
|
|
| def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): |
| |
| r"""Fills the input Tensor with values drawn from a truncated |
| normal distribution. The values are effectively drawn from the |
| normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` |
| with values outside :math:`[a, b]` redrawn until they are within |
| the bounds. The method used for generating the random values works |
| best when :math:`a \leq \text{mean} \leq b`. |
| Args: |
| tensor: an n-dimensional `torch.Tensor` |
| mean: the mean of the normal distribution |
| std: the standard deviation of the normal distribution |
| a: the minimum cutoff value |
| b: the maximum cutoff value |
| Examples: |
| >>> w = torch.empty(3, 5) |
| >>> nn.init.trunc_normal_(w) |
| """ |
| return _no_gradim_audiorunc_normal_(tensor, mean, std, a, b) |
|
|
|
|
| def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): |
| fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) |
| if mode == 'fan_in': |
| denom = fan_in |
| elif mode == 'fan_out': |
| denom = fan_out |
| elif mode == 'fan_avg': |
| denom = (fan_in + fan_out) / 2 |
|
|
| variance = scale / denom |
|
|
| if distribution == "truncated_normal": |
| |
| trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) |
| elif distribution == "normal": |
| tensor.normal_(std=math.sqrt(variance)) |
| elif distribution == "uniform": |
| bound = math.sqrt(3 * variance) |
| tensor.uniform_(-bound, bound) |
| else: |
| raise ValueError(f"invalid distribution {distribution}") |
|
|
|
|
| def lecun_normal_(tensor): |
| variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') |
|
|
|
|
| |
| |
|
|
| def window_partition(x, window_size): |
| """ |
| Args: |
| x: (B, H, W, C) |
| window_size (int): window size |
| Returns: |
| windows: (num_windows*B, window_size, window_size, C) |
| """ |
| B, H, W, C = x.shape |
| x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) |
| windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) |
| return windows |
|
|
|
|
| def window_reverse(windows, window_size, H, W): |
| """ |
| Args: |
| windows: (num_windows*B, window_size, window_size, C) |
| window_size (int): Window size |
| H (int): Height of image |
| W (int): Width of image |
| Returns: |
| x: (B, H, W, C) |
| """ |
| B = int(windows.shape[0] / (H * W / window_size / window_size)) |
| x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) |
| x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) |
| return x |
|
|
|
|
| class WindowAttention(nn.Module): |
| r""" Window based multi-head self attention (W-MSA) module with relative position bias. |
| It supports both of shifted and non-shifted window. |
| Args: |
| dim (int): Number of input channels. |
| window_size (tuple[int]): The height and width of the window. |
| num_heads (int): Number of attention heads. |
| qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
| qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set |
| attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 |
| proj_drop (float, optional): Dropout ratio of output. Default: 0.0 |
| """ |
|
|
| def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): |
|
|
| super().__init__() |
| self.dim = dim |
| self.window_size = window_size |
| self.num_heads = num_heads |
| head_dim = dim // num_heads |
| self.scale = qk_scale or head_dim ** -0.5 |
|
|
| |
| self.relative_position_bias_table = nn.Parameter( |
| torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) |
|
|
| |
| coords_h = torch.arange(self.window_size[0]) |
| coords_w = torch.arange(self.window_size[1]) |
| coords = torch.stack(torch.meshgrid([coords_h, coords_w])) |
| coords_flatten = torch.flatten(coords, 1) |
| relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] |
| relative_coords = relative_coords.permute(1, 2, 0).contiguous() |
| relative_coords[:, :, 0] += self.window_size[0] - 1 |
| relative_coords[:, :, 1] += self.window_size[1] - 1 |
| relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 |
| relative_position_index = relative_coords.sum(-1) |
| self.register_buffer("relative_position_index", relative_position_index) |
|
|
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| self.attn_drop = nn.Dropout(attn_drop) |
| self.proj = nn.Linear(dim, dim) |
| self.proj_drop = nn.Dropout(proj_drop) |
|
|
| trunc_normal_(self.relative_position_bias_table, std=.02) |
| self.softmax = nn.Softmax(dim=-1) |
|
|
| def forward(self, x, mask=None): |
| """ |
| Args: |
| x: input features with shape of (num_windows*B, N, C) |
| mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None |
| """ |
| B_, N, C = x.shape |
| qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| q = q * self.scale |
| attn = (q @ k.transpose(-2, -1)) |
|
|
| relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( |
| self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) |
| relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() |
| attn = attn + relative_position_bias.unsqueeze(0) |
|
|
| if mask is not None: |
| nW = mask.shape[0] |
| attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) |
| attn = attn.view(-1, self.num_heads, N, N) |
| attn = self.softmax(attn) |
| else: |
| attn = self.softmax(attn) |
|
|
| attn = self.attn_drop(attn) |
|
|
| x = (attn @ v).transpose(1, 2).reshape(B_, N, C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
| return x, attn |
|
|
| def extra_repr(self): |
| return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' |
|
|
|
|
| |
| class SwinTransformerBlock(nn.Module): |
| r""" Swin Transformer Block. |
| Args: |
| dim (int): Number of input channels. |
| input_resolution (tuple[int]): Input resulotion. |
| num_heads (int): Number of attention heads. |
| window_size (int): Window size. |
| shift_size (int): Shift size for SW-MSA. |
| mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. |
| qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
| qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. |
| drop (float, optional): Dropout rate. Default: 0.0 |
| attn_drop (float, optional): Attention dropout rate. Default: 0.0 |
| drop_path (float, optional): Stochastic depth rate. Default: 0.0 |
| act_layer (nn.Module, optional): Activation layer. Default: nn.GELU |
| norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
| """ |
|
|
| def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, |
| mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., |
| act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'): |
| super().__init__() |
| self.dim = dim |
| self.input_resolution = input_resolution |
| self.num_heads = num_heads |
| self.window_size = window_size |
| self.shift_size = shift_size |
| self.mlp_ratio = mlp_ratio |
| self.norm_before_mlp = norm_before_mlp |
| if min(self.input_resolution) <= self.window_size: |
| |
| self.shift_size = 0 |
| self.window_size = min(self.input_resolution) |
| assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" |
|
|
| self.norm1 = norm_layer(dim) |
| self.attn = WindowAttention( |
| dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, |
| qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) |
|
|
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| if self.norm_before_mlp == 'ln': |
| self.norm2 = nn.LayerNorm(dim) |
| elif self.norm_before_mlp == 'bn': |
| self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2) |
| else: |
| raise NotImplementedError |
| mlp_hidden_dim = int(dim * mlp_ratio) |
| self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
|
|
| if self.shift_size > 0: |
| |
| H, W = self.input_resolution |
| img_mask = torch.zeros((1, H, W, 1)) |
| h_slices = (slice(0, -self.window_size), |
| slice(-self.window_size, -self.shift_size), |
| slice(-self.shift_size, None)) |
| w_slices = (slice(0, -self.window_size), |
| slice(-self.window_size, -self.shift_size), |
| slice(-self.shift_size, None)) |
| cnt = 0 |
| for h in h_slices: |
| for w in w_slices: |
| img_mask[:, h, w, :] = cnt |
| cnt += 1 |
|
|
| mask_windows = window_partition(img_mask, self.window_size) |
| mask_windows = mask_windows.view(-1, self.window_size * self.window_size) |
| attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) |
| attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) |
| else: |
| attn_mask = None |
|
|
| self.register_buffer("attn_mask", attn_mask) |
|
|
| def forward(self, x): |
| |
| H, W = self.input_resolution |
| |
| |
| |
| B, L, C = x.shape |
| |
|
|
| shortcut = x |
| x = self.norm1(x) |
| x = x.view(B, H, W, C) |
|
|
| |
| if self.shift_size > 0: |
| shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) |
| else: |
| shifted_x = x |
|
|
| |
| x_windows = window_partition(shifted_x, self.window_size) |
| x_windows = x_windows.view(-1, self.window_size * self.window_size, C) |
|
|
| |
| attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) |
|
|
| |
| attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) |
| shifted_x = window_reverse(attn_windows, self.window_size, H, W) |
|
|
| |
| if self.shift_size > 0: |
| x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) |
| else: |
| x = shifted_x |
| x = x.view(B, H * W, C) |
|
|
| |
| x = shortcut + self.drop_path(x) |
| x = x + self.drop_path(self.mlp(self.norm2(x))) |
|
|
| return x, attn |
|
|
| def extra_repr(self): |
| return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ |
| f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" |
|
|
|
|
|
|
| class PatchMerging(nn.Module): |
| r""" Patch Merging Layer. |
| Args: |
| input_resolution (tuple[int]): Resolution of input feature. |
| dim (int): Number of input channels. |
| norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
| """ |
|
|
| def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): |
| super().__init__() |
| self.input_resolution = input_resolution |
| self.dim = dim |
| self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) |
| self.norm = norm_layer(4 * dim) |
|
|
| def forward(self, x): |
| """ |
| x: B, H*W, C |
| """ |
| H, W = self.input_resolution |
| B, L, C = x.shape |
| assert L == H * W, "input feature has wrong size" |
| assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." |
|
|
| x = x.view(B, H, W, C) |
|
|
| x0 = x[:, 0::2, 0::2, :] |
| x1 = x[:, 1::2, 0::2, :] |
| x2 = x[:, 0::2, 1::2, :] |
| x3 = x[:, 1::2, 1::2, :] |
| x = torch.cat([x0, x1, x2, x3], -1) |
| x = x.view(B, -1, 4 * C) |
|
|
| x = self.norm(x) |
| x = self.reduction(x) |
|
|
| return x |
|
|
| def extra_repr(self): |
| return f"input_resolution={self.input_resolution}, dim={self.dim}" |
|
|
|
|
| class BasicLayer(nn.Module): |
| """ A basic Swin Transformer layer for one stage. |
| Args: |
| dim (int): Number of input channels. |
| input_resolution (tuple[int]): Input resolution. |
| depth (int): Number of blocks. |
| num_heads (int): Number of attention heads. |
| window_size (int): Local window size. |
| mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. |
| qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
| qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. |
| drop (float, optional): Dropout rate. Default: 0.0 |
| attn_drop (float, optional): Attention dropout rate. Default: 0.0 |
| drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 |
| norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
| downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None |
| use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. |
| """ |
|
|
| def __init__(self, dim, input_resolution, depth, num_heads, window_size, |
| mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., |
| drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, |
| norm_before_mlp='ln'): |
|
|
| super().__init__() |
| self.dim = dim |
| self.input_resolution = input_resolution |
| self.depth = depth |
| self.use_checkpoint = use_checkpoint |
|
|
| |
| self.blocks = nn.ModuleList([ |
| SwinTransformerBlock(dim=dim, input_resolution=input_resolution, |
| num_heads=num_heads, window_size=window_size, |
| shift_size=0 if (i % 2 == 0) else window_size // 2, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, qk_scale=qk_scale, |
| drop=drop, attn_drop=attn_drop, |
| drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, |
| norm_layer=norm_layer, norm_before_mlp=norm_before_mlp) |
| for i in range(depth)]) |
|
|
| |
| if downsample is not None: |
| self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) |
| else: |
| self.downsample = None |
|
|
| def forward(self, x): |
| attns = [] |
| for blk in self.blocks: |
| if self.use_checkpoint: |
| x = checkpoint.checkpoint(blk, x) |
| else: |
| x, attn = blk(x) |
| if not self.training: |
| attns.append(attn.unsqueeze(0)) |
| if self.downsample is not None: |
| x = self.downsample(x) |
| if not self.training: |
| attn = torch.cat(attns, dim = 0) |
| attn = torch.mean(attn, dim = 0) |
| return x, attn |
|
|
| def extra_repr(self): |
| return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" |
|
|
|
|
| |
| class HTSAT_Swin_Transformer(nn.Module): |
| r"""HTSAT based on the Swin Transformer |
| Args: |
| spec_size (int | tuple(int)): Input Spectrogram size. Default 256 |
| patch_size (int | tuple(int)): Patch size. Default: 4 |
| path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4 |
| in_chans (int): Number of input image channels. Default: 1 (mono) |
| num_classes (int): Number of classes for classification head. Default: 527 |
| embed_dim (int): Patch embedding dimension. Default: 96 |
| depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer. |
| num_heads (tuple(int)): Number of attention heads in different layers. |
| window_size (int): Window size. Default: 8 |
| mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 |
| qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True |
| qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None |
| drop_rate (float): Dropout rate. Default: 0 |
| attn_drop_rate (float): Attention dropout rate. Default: 0 |
| drop_path_rate (float): Stochastic depth rate. Default: 0.1 |
| norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. |
| ape (bool): If True, add absolute position embedding to the patch embedding. Default: False |
| patch_norm (bool): If True, add normalization after patch embedding. Default: True |
| use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False |
| config (module): The configuration Module from config.py (HTSATConfig Class) |
| """ |
|
|
| def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4), |
| in_chans=1, num_classes=527, |
| embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32], |
| window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None, |
| drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, |
| norm_layer=nn.LayerNorm, |
| ape=False, patch_norm=True, |
| use_checkpoint=False, norm_before_mlp='ln', config = None, **kwargs): |
| super(HTSAT_Swin_Transformer, self).__init__() |
|
|
| self.config = config |
| self.spec_size = spec_size |
| self.patch_stride = patch_stride |
| self.patch_size = patch_size |
| self.window_size = window_size |
| self.embed_dim = embed_dim |
| self.depths = depths |
| self.ape = ape |
| self.in_chans = in_chans |
| self.num_classes = num_classes |
| self.num_heads = num_heads |
| self.num_layers = len(self.depths) |
| self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1)) |
| |
| self.drop_rate = drop_rate |
| self.attn_drop_rate = attn_drop_rate |
| self.drop_path_rate = drop_path_rate |
|
|
| self.qkv_bias = qkv_bias |
| self.qk_scale = None |
|
|
| self.patch_norm = patch_norm |
| self.norm_layer = norm_layer if self.patch_norm else None |
| self.norm_before_mlp = norm_before_mlp |
| self.mlp_ratio = mlp_ratio |
|
|
| self.use_checkpoint = use_checkpoint |
|
|
| |
| self.freq_ratio = self.spec_size // self.config.mel_bins |
| window = 'hann' |
| center = True |
| pad_mode = 'reflect' |
| ref = 1.0 |
| amin = 1e-10 |
| top_db = None |
| self.interpolate_ratio = 32 |
| |
| self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size, |
| win_length=config.window_size, window=window, center=center, pad_mode=pad_mode, |
| freeze_parameters=True) |
| |
| self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size, |
| n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db, |
| freeze_parameters=True) |
| |
| self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, |
| freq_drop_width=8, freq_stripes_num=2) |
| self.bn0 = nn.BatchNorm2d(self.config.mel_bins) |
|
|
|
|
| |
| self.patch_embed = PatchEmbed( |
| img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans, |
| embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride) |
|
|
| num_patches = self.patch_embed.num_patches |
| patches_resolution = self.patch_embed.grid_size |
| self.patches_resolution = patches_resolution |
|
|
| |
| if self.ape: |
| self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim)) |
| trunc_normal_(self.absolute_pos_embed, std=.02) |
|
|
| self.pos_drop = nn.Dropout(p=self.drop_rate) |
|
|
| |
| dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] |
|
|
| |
| self.layers = nn.ModuleList() |
| for i_layer in range(self.num_layers): |
| layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer), |
| input_resolution=(patches_resolution[0] // (2 ** i_layer), |
| patches_resolution[1] // (2 ** i_layer)), |
| depth=self.depths[i_layer], |
| num_heads=self.num_heads[i_layer], |
| window_size=self.window_size, |
| mlp_ratio=self.mlp_ratio, |
| qkv_bias=self.qkv_bias, qk_scale=self.qk_scale, |
| drop=self.drop_rate, attn_drop=self.attn_drop_rate, |
| drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])], |
| norm_layer=self.norm_layer, |
| downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, |
| use_checkpoint=use_checkpoint, |
| norm_before_mlp=self.norm_before_mlp) |
| self.layers.append(layer) |
|
|
| self.norm = self.norm_layer(self.num_features) |
| self.avgpool = nn.AdaptiveAvgPool1d(1) |
| self.maxpool = nn.AdaptiveMaxPool1d(1) |
|
|
| if self.config.enable_tscam: |
| SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio |
| self.tscam_conv = nn.Conv2d( |
| in_channels = self.num_features, |
| out_channels = self.num_classes, |
| kernel_size = (SF,3), |
| padding = (0,1) |
| ) |
| self.head = nn.Linear(num_classes, num_classes) |
| else: |
| self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() |
|
|
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| trunc_normal_(m.weight, std=.02) |
| if isinstance(m, nn.Linear) and m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.constant_(m.bias, 0) |
| nn.init.constant_(m.weight, 1.0) |
|
|
| @torch.jit.ignore |
| def no_weight_decay(self): |
| return {'absolute_pos_embed'} |
|
|
| @torch.jit.ignore |
| def no_weight_decay_keywords(self): |
| return {'relative_position_bias_table'} |
|
|
| def forward_features(self, x): |
| frames_num = x.shape[2] |
| x = self.patch_embed(x) |
| if self.ape: |
| x = x + self.absolute_pos_embed |
| x = self.pos_drop(x) |
| for i, layer in enumerate(self.layers): |
| x, attn = layer(x) |
|
|
| if self.config.enable_tscam: |
| |
| x = self.norm(x) |
| B, N, C = x.shape |
| SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] |
| ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1] |
| x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST) |
| B, C, F, T = x.shape |
| |
| c_freq_bin = F // self.freq_ratio |
| x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T) |
| x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1) |
|
|
| |
| latent_output = self.avgpool(torch.flatten(x,2)) |
| latent_output = torch.flatten(latent_output, 1) |
|
|
| |
| if self.config.htsat_attn_heatmap: |
| |
| attn = torch.mean(attn, dim = 1) |
| attn = torch.mean(attn, dim = 1) |
| attn = attn.reshape(B, SF, ST) |
| c_freq_bin = SF // self.freq_ratio |
| attn = attn.reshape(B, SF // c_freq_bin, c_freq_bin, ST) |
| attn = attn.permute(0,2,1,3).contiguous().reshape(B, c_freq_bin, -1) |
| attn = attn.mean(dim = 1) |
| attn_max = torch.max(attn, dim = 1, keepdim = True)[0] |
| attn_min = torch.min(attn, dim = 1, keepdim = True)[0] |
| attn = ((attn * 0.15) + (attn_max * 0.85 - attn_min)) / (attn_max - attn_min) |
| attn = attn.unsqueeze(dim = 2) |
|
|
| x = self.tscam_conv(x) |
| x = torch.flatten(x, 2) |
|
|
| if self.config.htsat_attn_heatmap: |
| fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous() * attn, 8 * self.patch_stride[1]) |
| else: |
| fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1]) |
| |
| x = self.avgpool(x) |
| x = torch.flatten(x, 1) |
|
|
| if self.config.loss_type == "clip_ce": |
| output_dict = { |
| 'framewise_output': fpx, |
| 'clipwise_output': x, |
| 'latent_output': latent_output |
| } |
| else: |
| output_dict = { |
| 'framewise_output': fpx, |
| 'clipwise_output': torch.sigmoid(x), |
| 'latent_output': latent_output |
| } |
| |
| else: |
| x = self.norm(x) |
| B, N, C = x.shape |
| |
| fpx = x.permute(0,2,1).contiguous().reshape(B, C, frames_num // (2 ** (len(self.depths) + 1)), frames_num // (2 ** (len(self.depths) + 1)) ) |
| B, C, F, T = fpx.shape |
| c_freq_bin = F // self.freq_ratio |
| fpx = fpx.reshape(B, C, F // c_freq_bin, c_freq_bin, T) |
| fpx = fpx.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1) |
| fpx = torch.sum(fpx, dim = 2) |
| fpx = interpolate(fpx.permute(0,2,1).contiguous(), 8 * self.patch_stride[1]) |
| x = self.avgpool(x.transpose(1, 2)) |
| x = torch.flatten(x, 1) |
| if self.num_classes > 0: |
| x = self.head(x) |
| fpx = self.head(fpx) |
| output_dict = {'framewise_output': torch.sigmoid(fpx), |
| 'clipwise_output': torch.sigmoid(x)} |
| return output_dict |
|
|
| def crop_wav(self, x, crop_size, spe_pos = None): |
| time_steps = x.shape[2] |
| tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device) |
| for i in range(len(x)): |
| if spe_pos is None: |
| crop_pos = random.randint(0, time_steps - crop_size - 1) |
| else: |
| crop_pos = spe_pos |
| tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:] |
| return tx |
|
|
| |
| def reshape_wav2img(self, x): |
| B, C, T, F = x.shape |
| target_T = int(self.spec_size * self.freq_ratio) |
| target_F = self.spec_size // self.freq_ratio |
| assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size" |
| |
| if T < target_T: |
| x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True) |
| if F < target_F: |
| x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True) |
| x = x.permute(0,1,3,2).contiguous() |
| x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio) |
| |
| x = x.permute(0,1,3,2,4).contiguous() |
| x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4]) |
| return x |
| |
| |
| def repeat_wat2img(self, x, cur_pos): |
| B, C, T, F = x.shape |
| target_T = int(self.spec_size * self.freq_ratio) |
| target_F = self.spec_size // self.freq_ratio |
| assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size" |
| |
| if T < target_T: |
| x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True) |
| if F < target_F: |
| x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True) |
| x = x.permute(0,1,3,2).contiguous() |
| x = x[:,:,:,cur_pos:cur_pos + self.spec_size] |
| x = x.repeat(repeats = (1,1,4,1)) |
| return x |
|
|
| def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False): |
| x = self.spectrogram_extractor(x) |
| x = self.logmel_extractor(x) |
| |
| |
| x = x.transpose(1, 3) |
| x = self.bn0(x) |
| x = x.transpose(1, 3) |
| if self.training: |
| x = self.spec_augmenter(x) |
| if self.training and mixup_lambda is not None: |
| x = do_mixup(x, mixup_lambda) |
| |
| if infer_mode: |
| |
| frame_num = x.shape[2] |
| target_T = int(self.spec_size * self.freq_ratio) |
| repeat_ratio = math.floor(target_T / frame_num) |
| x = x.repeat(repeats=(1,1,repeat_ratio,1)) |
| x = self.reshape_wav2img(x) |
| output_dict = self.forward_features(x) |
| elif self.config.enable_repeat_mode: |
| if self.training: |
| cur_pos = random.randint(0, (self.freq_ratio - 1) * self.spec_size - 1) |
| x = self.repeat_wat2img(x, cur_pos) |
| output_dict = self.forward_features(x) |
| else: |
| output_dicts = [] |
| for cur_pos in range(0, (self.freq_ratio - 1) * self.spec_size + 1, self.spec_size): |
| tx = x.clone() |
| tx = self.repeat_wat2img(tx, cur_pos) |
| output_dicts.append(self.forward_features(tx)) |
| clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device) |
| framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device) |
| for d in output_dicts: |
| clipwise_output += d["clipwise_output"] |
| framewise_output += d["framewise_output"] |
| clipwise_output = clipwise_output / len(output_dicts) |
| framewise_output = framewise_output / len(output_dicts) |
|
|
| output_dict = { |
| 'framewise_output': framewise_output, |
| 'clipwise_output': clipwise_output |
| } |
| else: |
| if x.shape[2] > self.freq_ratio * self.spec_size: |
| if self.training: |
| x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size) |
| x = self.reshape_wav2img(x) |
| output_dict = self.forward_features(x) |
| else: |
| |
| overlap_size = 344 |
| output_dicts = [] |
| crop_size = 689 |
| for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size): |
| tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos) |
| tx = self.reshape_wav2img(tx) |
| output_dicts.append(self.forward_features(tx)) |
| clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device) |
| framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device) |
| latent_output = torch.zeros_like(output_dicts[0]["latent_output"]).float().to(x.device) |
| for d in output_dicts: |
| clipwise_output += d["clipwise_output"] |
| framewise_output += d["framewise_output"] |
| latent_output += d["latent_output"] |
| clipwise_output = clipwise_output / len(output_dicts) |
| framewise_output = framewise_output / len(output_dicts) |
| latent_output = latent_output / len(output_dicts) |
| output_dict = { |
| 'framewise_output': framewise_output, |
| 'clipwise_output': clipwise_output, |
| 'latent_output': latent_output, |
| } |
| else: |
| x = self.reshape_wav2img(x) |
| output_dict = self.forward_features(x) |
| |
| return output_dict |
|
|
| class HTSATWrapper(nn.Module): |
| def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, |
| fmax, classes_num, out_emb): |
| super().__init__() |
|
|
| |
| |
| |
|
|
| self.htsat = HTSAT_Swin_Transformer(config=HTSATConfig()) |
|
|
| def forward(self, x): |
| out_dict = self.htsat(x) |
| out_dict['embedding'] = out_dict['latent_output'] |
| return out_dict |
|
|
|
|
| def get_audio_encoder(name: str): |
| if name == "HTSAT": |
| return HTSATWrapper |
| else: |
| raise Exception('The audio encoder name {} is incorrect or not supported'.format(name)) |
|
|
| class Projection(nn.Module): |
| def __init__(self, dim_imgn: int, d_out: int, p: float=0.5) -> None: |
| super().__init__() |
| self.linear1 = nn.Linear(dim_imgn, d_out, bias=False) |
| self.linear2 = nn.Linear(d_out, d_out, bias=False) |
| self.layer_norm = nn.LayerNorm(d_out) |
| self.drop = nn.Dropout(p) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| embed1 = self.linear1(x) |
| embed2 = self.drop(self.linear2(F.gelu(embed1))) |
| embeds = self.layer_norm(embed1 + embed2) |
| return embeds |
|
|
| class AudioEncoder(nn.Module): |
| def __init__(self, audioenc_name:str, dim_imgn: int, d_out: int, sample_rate: int, window_size: int, |
| hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None: |
| super().__init__() |
|
|
| audio_encoder = get_audio_encoder(audioenc_name) |
|
|
| self.base = audio_encoder( |
| sample_rate, window_size, |
| hop_size, mel_bins, fmin, fmax, |
| classes_num, dim_imgn) |
|
|
| self.projection = Projection(dim_imgn, d_out) |
|
|
| def forward(self, x): |
| out_dict = self.base(x) |
| audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output'] |
| projected_vec = self.projection(audio_features) |
| return projected_vec, audio_classification_output |
|
|
| class TextEncoder(nn.Module): |
| def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None: |
| super().__init__() |
| self.text_model = text_model |
| self.base = AutoModel.from_pretrained(text_model) |
|
|
| if 'clip' in text_model: |
| self.clip_text_projection = self.base.text_projection |
| self.base = self.base.text_model |
| if 'base' in text_model: |
| transformer_embed_dim = 512 |
| |
| self.projection = Projection(transformer_embed_dim, d_out) |
|
|
| def forward(self, x): |
| if 'clip' in self.text_model: |
| pooled_output = self.base(**x)[1] |
| out = self.clip_text_projection(pooled_output) |
| elif 'gpt' in self.text_model: |
| batch_size = x['input_ids'].shape[0] |
| hidden_states = self.base(**x)[0] |
|
|
| sequence_lengths = torch.ne(x['input_ids'], 0).sum(-1) - 1 |
| out = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths] |
| else: |
| out = self.base(**x)[0] |
| out = out[:, 0, :] |
| |
| projected_vec = self.projection(out) |
|
|
| return projected_vec |
|
|
| class CLAP(nn.Module): |
| def __init__(self, |
| |
| audioenc_name: str, |
| sample_rate: int, |
| window_size: int, |
| hop_size: int, |
| mel_bins: int, |
| fmin: int, |
| fmax: int, |
| classes_num: int, |
| out_emb: int, |
| |
| text_model: str, |
| transformer_embed_dim: int, |
| |
| d_proj: int, |
| ): |
| super().__init__() |
|
|
| |
| self.audio_encoder = AudioEncoder( |
| audioenc_name, out_emb, d_proj, |
| sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num) |
|
|
| self.caption_encoder = TextEncoder( |
| d_proj, text_model, transformer_embed_dim |
| ) |
|
|
| self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
|
| def forward(self, audio, text): |
| audio_embed, _ = self.audio_encoder(audio) |
| caption_embed = self.caption_encoder(text) |
|
|
| return caption_embed, audio_embed, self.logit_scale.exp() |
| |
| |
| |
| |
| |
| |
| def read_audio(audio_path, resample=True): |
| r"""Loads audio file or array and returns a torch tensor""" |
| |
| audio_time_series, sample_rate = torchaudio.load(audio_path) |
|
|
| resample_rate = clapConfig.sample_rate |
| if resample and resample_rate != sample_rate: |
| resampler = T.Resample(sample_rate, resample_rate) |
| audio_time_series = resampler(audio_time_series) |
| return audio_time_series, resample_rate |
|
|
| def load_audio_into_tensor(audio_path, audio_duration, resample=False): |
| r"""Loads audio file and returns raw audio.""" |
| |
| audio_time_series, sample_rate = read_audio(audio_path, resample) |
| audio_time_series = audio_time_series.reshape(-1) |
|
|
| |
| |
| if audio_duration*sample_rate >= audio_time_series.shape[0]: |
| repeat_factor = int(np.ceil((audio_duration*sample_rate) / |
| audio_time_series.shape[0])) |
| |
| audio_time_series = audio_time_series.repeat(repeat_factor) |
| |
| audio_time_series = audio_time_series[0:audio_duration*sample_rate] |
| else: |
| |
| |
| start_index = random.randrange( |
| audio_time_series.shape[0] - audio_duration*sample_rate) |
| audio_time_series = audio_time_series[start_index:start_index + |
| audio_duration*sample_rate] |
| return torch.FloatTensor(audio_time_series) |
|
|
| np_str_obj_array_pattern = re.compile(r'[SaUO]') |
| default_collate_err_msg_format = ( |
| "default_collate: batch must contain tensors, numpy arrays, numbers, " |
| "dicts or lists; found {}") |
|
|
| def default_collate(batch): |
| r"""Puts each data field into a tensor with outer dimension batch size""" |
| elem = batch[0] |
| elem_type = type(elem) |
| if isinstance(elem, torch.Tensor): |
| out = None |
| if torch.utils.data.get_worker_info() is not None: |
| |
| |
| numel = sum([x.numel() for x in batch]) |
| storage = elem.storage()._new_shared(numel) |
| out = elem.new(storage) |
| return torch.stack(batch, 0, out=out) |
| elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ |
| and elem_type.__name__ != 'string_': |
| if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': |
| |
| if np_str_obj_array_pattern.search(elem.dtype.str) is not None: |
| raise TypeError( |
| default_collate_err_msg_format.format(elem.dtype)) |
|
|
| return default_collate([torch.as_tensor(b) for b in batch]) |
| elif elem.shape == (): |
| return torch.as_tensor(batch) |
| elif isinstance(elem, float): |
| return torch.tensor(batch, dtype=torch.float64) |
| elif isinstance(elem, int): |
| return torch.tensor(batch) |
| elif isinstance(elem, str): |
| return batch |
| elif isinstance(elem, collections.abc.Mapping): |
| return {key: default_collate([d[key] for d in batch]) for key in elem} |
| elif isinstance(elem, tuple) and hasattr(elem, '_fields'): |
| return elem_type(*(default_collate(samples) for samples in zip(*batch))) |
| elif isinstance(elem, collections.abc.Sequence): |
| |
| it = iter(batch) |
| elem_size = len(next(it)) |
| if not all(len(elem) == elem_size for elem in it): |
| raise RuntimeError( |
| 'each element in list of batch should be of equal size') |
| transposed = zip(*batch) |
| return [default_collate(samples) for samples in transposed] |
|
|
| raise TypeError(default_collate_err_msg_format.format(elem_type)) |
|
|
| def preprocess_audio(audio_files, resample): |
| r"""Load list of audio files and return raw audio""" |
| audio_tensors = [] |
| for audio_file in audio_files: |
| audio_tensor = load_audio_into_tensor( |
| audio_file, clapConfig.duration, resample) |
| audio_tensor = audio_tensor.reshape(1, -1) |
| audio_tensors.append(audio_tensor) |
| return default_collate(audio_tensors) |
|
|
|
|
|
|
| |
| |
| |
| def CLAPAudioProcessor(audio_files: List[str], resample=True): |
| preprocessed_audio = preprocess_audio(audio_files, resample) |
| preprocessed_audio = preprocessed_audio.reshape( |
| preprocessed_audio.shape[0], preprocessed_audio.shape[2]) |
| preprocessed_audio = preprocessed_audio |
| return preprocessed_audio |
|
|
| def get_audio_embeddings(audio_files: List[str], audio_encoder, resample=True): |
| """Load list of audio files and return audio embeddings""" |
| |
| |
| |
| |
| with torch.no_grad(): |
| preprocessed_audio = CLAPAudioProcessor(audio_files, resample) |
| return audio_encoder(preprocessed_audio)[0] |
|
|
|
|
| |
| |
| |
| class ClapConfig: |
| |
| text_model = 'gpt2' |
| text_len = 77 |
| transformer_embed_dim = 768 |
| freeze_text_encoder_weights = True |
|
|
| |
| audioenc_name = 'HTSAT' |
| out_emb = 768 |
| sample_rate = 44100 |
| duration = 7 |
| fmin = 50 |
| fmax = 8000 |
| n_fft = 1024 |
| hop_size = 320 |
| mel_bins = 64 |
| window_size = 1024 |
|
|
| |
| d_proj = 1024 |
| temperature = 0.003 |
|
|
| |
| num_classes = 527 |
| batch_size = 1024 |
| demo = False |
| |
|
|
| clapConfig = ClapConfig() |
| clap = CLAP( |
| audioenc_name=clapConfig.audioenc_name, |
| sample_rate=clapConfig.sample_rate, |
| window_size=clapConfig.window_size, |
| hop_size=clapConfig.hop_size, |
| mel_bins=clapConfig.mel_bins, |
| fmin=clapConfig.fmin, |
| fmax=clapConfig.fmax, |
| classes_num=clapConfig.num_classes, |
| out_emb=clapConfig.out_emb, |
| text_model=clapConfig.text_model, |
| transformer_embed_dim=clapConfig.transformer_embed_dim, |
| d_proj=clapConfig.d_proj |
| ) |
|
|
| model_repo = "microsoft/msclap" |
| model_name = { |
| '2022': 'CLAP_weights_2022.pth', |
| '2023': 'CLAP_weights_2023.pth', |
| 'clapcap': 'clapcap_weights_2023.pth' |
| } |
|
|
| version = '2023' |
| model_fp = hf_hub_download(model_repo, model_name[version]) |
|
|
| model_state_dict = torch.load(model_fp, map_location=torch.device('cpu'))['model'] |
| clap.load_state_dict(model_state_dict, strict=False) |
| |
|
|
| clap_audio_encoder = clap.audio_encoder.to(device) |
|
|
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| LoRAconfig = { |
| "peft_type": "LORA", |
| "task_type": "FEATURE_EXTRACTION", |
| "inference_mode": False, |
| "r": 16, |
| "target_modules": ["qkv", "fc1", "fc2", "proj", "linear1", "linear2"], |
| "lora_alpha": 32, |
| "lora_dropout": 0.05, |
| "fan_in_fan_out": False, |
| "bias": "all", |
| } |
| peft_config = get_peft_config(LoRAconfig) |
|
|
| peft_model = get_peft_model(clap_audio_encoder, peft_config) |
|
|
| peft_model.print_trainable_parameters() |
|
|
| peft_clap_audio_encoder = peft_model.base_model |
| |
| |
|
|
|
|
|
|
| |
| |
| |
| import open_clip |
| open_clip_model, open_clip_imgaug, open_clip_preprocess = open_clip.create_model_and_transforms( |
| model_name='ViT-H-14', pretrained='laion2b_s32b_b79k', device=device |
| ) |
|
|
|
|
| |
| |
| |
| class CSIP(nn.Module): |
| def __init__(self, image_encoder, audio_encoder, |
| dim_img=None, dim_audio=1024, dim_emb=1024): |
| super(CSIP, self).__init__() |
| |
| self.image_encoder = image_encoder |
| self.audio_encoder = audio_encoder |
|
|
| for param in self.image_encoder.parameters(): |
| param.requires_grad = False |
| |
| |
| self.audio_proj = nn.Linear(dim_audio, dim_emb) |
|
|
| |
| self.log_temp = nn.Parameter(torch.tensor(0.07).log()) |
|
|
| def forward(self, images, audios): |
| |
| |
| |
| image_features = self.image_encoder(images).norm(dim=-1, keepdim=True) |
| audio_features = self.audio_encoder(audios)[0].norm(dim=-1, keepdim=True) |
|
|
| |
| image_embeds = F.normalize(image_features) |
| audio_embeds = F.normalize(self.audio_proj(audio_features), dim=1) |
|
|
| |
| logits = torch.matmul(image_embeds, audio_embeds.T) * self.log_temp.exp() |
| probs = logits.softmax(dim=1) |
|
|
| |
| labels = torch.arange(len(images), device=images.device) |
| |
| loss_a = F.cross_entropy(logits.T, labels) |
| |
| loss = loss_a |
| return loss, logits, probs |
|
|
|
|
| |
| |
| |
| class VaaniImageAudioDataset(torch.utils.data.Dataset): |
| def __init__(self, df): |
| self.image_paths = df.image_path.tolist() |
| self.audio_paths = df.audio_path.tolist() |
|
|
| def __len__(self): |
| return len(self.audio_paths) |
|
|
| def __getitem__(self, idx): |
| return { |
| 'image_path': self.image_paths[idx], |
| 'audio_path': self.audio_paths[idx] |
| } |
|
|
|
|
| def collate_fn(batch): |
| image_tensor = open_clip_imgaug([Image.open(item['image_path']) for item in batch])['pixel_values'] |
| audio_tensor = CLAPAudioProcessor([item['audio_path'] for item in batch], resample=True) |
| return {'image_tensor': torch.stack(image_tensor), 'audio_tensor': audio_tensor} |
|
|
|
|
| |
| |
|
|
| train_df = pd.read_csv("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TRAIN.csv") |
| test_df = pd.read_csv("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TEST.csv") |
| train_dataset = VaaniImageAudioDataset(train_df) |
| test_dataset = VaaniImageAudioDataset(test_df) |
| BATCH_SIZE = 64 |
|
|
| print('Train Dataset:', len(train_dataset)) |
| print('Test Dataset:', len(test_dataset)) |
|
|
|
|
| train_dataloader = torch.utils.data.DataLoader( |
| train_dataset, |
| batch_size=BATCH_SIZE, |
| shuffle=True, |
| num_workers=48, |
| collate_fn=collate_fn, |
| pin_memory=True, |
| drop_last=False, |
| persistent_workers=True |
| ) |
|
|
| test_dataloader = torch.utils.data.DataLoader( |
| test_dataset, |
| batch_size=BATCH_SIZE, |
| shuffle=False, |
| num_workers=48, |
| collate_fn=collate_fn, |
| pin_memory=True, |
| drop_last=False, |
| persistent_workers=True |
| ) |
|
|
| batch = next(iter(train_dataloader)) |
| image_tensor_batch = batch['image_tensor'] |
| audio_tensor_batch = batch['audio_tensor'] |
| print("Image batch shape:", image_tensor_batch.shape) |
| print("Audio batch shape:", audio_tensor_batch.shape) |
|
|
|
|
| csip_model = CSIP(open_clip_model.visual, peft_clap_audio_encoder).to(device) |
|
|
| from torchinfo import summary |
| import subprocess |
| summary(model=csip_model, |
| input_data=((image_tensor_batch.to(device)), (audio_tensor_batch.to(device))), |
| |
| dtypes=[torch.long], |
| col_names = ["input_size", "output_size", "num_params", "trainable", "params_percent"], |
| col_width=20, |
| row_settings=["var_names"], |
| depth = 2, |
| |
| |
| ) |
|
|
| |
| |
|
|
|
|
| def train_batch(model, images, audio, optimizer): |
| model.train() |
| optimizer.zero_grad() |
| loss, logits, probs = model(images, audio) |
| loss.backward() |
| optimizer.step() |
| return loss.item(), logits, probs |
|
|
| @torch.no_grad() |
| def evaluate_batch(model, images, audio): |
| model.eval() |
| loss, logits, probs = model(images, audio) |
| return loss.item(), logits, probs |
|
|
| def save_checkpoint(state, checkpoint_dir, epoch, max_checkpoints=2): |
| filename = f"csip_best_epoch_{epoch+1}.pt" |
| path = os.path.join(checkpoint_dir, filename) |
| torch.save(state, path) |
| checkpoints = sorted( |
| [f for f in os.listdir(checkpoint_dir) if f.startswith("csip_best_epoch_")], |
| key=lambda x: int(x.split("_")[-1].split(".")[0]) |
| ) |
| while len(checkpoints) > max_checkpoints: |
| to_delete = checkpoints.pop(0) |
| os.remove(os.path.join(checkpoint_dir, to_delete)) |
|
|
|
|
| def load_checkpoint(checkpoint_dir, model, optimizer, scheduler): |
| checkpoints = sorted( |
| [f for f in os.listdir(checkpoint_dir) if f.startswith("csip_best_epoch_")], |
| key=lambda x: int(x.split("_")[-1].split(".")[0]) |
| ) |
| if not checkpoints: |
| print("No checkpoint found to resume from.") |
| return 0, float("inf") |
|
|
| best_ckpt = checkpoints[-1] |
| path = os.path.join(checkpoint_dir, best_ckpt) |
| checkpoint = torch.load(path) |
| model.load_state_dict(checkpoint['model_state']) |
| optimizer.load_state_dict(checkpoint['optimizer_state']) |
| scheduler.load_state_dict(checkpoint['scheduler_state']) |
| start_epoch = checkpoint['epoch'] |
| best_loss = checkpoint['best_loss'] |
| print(f"Resumed training from epoch {start_epoch+1} with best loss {best_loss:.4f}") |
| return start_epoch, best_loss |
|
|
|
|
| def fig_to_tensor(fig): |
| """Convert a Matplotlib figure to a tensor suitable for TensorBoard.""" |
| buf = io.BytesIO() |
| fig.savefig(buf, format='png') |
| buf.seek(0) |
| image = Image.open(buf).convert("RGB") |
| tensor = tv.transforms.functional.to_tensor(image) |
| buf.close() |
| plt.close(fig) |
| return tensor |
|
|
| def save_similarity_heatmaps(logits, epoch, loss, save_dir, writer): |
| os.makedirs(os.path.join(save_dir, 'logits'), exist_ok=True) |
| os.makedirs(os.path.join(save_dir, 'probs'), exist_ok=True) |
|
|
| |
| logits_np = logits.detach().cpu().numpy() |
| fig_logits = plt.figure(figsize=(8, 6)) |
| sns.heatmap(logits_np, square=True, cmap="Blues", cbar=True, annot=False) |
| plt.title(f"Raw Logits Heatmap — Epoch {epoch+1}, Loss {loss:.4f}") |
| plt.xlabel("Audio Index") |
| plt.ylabel("Image Index") |
| raw_path = os.path.join(save_dir, 'logits', f"raw_logits_epoch_{epoch+1}_loss_{loss:.4f}.png") |
| fig_logits.savefig(raw_path) |
| writer.add_image("Heatmap/RawLogits", fig_to_tensor(fig_logits), global_step=epoch+1) |
|
|
| |
| probs_np = logits.softmax(dim=1).cpu().numpy() |
| fig_probs = plt.figure(figsize=(8, 6)) |
| sns.heatmap(probs_np, square=True, cmap="Blues", cbar=True, annot=False) |
| plt.title(f"Softmax Probabilities Heatmap — Epoch {epoch+1}, Loss {loss:.4f}") |
| plt.xlabel("Audio Index") |
| plt.ylabel("Image Index") |
| prob_path = os.path.join(save_dir, "probs", f"probs_epoch_{epoch+1}_loss_{loss:.4f}.png") |
| fig_probs.savefig(prob_path) |
| writer.add_image("Heatmap/SoftmaxProbs", fig_to_tensor(fig_probs), global_step=epoch+1) |
|
|
|
|
|
|
| def train_model(model, train_loader, test_loader, |
| optimizer, scheduler, device, log_dir, |
| checkpoint_dir, resume=False, epochs=10): |
|
|
| os.makedirs(log_dir, exist_ok=True) |
| os.makedirs(checkpoint_dir, exist_ok=True) |
| csv_path = os.path.join(log_dir, "training_log.csv") |
|
|
| writer = SummaryWriter(log_dir=log_dir) |
|
|
| start_epoch = 0 |
| best_loss = float("inf") |
| best_epoch = -1 |
|
|
| if resume: |
| start_epoch, best_loss = load_checkpoint(checkpoint_dir, model, optimizer, scheduler) |
|
|
| |
| if not (resume and os.path.exists(csv_path)): |
| with open(csv_path, mode='w', newline='') as f: |
| writer_csv = csv.writer(f) |
| writer_csv.writerow(["Epoch", "Best Epoch", "Train Loss", "Test Loss", "Best Loss", "Learning Rate"]) |
|
|
| for epoch in trange(start_epoch, epochs, colour='yellow', dynamic_ncols=True): |
| train_losses = [] |
| test_losses = [] |
|
|
| train_loop = tqdm(train_loader, desc=f"[TrainEp {epoch+1}]", colour='blue', dynamic_ncols=True) |
| for batch in train_loop: |
| images = batch['image_tensor'].to(device) |
| audios = batch['audio_tensor'].to(device) |
| loss, logits, probs = train_batch(model, images, audios, optimizer) |
| train_losses.append(loss) |
| train_loop.set_postfix(trainLoss=loss) |
|
|
| test_loop = tqdm(test_loader, desc=f"[TestEp {epoch+1}]", colour='red', dynamic_ncols=True) |
| for batch in test_loop: |
| images = batch['image_tensor'].to(device) |
| audios = batch['audio_tensor'].to(device) |
| loss, logits, probs = evaluate_batch(model, images, audios) |
| test_losses.append(loss) |
| test_loop.set_postfix(testLoss=loss) |
|
|
| avg_train_loss = sum(train_losses) / len(train_losses) |
| avg_test_loss = sum(test_losses) / len(test_losses) |
| |
| current_lr = optimizer.param_groups[0]['lr'] |
|
|
| writer.add_scalar("Loss/Train", avg_train_loss, epoch + 1) |
| writer.add_scalar("Loss/Test", avg_test_loss, epoch + 1) |
| writer.add_scalar("Learning Rate", current_lr, epoch + 1) |
|
|
| print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | \ |
| Test Loss: {avg_test_loss:.4f} | LR: {current_lr:.2e}") |
|
|
| if avg_test_loss < best_loss: |
| save_similarity_heatmaps(logits, epoch, avg_test_loss, checkpoint_dir, writer) |
| best_loss = avg_test_loss |
| best_epoch = epoch + 1 |
| save_checkpoint({ |
| 'epoch': epoch, |
| 'model_state': model.state_dict(), |
| 'optimizer_state': optimizer.state_dict(), |
| 'best_loss': best_loss, |
| 'scheduler_state': scheduler.state_dict() if scheduler else None |
| }, checkpoint_dir, epoch) |
| print(f">>> Saved new best model at epoch {epoch+1}") |
|
|
| scheduler.step() |
| with open(csv_path, mode='a', newline='') as f: |
| writer_csv = csv.writer(f) |
| writer_csv.writerow([epoch + 1, best_epoch, avg_train_loss, avg_test_loss, best_loss, current_lr]) |
|
|
| writer.close() |
|
|
|
|
|
|
| model_name = "csip_model_openClip_CLAP" |
| learning_rate = 1e-4 |
| epochs = 100 |
| optimizer = torch.optim.AdamW(csip_model.parameters(), lr=learning_rate) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-10) |
|
|
| |
| |
| |
| |
| |
|
|
| train_model( |
| model=csip_model, |
| train_loader=train_dataloader, |
| test_loader=test_dataloader, |
| optimizer=optimizer, |
| scheduler=scheduler, |
| device=device, |
| log_dir=f"{model_name}/runs/csip", |
| checkpoint_dir=f"{model_name}/checkpoints/csip", |
| resume=True, |
| epochs=epochs |
| ) |
|
|
|
|
| |
| |
|
|
| |
|
|