Vaani-Audio2Img-LDM / Vaani /Img_Audio_Alignment /_2.1_Train_OpenCLIP.py
alpha31476's picture
Image Audio Alingment Train OpenClip
0733b34 verified
# ==================================================================
# L A T E N T D I F F U S I O N M O D E L
# ==================================================================
# Author : Ashish Kumar Uchadiya
# Created : May 11, 2025
# Description: This script implements the training of a VQ-VAE model for
# image reconstruction, integrated with Latent Diffusion Models (LDMs) and
# audio conditioning. The VQ-VAE maps images to a discrete latent space,
# which is then modeled by the LDM for learning a diffusion process over the
# compressed representation. Audio features are used as conditioning inputs
# to guide the generation process. The training minimizes a combination of
# LPIPS (Learned Perceptual Image Patch Similarity) loss for perceptual
# fidelity and PatchGAN loss to enforce local realism. This setup enables
# efficient and semantically-aware generation of high-quality images driven
# by audio cues.
# ==================================================================
# I M P O R T S
# ==================================================================
from __future__ import annotations
import warnings
warnings.filterwarnings("ignore")
import os
import io
import sys
import math
import random
import collections
import collections.abc
import re
from itertools import repeat
from pathlib import Path
from typing import Optional, Tuple, Union, List, Dict
import csv
import numpy as np
import pandas as pd
from PIL import Image
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import trange, tqdm
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.init import _calculate_fan_in_and_fan_out
import torch.utils.checkpoint as checkpoint
import torchvision as tv
from torchvision.transforms import v2
from torch.utils.tensorboard import SummaryWriter
# from tensorboardX import SummaryWriter
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
import torchaudio
import torchaudio.transforms as T
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from torchlibrosa.augmentation import SpecAugmentation
from transformers import AutoModel, AutoTokenizer, logging
from huggingface_hub.file_download import hf_hub_download
from huggingface_hub.file_download import hf_hub_download
from peft import get_peft_config, get_peft_model
from transformers import CLIPVisionModel, AutoProcessor
from watermark import watermark
print(watermark(
author='Ashish',
# email='ashish@example.com',
current_date=True,
datename=True,
current_time=True,
iso8601=True,
timezone=True,
updated=True,
custom_time=None,
python=True,
# packages="torch,torchvision,numpy",
conda=True,
hostname=True,
machine=True,
watermark=False,
iversions=True,
gpu=True,
globals_=globals()
))
# ==================================================================
# H T S - A T
# ==================================================================
class HTSATConfig:
# Ke Chen
# knutchen@ucsd.edu
# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
# The configuration for training the model
exp_name = "exp_htsat_pretrain" # the saved ckpt prefix name of the model
workspace = "/home/kechen/Research/HTSAT" # the folder of your code
dataset_path = "/home/Research/audioset" # the dataset path
desed_folder = "/home/Research/DESED" # the desed file
dataset_type = "audioset" # "audioset" "esc-50" "scv2"
index_type = "full_train" # only works for audioset
balanced_data = True # only works for audioset
loss_type = "clip_bce" #
# AudioSet & SCV2: "clip_bce" | ESC-50: "clip_ce"
# trained from a checkpoint, or evaluate a single model
resume_checkpoint = None
# "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt"
esc_fold = 0 # just for esc dataset, select the fold you need for evaluation and (+1) validation
debug = False
random_seed = 970131 # 19970318 970131 12412 127777 1009 34047
batch_size = 32 * 4 # batch size per GPU x GPU number , default is 32 x 4 = 128
learning_rate = 1e-3 # 1e-4 also workable
max_epoch = 100
num_workers = 3
lr_scheduler_epoch = [10,20,30]
lr_rate = [0.02, 0.05, 0.1]
# these data preparation optimizations do not bring many improvements, so deprecated
enable_token_label = False # token label
class_map_path = "class_hier_map.npy"
class_filter = None
retrieval_index = [15382, 9202, 130, 17618, 17157, 17516, 16356, 6165, 13992, 9238, 5550, 5733, 1914, 1600, 3450, 13735, 11108, 3762,
9840, 11318, 8131, 4429, 16748, 4992, 16783, 12691, 4945, 8779, 2805, 9418, 2797, 14357, 5603, 212, 3852, 12666, 1338, 10269, 2388, 8260, 4293, 14454, 7677, 11253, 5060, 14938, 8840, 4542, 2627, 16336, 8992, 15496, 11140, 446, 6126, 10691, 8624, 10127, 9068, 16710, 10155, 14358, 7567, 5695, 2354, 8057, 17635, 133, 16183, 14535, 7248, 4560, 14429, 2463, 10773, 113, 2462, 9223, 4929, 14274, 4716, 17307, 4617, 2132, 11083, 1039, 1403, 9621, 13936, 2229, 2875, 17840, 9359, 13311, 9790, 13288, 4750, 17052, 8260, 14900]
token_label_range = [0.2,0.6]
enable_time_shift = False # shift time
enable_label_enhance = False # enhance hierarchical label
enable_repeat_mode = False # repeat the spectrogram / reshape the spectrogram
# for model's design
enable_tscam = True # enbale the token-semantic layer
# for signal processing
sample_rate = 32000 # 16000 for scv2, 32000 for audioset and esc-50
clip_samples = sample_rate * 10 # audio_set 10-sec clip
window_size = 1024
hop_size = 320 # 160 for scv2, 320 for audioset and esc-50
mel_bins = 64
fmin = 50
fmax = 14000
shift_max = int(clip_samples * 0.5)
# for data collection
classes_num = 527 # esc: 50 | audioset: 527 | scv2: 35
patch_size = (25, 4) # deprecated
crop_size = None # int(clip_samples * 0.5) deprecated
# for htsat hyperparamater
htsat_window_size = 8
htsat_spec_size = 256
htsat_patch_size = 4
htsat_stride = (4, 4)
htsat_num_head = [4,8,16,32]
htsat_dim = 96
htsat_depth = [2,2,6,2]
swin_pretrain_path = None
# "/home/Research/model_backup/pretrain/swin_tiny_c24_patch4_window8_256.pth"
# Some Deprecated Optimization in the model design, check the model code for details
htsat_attn_heatmap = False
htsat_hier_output = False
htsat_use_max = False
# for ensemble test
ensemble_checkpoints = []
ensemble_strides = []
# weight average folder
wa_folder = "/home/version_0/checkpoints/"
# weight average output filename
wa_model_path = "HTSAT_AudioSet_Saved_x.ckpt"
esm_model_pathes = [
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt",
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_2.ckpt",
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_3.ckpt",
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_4.ckpt",
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_5.ckpt",
"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_6.ckpt"
]
# for framewise localization
heatmap_dir = "/home/Research/heatmap_output"
test_file = "htsat-test-ensemble"
fl_local = False # indicate if we need to use this dataset for the framewise detection
fl_dataset = "/home/Research/desed/desedim_embval.npy"
fl_class_num = [
"Speech", "Frying", "Dishes", "Running_water",
"Blender", "Electric_shaver_toothbrush", "Alarm_bell_ringing",
"Cat", "Dog", "Vacuum_cleaner"
]
# map 527 classes into 10 classes
fl_audioset_mapping = [
[0,1,2,3,4,5,6,7],
[366, 367, 368],
[364],
[288, 289, 290, 291, 292, 293, 294, 295, 296, 297],
[369],
[382],
[310, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402],
[81, 82, 83, 84, 85],
[74, 75, 76, 77, 78, 79],
[377]
]
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
def do_mixup(x, mixup_lambda):
"""Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes
(1, 3, 5, ...).
Args:
x: (batch_size * 2, ...)
mixup_lambda: (batch_size * 2,)
Returns:
out: (batch_size, ...)
"""
out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \
x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1)
return out
def interpolate(x, ratio):
"""Interpolate data in time domain. This is used to compensate the
resolution reduction in downsampling of a CNN.
Args:
x: (batch_size, time_steps, classes_num)
ratio: int, ratio to interpolate
Returns:
upsampled: (batch_size, time_steps * ratio, classes_num)
"""
(batch_size, time_steps, classes_num) = x.shape
upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
return upsampled
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patch_stride = to_2tuple(patch_stride)
self.img_size = img_size
self.patch_size = patch_size
self.patch_stride = patch_stride
self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.in_chans = in_chans
self.embed_dim = embed_dim
padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def _no_gradim_audiorunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_gradim_audiorunc_normal_(tensor, mean, std, a, b)
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == 'fan_in':
denom = fan_in
elif mode == 'fan_out':
denom = fan_out
elif mode == 'fan_avg':
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
elif distribution == "normal":
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
# below codes are based and referred from https://github.com/microsoft/Swin-Transformer
# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn
def extra_repr(self):
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.norm_before_mlp = norm_before_mlp
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
if self.norm_before_mlp == 'ln':
self.norm2 = nn.LayerNorm(dim)
elif self.norm_before_mlp == 'bn':
self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2)
else:
raise NotImplementedError
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
# pdb.set_trace()
H, W = self.input_resolution
# print("H: ", H)
# print("W: ", W)
# pdb.set_trace()
B, L, C = x.shape
# assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x, attn
def extra_repr(self):
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self):
return f"input_resolution={self.input_resolution}, dim={self.dim}"
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
norm_before_mlp='ln'):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer, norm_before_mlp=norm_before_mlp)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
attns = []
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x, attn = blk(x)
if not self.training:
attns.append(attn.unsqueeze(0))
if self.downsample is not None:
x = self.downsample(x)
if not self.training:
attn = torch.cat(attns, dim = 0)
attn = torch.mean(attn, dim = 0)
return x, attn
def extra_repr(self):
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
# The Core of HTSAT
class HTSAT_Swin_Transformer(nn.Module):
r"""HTSAT based on the Swin Transformer
Args:
spec_size (int | tuple(int)): Input Spectrogram size. Default 256
patch_size (int | tuple(int)): Patch size. Default: 4
path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
in_chans (int): Number of input image channels. Default: 1 (mono)
num_classes (int): Number of classes for classification head. Default: 527
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 8
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
config (module): The configuration Module from config.py (HTSATConfig Class)
"""
def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4),
in_chans=1, num_classes=527,
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32],
window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
ape=False, patch_norm=True,
use_checkpoint=False, norm_before_mlp='ln', config = None, **kwargs):
super(HTSAT_Swin_Transformer, self).__init__()
self.config = config
self.spec_size = spec_size
self.patch_stride = patch_stride
self.patch_size = patch_size
self.window_size = window_size
self.embed_dim = embed_dim
self.depths = depths
self.ape = ape
self.in_chans = in_chans
self.num_classes = num_classes
self.num_heads = num_heads
self.num_layers = len(self.depths)
self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
self.drop_rate = drop_rate
self.attn_drop_rate = attn_drop_rate
self.drop_path_rate = drop_path_rate
self.qkv_bias = qkv_bias
self.qk_scale = None
self.patch_norm = patch_norm
self.norm_layer = norm_layer if self.patch_norm else None
self.norm_before_mlp = norm_before_mlp
self.mlp_ratio = mlp_ratio
self.use_checkpoint = use_checkpoint
# process mel-spec ; used only once
self.freq_ratio = self.spec_size // self.config.mel_bins
window = 'hann'
center = True
pad_mode = 'reflect'
ref = 1.0
amin = 1e-10
top_db = None
self.interpolate_ratio = 32 # Downsampled ratio
# Spectrogram extractor
self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size,
win_length=config.window_size, window=window, center=center, pad_mode=pad_mode,
freeze_parameters=True)
# Logmel feature extractor
self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size,
n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db,
freeze_parameters=True)
# Spec augmenter
self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
freq_drop_width=8, freq_stripes_num=2) # 2 2
self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
# split spctrogram into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans,
embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.grid_size
self.patches_resolution = patches_resolution
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=self.drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=self.depths[i_layer],
num_heads=self.num_heads[i_layer],
window_size=self.window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
drop=self.drop_rate, attn_drop=self.attn_drop_rate,
drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],
norm_layer=self.norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
norm_before_mlp=self.norm_before_mlp)
self.layers.append(layer)
self.norm = self.norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.maxpool = nn.AdaptiveMaxPool1d(1)
if self.config.enable_tscam:
SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio
self.tscam_conv = nn.Conv2d(
in_channels = self.num_features,
out_channels = self.num_classes,
kernel_size = (SF,3),
padding = (0,1)
)
self.head = nn.Linear(num_classes, num_classes)
else:
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def forward_features(self, x):
frames_num = x.shape[2]
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for i, layer in enumerate(self.layers):
x, attn = layer(x)
if self.config.enable_tscam:
# for x
x = self.norm(x)
B, N, C = x.shape
SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST)
B, C, F, T = x.shape
# group 2D CNN
c_freq_bin = F // self.freq_ratio
x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
# get latent_output
latent_output = self.avgpool(torch.flatten(x,2))
latent_output = torch.flatten(latent_output, 1)
# display the attention map, if needed
if self.config.htsat_attn_heatmap:
# for attn
attn = torch.mean(attn, dim = 1)
attn = torch.mean(attn, dim = 1)
attn = attn.reshape(B, SF, ST)
c_freq_bin = SF // self.freq_ratio
attn = attn.reshape(B, SF // c_freq_bin, c_freq_bin, ST)
attn = attn.permute(0,2,1,3).contiguous().reshape(B, c_freq_bin, -1)
attn = attn.mean(dim = 1)
attn_max = torch.max(attn, dim = 1, keepdim = True)[0]
attn_min = torch.min(attn, dim = 1, keepdim = True)[0]
attn = ((attn * 0.15) + (attn_max * 0.85 - attn_min)) / (attn_max - attn_min)
attn = attn.unsqueeze(dim = 2)
x = self.tscam_conv(x)
x = torch.flatten(x, 2) # B, C, T
if self.config.htsat_attn_heatmap:
fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous() * attn, 8 * self.patch_stride[1])
else:
fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
x = self.avgpool(x)
x = torch.flatten(x, 1)
if self.config.loss_type == "clip_ce":
output_dict = {
'framewise_output': fpx, # already sigmoided
'clipwise_output': x,
'latent_output': latent_output
}
else:
output_dict = {
'framewise_output': fpx, # already sigmoided
'clipwise_output': torch.sigmoid(x),
'latent_output': latent_output
}
else:
x = self.norm(x) # B N C
B, N, C = x.shape
fpx = x.permute(0,2,1).contiguous().reshape(B, C, frames_num // (2 ** (len(self.depths) + 1)), frames_num // (2 ** (len(self.depths) + 1)) )
B, C, F, T = fpx.shape
c_freq_bin = F // self.freq_ratio
fpx = fpx.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
fpx = fpx.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
fpx = torch.sum(fpx, dim = 2)
fpx = interpolate(fpx.permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
if self.num_classes > 0:
x = self.head(x)
fpx = self.head(fpx)
output_dict = {'framewise_output': torch.sigmoid(fpx),
'clipwise_output': torch.sigmoid(x)}
return output_dict
def crop_wav(self, x, crop_size, spe_pos = None):
time_steps = x.shape[2]
tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
for i in range(len(x)):
if spe_pos is None:
crop_pos = random.randint(0, time_steps - crop_size - 1)
else:
crop_pos = spe_pos
tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:]
return tx
# Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
def reshape_wav2img(self, x):
B, C, T, F = x.shape
target_T = int(self.spec_size * self.freq_ratio)
target_F = self.spec_size // self.freq_ratio
assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
# to avoid bicubic zero error
if T < target_T:
x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
if F < target_F:
x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
x = x.permute(0,1,3,2).contiguous()
x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio)
# print(x.shape)
x = x.permute(0,1,3,2,4).contiguous()
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
return x
# Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
def repeat_wat2img(self, x, cur_pos):
B, C, T, F = x.shape
target_T = int(self.spec_size * self.freq_ratio)
target_F = self.spec_size // self.freq_ratio
assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
# to avoid bicubic zero error
if T < target_T:
x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
if F < target_F:
x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
x = x.permute(0,1,3,2).contiguous() # B C F T
x = x[:,:,:,cur_pos:cur_pos + self.spec_size]
x = x.repeat(repeats = (1,1,4,1))
return x
def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False):# out_feat_keys: List[str] = None):
x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
if self.training:
x = self.spec_augmenter(x)
if self.training and mixup_lambda is not None:
x = do_mixup(x, mixup_lambda)
if infer_mode:
# in infer mode. we need to handle different length audio input
frame_num = x.shape[2]
target_T = int(self.spec_size * self.freq_ratio)
repeat_ratio = math.floor(target_T / frame_num)
x = x.repeat(repeats=(1,1,repeat_ratio,1))
x = self.reshape_wav2img(x)
output_dict = self.forward_features(x)
elif self.config.enable_repeat_mode:
if self.training:
cur_pos = random.randint(0, (self.freq_ratio - 1) * self.spec_size - 1)
x = self.repeat_wat2img(x, cur_pos)
output_dict = self.forward_features(x)
else:
output_dicts = []
for cur_pos in range(0, (self.freq_ratio - 1) * self.spec_size + 1, self.spec_size):
tx = x.clone()
tx = self.repeat_wat2img(tx, cur_pos)
output_dicts.append(self.forward_features(tx))
clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
for d in output_dicts:
clipwise_output += d["clipwise_output"]
framewise_output += d["framewise_output"]
clipwise_output = clipwise_output / len(output_dicts)
framewise_output = framewise_output / len(output_dicts)
output_dict = {
'framewise_output': framewise_output,
'clipwise_output': clipwise_output
}
else:
if x.shape[2] > self.freq_ratio * self.spec_size:
if self.training:
x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
x = self.reshape_wav2img(x)
output_dict = self.forward_features(x)
else:
# Change: Hard code here
overlap_size = 344 #(x.shape[2] - 1) // 4
output_dicts = []
crop_size = 689 #(x.shape[2] - 1) // 2
for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
tx = self.reshape_wav2img(tx)
output_dicts.append(self.forward_features(tx))
clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
latent_output = torch.zeros_like(output_dicts[0]["latent_output"]).float().to(x.device)
for d in output_dicts:
clipwise_output += d["clipwise_output"]
framewise_output += d["framewise_output"]
latent_output += d["latent_output"]
clipwise_output = clipwise_output / len(output_dicts)
framewise_output = framewise_output / len(output_dicts)
latent_output = latent_output / len(output_dicts)
output_dict = {
'framewise_output': framewise_output,
'clipwise_output': clipwise_output,
'latent_output': latent_output,
}
else: # this part is typically used, and most easy one
x = self.reshape_wav2img(x)
output_dict = self.forward_features(x)
# x = self.head(x)
return output_dict
class HTSATWrapper(nn.Module):
def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin,
fmax, classes_num, out_emb):
super().__init__()
# print("parameters are being overidden when using HTSAT")
# print("HTSAT only support loading a pretrained model on AudioSet")
# @TODO later look at what parameters are same and can be merged
self.htsat = HTSAT_Swin_Transformer(config=HTSATConfig())
def forward(self, x):
out_dict = self.htsat(x)
out_dict['embedding'] = out_dict['latent_output']
return out_dict
def get_audio_encoder(name: str):
if name == "HTSAT":
return HTSATWrapper
else:
raise Exception('The audio encoder name {} is incorrect or not supported'.format(name))
class Projection(nn.Module):
def __init__(self, dim_imgn: int, d_out: int, p: float=0.5) -> None:
super().__init__()
self.linear1 = nn.Linear(dim_imgn, d_out, bias=False)
self.linear2 = nn.Linear(d_out, d_out, bias=False)
self.layer_norm = nn.LayerNorm(d_out)
self.drop = nn.Dropout(p)
def forward(self, x: torch.Tensor) -> torch.Tensor:
embed1 = self.linear1(x)
embed2 = self.drop(self.linear2(F.gelu(embed1)))
embeds = self.layer_norm(embed1 + embed2)
return embeds
class AudioEncoder(nn.Module):
def __init__(self, audioenc_name:str, dim_imgn: int, d_out: int, sample_rate: int, window_size: int,
hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None:
super().__init__()
audio_encoder = get_audio_encoder(audioenc_name)
self.base = audio_encoder(
sample_rate, window_size,
hop_size, mel_bins, fmin, fmax,
classes_num, dim_imgn)
self.projection = Projection(dim_imgn, d_out)
def forward(self, x):
out_dict = self.base(x)
audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output']
projected_vec = self.projection(audio_features)
return projected_vec, audio_classification_output
class TextEncoder(nn.Module):
def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None:
super().__init__()
self.text_model = text_model
self.base = AutoModel.from_pretrained(text_model)
if 'clip' in text_model:
self.clip_text_projection = self.base.text_projection
self.base = self.base.text_model
if 'base' in text_model:
transformer_embed_dim = 512
self.projection = Projection(transformer_embed_dim, d_out)
def forward(self, x):
if 'clip' in self.text_model:
pooled_output = self.base(**x)[1] # get pooled output
out = self.clip_text_projection(pooled_output) # get CLS token output
elif 'gpt' in self.text_model:
batch_size = x['input_ids'].shape[0]
hidden_states = self.base(**x)[0] # (batch_size=4, seq_len, 768)
sequence_lengths = torch.ne(x['input_ids'], 0).sum(-1) - 1 # tensor([13, 14, 18, 17])
out = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths] # [batch_size, 768] = [4, 768]
else:
out = self.base(**x)[0]
out = out[:, 0, :] # get CLS token output
projected_vec = self.projection(out)
return projected_vec
class CLAP(nn.Module):
def __init__(self,
# audio
audioenc_name: str,
sample_rate: int,
window_size: int,
hop_size: int,
mel_bins: int,
fmin: int,
fmax: int,
classes_num: int,
out_emb: int,
# text
text_model: str,
transformer_embed_dim: int,
# common
d_proj: int,
):
super().__init__()
self.audio_encoder = AudioEncoder(
audioenc_name, out_emb, d_proj,
sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num)
self.caption_encoder = TextEncoder(
d_proj, text_model, transformer_embed_dim
)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def forward(self, audio, text):
audio_embed, _ = self.audio_encoder(audio)
caption_embed = self.caption_encoder(text)
return caption_embed, audio_embed, self.logit_scale.exp()
# ==================================================================
# A U D I O - P R E - P R O C E S S I N G
# ==================================================================
def read_audio(audio_path, resample=True):
r"""Loads audio file or array and returns a torch tensor"""
# Randomly sample a segment of audio_duration from the clip or pad to match duration
audio_time_series, sample_rate = torchaudio.load(audio_path)
resample_rate = clapConfig.sample_rate
if resample and resample_rate != sample_rate:
resampler = T.Resample(sample_rate, resample_rate)
audio_time_series = resampler(audio_time_series)
return audio_time_series, resample_rate
def load_audio_into_tensor(audio_path, audio_duration, resample=False):
r"""Loads audio file and returns raw audio."""
# Randomly sample a segment of audio_duration from the clip or pad to match duration
audio_time_series, sample_rate = read_audio(audio_path, resample)
audio_time_series = audio_time_series.reshape(-1)
# audio_time_series is shorter than predefined audio duration,
# so audio_time_series is extended
if audio_duration*sample_rate >= audio_time_series.shape[0]:
repeat_factor = int(np.ceil((audio_duration*sample_rate) /
audio_time_series.shape[0]))
# Repeat audio_time_series by repeat_factor to match audio_duration
audio_time_series = audio_time_series.repeat(repeat_factor)
# remove excess part of audio_time_series
audio_time_series = audio_time_series[0:audio_duration*sample_rate]
else:
# audio_time_series is longer than predefined audio duration,
# so audio_time_series is trimmed
start_index = random.randrange(
audio_time_series.shape[0] - audio_duration*sample_rate)
audio_time_series = audio_time_series[start_index:start_index +
audio_duration*sample_rate]
return torch.FloatTensor(audio_time_series)
np_str_obj_array_pattern = re.compile(r'[SaUO]')
default_collate_err_msg_format = (
"default_collate: batch must contain tensors, numpy arrays, numbers, "
"dicts or lists; found {}")
def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(
default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, str):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError(
'each element in list of batch should be of equal size')
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))
def preprocess_audio(audio_files, resample):
r"""Load list of audio files and return raw audio"""
audio_tensors = []
for audio_file in audio_files:
audio_tensor = load_audio_into_tensor(
audio_file, clapConfig.duration, resample)
audio_tensor = audio_tensor.reshape(1, -1)
audio_tensors.append(audio_tensor)
return default_collate(audio_tensors)
# ==================================================================
# A U D I O - E M B E D D I N G S - H E L P E R
# ==================================================================
def CLAPAudioProcessor(audio_files: List[str], resample=True):
preprocessed_audio = preprocess_audio(audio_files, resample)
preprocessed_audio = preprocessed_audio.reshape(
preprocessed_audio.shape[0], preprocessed_audio.shape[2])
preprocessed_audio = preprocessed_audio
return preprocessed_audio
def get_audio_embeddings(audio_files: List[str], audio_encoder, resample=True):
"""Load list of audio files and return audio embeddings"""
# preprocessed_audio = preprocess_audio(audio_files, resample)
# with torch.no_grad():
# preprocessed_audio = preprocessed_audio.reshape(
# preprocessed_audio.shape[0], preprocessed_audio.shape[2])
with torch.no_grad():
preprocessed_audio = CLAPAudioProcessor(audio_files, resample)
return audio_encoder(preprocessed_audio)[0]
# ==================================================================
# C L A P
# ==================================================================
class ClapConfig:
# TEXT ENCODER CONFIG
text_model = 'gpt2'
text_len = 77
transformer_embed_dim = 768
freeze_text_encoder_weights = True
# AUDIO ENCODER CONFIG
audioenc_name = 'HTSAT'
out_emb = 768
sample_rate = 44100
duration = 7
fmin = 50
fmax = 8000 # 14000
n_fft = 1024 # 1028
hop_size = 320
mel_bins = 64
window_size = 1024
# PROJECTION SPACE CONFIG
d_proj = 1024
temperature = 0.003
# TRAINING AND EVALUATION CONFIG
num_classes = 527
batch_size = 1024
demo = False
clapConfig = ClapConfig()
clap = CLAP(
audioenc_name=clapConfig.audioenc_name,
sample_rate=clapConfig.sample_rate,
window_size=clapConfig.window_size,
hop_size=clapConfig.hop_size,
mel_bins=clapConfig.mel_bins,
fmin=clapConfig.fmin,
fmax=clapConfig.fmax,
classes_num=clapConfig.num_classes,
out_emb=clapConfig.out_emb,
text_model=clapConfig.text_model,
transformer_embed_dim=clapConfig.transformer_embed_dim,
d_proj=clapConfig.d_proj
)
model_repo = "microsoft/msclap"
model_name = {
'2022': 'CLAP_weights_2022.pth',
'2023': 'CLAP_weights_2023.pth',
'clapcap': 'clapcap_weights_2023.pth'
}
version = '2023'
model_fp = hf_hub_download(model_repo, model_name[version])
model_state_dict = torch.load(model_fp, map_location=torch.device('cpu'))['model']
clap.load_state_dict(model_state_dict, strict=False)
# clap.eval()
clap_audio_encoder = clap.audio_encoder.to(device)
# ENGLISH_AUDIO_DIR = r"/home/IITB/ai-at-ieor/23m1521/datasets/Vaani/Audios/English"
# audio_files = [os.path.join(ENGLISH_AUDIO_DIR, i) for i in os.listdir(ENGLISH_AUDIO_DIR) if i.endswith(".wav")]
# audio_embedding = get_audio_embeddings(audio_files, clap_audio_encoder)
# print("CLAP Audio Encoder Embeddings:", audio_embedding.shape) # [5, 1024]
# ==================================================================
# C L A P - L o R A - M O D E L
# ==================================================================
LoRAconfig = {
"peft_type": "LORA",
"task_type": "FEATURE_EXTRACTION",
"inference_mode": False,
"r": 16,
"target_modules": ["qkv", "fc1", "fc2", "proj", "linear1", "linear2"],
"lora_alpha": 32,
"lora_dropout": 0.05,
"fan_in_fan_out": False,
"bias": "all",
}
peft_config = get_peft_config(LoRAconfig)
peft_model = get_peft_model(clap_audio_encoder, peft_config)
peft_model.print_trainable_parameters()
peft_clap_audio_encoder = peft_model.base_model
# audio_embedding = get_audio_embeddings(audio_files, peft_clap_audio_encoder)
# print("CLAP LoRA Audio Encoder Embeddings:", audio_embedding.shape) # [5, 1024]
# ==================================================================
# O P E N - C L I P - M O D E L
# ==================================================================
import open_clip
open_clip_model, open_clip_imgaug, open_clip_preprocess = open_clip.create_model_and_transforms(
model_name='ViT-H-14', pretrained='laion2b_s32b_b79k', device=device
)
# ==================================================================
# C S I P - M O D U L E
# ==================================================================
class CSIP(nn.Module):
def __init__(self, image_encoder, audio_encoder,
dim_img=None, dim_audio=1024, dim_emb=1024):
super(CSIP, self).__init__()
self.image_encoder = image_encoder # CLIPVisionModel
self.audio_encoder = audio_encoder # CLAP_audio_encoder
for param in self.image_encoder.parameters():
param.requires_grad = False
# self.image_proj = nn.Linear(dim_img, dim_emb)
self.audio_proj = nn.Linear(dim_audio, dim_emb)
# Learnable temperature parameter
self.log_temp = nn.Parameter(torch.tensor(0.07).log())
def forward(self, images, audios):
# Step 1: Feature extraction
# with torch.no_grad():
# with torch.inference_mode():
image_features = self.image_encoder(images).norm(dim=-1, keepdim=True) # shape: [n, dim_img]
audio_features = self.audio_encoder(audios)[0].norm(dim=-1, keepdim=True) # shape: [n, dim_audio]
# Step 2: Project and normalize
image_embeds = F.normalize(image_features) # [n, dim_emb]
audio_embeds = F.normalize(self.audio_proj(audio_features), dim=1) # [n, dim_emb]
# Step 3: Cosine similarity with temperature
logits = torch.matmul(image_embeds, audio_embeds.T) * self.log_temp.exp() # [n, n]
probs = logits.softmax(dim=1)
# Step 4: Symmetric cross-entropy loss
labels = torch.arange(len(images), device=images.device)
# loss_i = F.cross_entropy(logits, labels)
loss_a = F.cross_entropy(logits.T, labels)
# loss = (loss_i + loss_a) / 2
loss = loss_a
return loss, logits, probs
# ==================================================================
# I M A G E - A U D I O - D A T A S E T
# ==================================================================
class VaaniImageAudioDataset(torch.utils.data.Dataset):
def __init__(self, df):
self.image_paths = df.image_path.tolist()
self.audio_paths = df.audio_path.tolist()
def __len__(self):
return len(self.audio_paths)
def __getitem__(self, idx):
return {
'image_path': self.image_paths[idx],
'audio_path': self.audio_paths[idx]
}
def collate_fn(batch):
image_tensor = open_clip_imgaug([Image.open(item['image_path']) for item in batch])['pixel_values']
audio_tensor = CLAPAudioProcessor([item['audio_path'] for item in batch], resample=True)
return {'image_tensor': torch.stack(image_tensor), 'audio_tensor': audio_tensor}
# preprocessed_audio = CLAPAudioProcessor(audio_files, resample=True)
# clip_vision_processor = clip_vision_processor
train_df = pd.read_csv("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TRAIN.csv")
test_df = pd.read_csv("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TEST.csv")
train_dataset = VaaniImageAudioDataset(train_df)
test_dataset = VaaniImageAudioDataset(test_df)
BATCH_SIZE = 64
print('Train Dataset:', len(train_dataset))
print('Test Dataset:', len(test_dataset))
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=48,
collate_fn=collate_fn,
pin_memory=True,
drop_last=False,
persistent_workers=True
)
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=48,
collate_fn=collate_fn,
pin_memory=True,
drop_last=False,
persistent_workers=True
)
batch = next(iter(train_dataloader))
image_tensor_batch = batch['image_tensor']
audio_tensor_batch = batch['audio_tensor']
print("Image batch shape:", image_tensor_batch.shape) # [BATCH_SIZE, 3, 224, 224]
print("Audio batch shape:", audio_tensor_batch.shape) # [BATCH_SIZE, 1, 44100]
csip_model = CSIP(open_clip_model.visual, peft_clap_audio_encoder).to(device)
from torchinfo import summary
import subprocess
summary(model=csip_model,
input_data=((image_tensor_batch.to(device)), (audio_tensor_batch.to(device))),
# input_size = (1, 3, config.IMAGE_SIZE, config.IMAGE_SIZE),
dtypes=[torch.long],
col_names = ["input_size", "output_size", "num_params", "trainable", "params_percent"],
col_width=20,
row_settings=["var_names"],
depth = 2,
# verbose=2,
# device=device
)
# loss, logits, probs = csip_model(batch['image_tensor'].to(device), batch['audio_tensor'].to(device))
# loss, logits, probs, logits.shape, probs.shape
def train_batch(model, images, audio, optimizer):
model.train()
optimizer.zero_grad()
loss, logits, probs = model(images, audio)
loss.backward()
optimizer.step()
return loss.item(), logits, probs
@torch.no_grad()
def evaluate_batch(model, images, audio):
model.eval()
loss, logits, probs = model(images, audio)
return loss.item(), logits, probs
def save_checkpoint(state, checkpoint_dir, epoch, max_checkpoints=2):
filename = f"csip_best_epoch_{epoch+1}.pt"
path = os.path.join(checkpoint_dir, filename)
torch.save(state, path)
checkpoints = sorted(
[f for f in os.listdir(checkpoint_dir) if f.startswith("csip_best_epoch_")],
key=lambda x: int(x.split("_")[-1].split(".")[0])
)
while len(checkpoints) > max_checkpoints:
to_delete = checkpoints.pop(0)
os.remove(os.path.join(checkpoint_dir, to_delete))
def load_checkpoint(checkpoint_dir, model, optimizer, scheduler):
checkpoints = sorted(
[f for f in os.listdir(checkpoint_dir) if f.startswith("csip_best_epoch_")],
key=lambda x: int(x.split("_")[-1].split(".")[0])
)
if not checkpoints:
print("No checkpoint found to resume from.")
return 0, float("inf")
best_ckpt = checkpoints[-1]
path = os.path.join(checkpoint_dir, best_ckpt)
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optimizer_state'])
scheduler.load_state_dict(checkpoint['scheduler_state'])
start_epoch = checkpoint['epoch']
best_loss = checkpoint['best_loss']
print(f"Resumed training from epoch {start_epoch+1} with best loss {best_loss:.4f}")
return start_epoch, best_loss
def fig_to_tensor(fig):
"""Convert a Matplotlib figure to a tensor suitable for TensorBoard."""
buf = io.BytesIO()
fig.savefig(buf, format='png')
buf.seek(0)
image = Image.open(buf).convert("RGB")
tensor = tv.transforms.functional.to_tensor(image)
buf.close()
plt.close(fig)
return tensor
def save_similarity_heatmaps(logits, epoch, loss, save_dir, writer):
os.makedirs(os.path.join(save_dir, 'logits'), exist_ok=True)
os.makedirs(os.path.join(save_dir, 'probs'), exist_ok=True)
# --- Raw logits heatmap ---
logits_np = logits.detach().cpu().numpy()
fig_logits = plt.figure(figsize=(8, 6))
sns.heatmap(logits_np, square=True, cmap="Blues", cbar=True, annot=False)
plt.title(f"Raw Logits Heatmap — Epoch {epoch+1}, Loss {loss:.4f}")
plt.xlabel("Audio Index")
plt.ylabel("Image Index")
raw_path = os.path.join(save_dir, 'logits', f"raw_logits_epoch_{epoch+1}_loss_{loss:.4f}.png")
fig_logits.savefig(raw_path)
writer.add_image("Heatmap/RawLogits", fig_to_tensor(fig_logits), global_step=epoch+1)
# --- Softmax probs heatmap ---
probs_np = logits.softmax(dim=1).cpu().numpy()
fig_probs = plt.figure(figsize=(8, 6))
sns.heatmap(probs_np, square=True, cmap="Blues", cbar=True, annot=False)
plt.title(f"Softmax Probabilities Heatmap — Epoch {epoch+1}, Loss {loss:.4f}")
plt.xlabel("Audio Index")
plt.ylabel("Image Index")
prob_path = os.path.join(save_dir, "probs", f"probs_epoch_{epoch+1}_loss_{loss:.4f}.png")
fig_probs.savefig(prob_path)
writer.add_image("Heatmap/SoftmaxProbs", fig_to_tensor(fig_probs), global_step=epoch+1)
def train_model(model, train_loader, test_loader,
optimizer, scheduler, device, log_dir,
checkpoint_dir, resume=False, epochs=10):
os.makedirs(log_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
csv_path = os.path.join(log_dir, "training_log.csv")
writer = SummaryWriter(log_dir=log_dir)
start_epoch = 0
best_loss = float("inf")
best_epoch = -1
if resume:
start_epoch, best_loss = load_checkpoint(checkpoint_dir, model, optimizer, scheduler)
# If resuming, don't overwrite the CSV
if not (resume and os.path.exists(csv_path)):
with open(csv_path, mode='w', newline='') as f:
writer_csv = csv.writer(f)
writer_csv.writerow(["Epoch", "Best Epoch", "Train Loss", "Test Loss", "Best Loss", "Learning Rate"])
for epoch in trange(start_epoch, epochs, colour='yellow', dynamic_ncols=True):
train_losses = []
test_losses = []
train_loop = tqdm(train_loader, desc=f"[TrainEp {epoch+1}]", colour='blue', dynamic_ncols=True)
for batch in train_loop:
images = batch['image_tensor'].to(device)
audios = batch['audio_tensor'].to(device)
loss, logits, probs = train_batch(model, images, audios, optimizer)
train_losses.append(loss)
train_loop.set_postfix(trainLoss=loss)
test_loop = tqdm(test_loader, desc=f"[TestEp {epoch+1}]", colour='red', dynamic_ncols=True)
for batch in test_loop:
images = batch['image_tensor'].to(device)
audios = batch['audio_tensor'].to(device)
loss, logits, probs = evaluate_batch(model, images, audios)
test_losses.append(loss)
test_loop.set_postfix(testLoss=loss)
avg_train_loss = sum(train_losses) / len(train_losses)
avg_test_loss = sum(test_losses) / len(test_losses)
current_lr = optimizer.param_groups[0]['lr']
writer.add_scalar("Loss/Train", avg_train_loss, epoch + 1)
writer.add_scalar("Loss/Test", avg_test_loss, epoch + 1)
writer.add_scalar("Learning Rate", current_lr, epoch + 1)
print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | \
Test Loss: {avg_test_loss:.4f} | LR: {current_lr:.2e}")
if avg_test_loss < best_loss:
save_similarity_heatmaps(logits, epoch, avg_test_loss, checkpoint_dir, writer)
best_loss = avg_test_loss
best_epoch = epoch + 1
save_checkpoint({
'epoch': epoch,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'best_loss': best_loss,
'scheduler_state': scheduler.state_dict() if scheduler else None
}, checkpoint_dir, epoch)
print(f">>> Saved new best model at epoch {epoch+1}")
scheduler.step()
with open(csv_path, mode='a', newline='') as f:
writer_csv = csv.writer(f)
writer_csv.writerow([epoch + 1, best_epoch, avg_train_loss, avg_test_loss, best_loss, current_lr])
writer.close()
model_name = "csip_model_openClip_CLAP"
learning_rate = 1e-4
epochs = 100
optimizer = torch.optim.AdamW(csip_model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-10)
# subprocess.run([
# "rm",
# "-rf",
# f"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/{model_name}",
# ])
train_model(
model=csip_model,
train_loader=train_dataloader,
test_loader=test_dataloader,
optimizer=optimizer,
scheduler=scheduler,
device=device,
log_dir=f"{model_name}/runs/csip",
checkpoint_dir=f"{model_name}/checkpoints/csip",
resume=True,
epochs=epochs
)
# tensorboard --logdir=/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/runs --port=6006
# 127.0.0.1:40697
# tensorboard --logdir=/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/runs --port=6006 --host=0.0.0.0