| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchvision.models import swin_v2_t, Swin_V2_T_Weights |
| import copy |
|
|
|
|
| class SheetMusicSwin(nn.Module): |
| def __init__(self, out_channels=512): |
| super(SheetMusicSwin, self).__init__() |
| self.backbone = swin_v2_t(weights=Swin_V2_T_Weights.DEFAULT) |
| in_features = self.backbone.head.in_features |
| self.backbone.head = nn.Identity() |
| self.projection_head = nn.Sequential( |
| nn.Linear(in_features, 512), |
| nn.LayerNorm(512), |
| nn.GELU(), |
| nn.Linear(512, out_channels) |
| ) |
|
|
| def forward(self, x): |
| features = self.backbone(x) |
| projected = self.projection_head(features) |
| return F.normalize(projected, p=2, dim=1) |
|
|
|
|
| class SpectrogramSwin(nn.Module): |
| def __init__(self, out_channels=512): |
| super(SpectrogramSwin, self).__init__() |
| self.backbone = swin_v2_t(weights=Swin_V2_T_Weights.DEFAULT) |
|
|
| |
| original_conv = self.backbone.features[0][0] |
| new_conv = nn.Conv2d( |
| in_channels=1, |
| out_channels=original_conv.out_channels, |
| kernel_size=original_conv.kernel_size, |
| stride=original_conv.stride, |
| padding=original_conv.padding |
| ) |
| new_conv.weight.data = original_conv.weight.data.mean( |
| dim=1, keepdim=True) |
| if original_conv.bias is not None: |
| new_conv.bias.data = original_conv.bias.data |
| self.backbone.features[0][0] = new_conv |
|
|
| in_features = self.backbone.head.in_features |
| self.backbone.head = nn.Identity() |
| self.projection_head = nn.Sequential( |
| nn.Linear(in_features, 512), |
| nn.LayerNorm(512), |
| nn.GELU(), |
| nn.Linear(512, out_channels) |
| ) |
|
|
| def forward(self, x): |
| |
| H, W = x.shape[2], x.shape[3] |
| pad_h = (32 - H % 32) % 32 |
| pad_w = (32 - W % 32) % 32 |
| x = F.pad(x, (0, pad_w, 0, pad_h)) |
|
|
| features = self.backbone(x) |
| projected = self.projection_head(features) |
| return F.normalize(projected, p=2, dim=1) |
|
|
|
|
| class VisionAudioMoCo(nn.Module): |
| def __init__(self, vision_encoder, audio_encoder, dim=512, K=16384, m=0.999, T=0.07): |
| super(VisionAudioMoCo, self).__init__() |
| self.K = K |
| self.m = m |
| self.T = T |
|
|
| self.encoder_q_vision = vision_encoder |
| self.encoder_q_audio = audio_encoder |
|
|
| self.encoder_k_vision = copy.deepcopy(vision_encoder) |
| self.encoder_k_audio = copy.deepcopy(audio_encoder) |
|
|
| for param_q, param_k in zip(self.encoder_q_vision.parameters(), self.encoder_k_vision.parameters()): |
| param_k.requires_grad = False |
| for param_q, param_k in zip(self.encoder_q_audio.parameters(), self.encoder_k_audio.parameters()): |
| param_k.requires_grad = False |
|
|
| self.register_buffer("queue_vision", torch.randn(dim, K)) |
| self.register_buffer("queue_audio", torch.randn(dim, K)) |
| self.queue_vision = F.normalize(self.queue_vision, dim=0) |
| self.queue_audio = F.normalize(self.queue_audio, dim=0) |
| self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) |
|
|
| @torch.no_grad() |
| def _momentum_update_key_encoders(self): |
| for param_q, param_k in zip(self.encoder_q_vision.parameters(), self.encoder_k_vision.parameters()): |
| param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) |
| for param_q, param_k in zip(self.encoder_q_audio.parameters(), self.encoder_k_audio.parameters()): |
| param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) |
|
|
| @torch.no_grad() |
| def _dequeue_and_enqueue(self, keys_vision, keys_audio): |
| batch_size = keys_vision.shape[0] |
| ptr = int(self.queue_ptr) |
| if ptr + batch_size <= self.K: |
| self.queue_vision[:, ptr:ptr + batch_size] = keys_vision.T |
| self.queue_audio[:, ptr:ptr + batch_size] = keys_audio.T |
| self.queue_ptr[0] = (ptr + batch_size) % self.K |
| else: |
| remainder = self.K - ptr |
| wrap_around = batch_size - remainder |
| self.queue_vision[:, ptr:] = keys_vision[:remainder].T |
| self.queue_audio[:, ptr:] = keys_audio[:remainder].T |
| self.queue_vision[:, :wrap_around] = keys_vision[remainder:].T |
| self.queue_audio[:, :wrap_around] = keys_audio[remainder:].T |
| self.queue_ptr[0] = wrap_around |
|
|
| def forward(self, images, audio_inputs): |
| self._momentum_update_key_encoders() |
| q_vision = self.encoder_q_vision(images) |
| q_audio = self.encoder_q_audio(audio_inputs) |
| with torch.no_grad(): |
| k_vision = self.encoder_k_vision(images) |
| k_audio = self.encoder_k_audio(audio_inputs) |
|
|
| l_pos_V2A = torch.einsum('nc,nc->n', [q_vision, k_audio]).unsqueeze(-1) |
| l_neg_V2A = torch.einsum( |
| 'nc,ck->nk', [q_vision, self.queue_audio.clone().detach()]) |
| logits_V2A = torch.cat([l_pos_V2A, l_neg_V2A], dim=1) / self.T |
|
|
| l_pos_A2V = torch.einsum('nc,nc->n', [q_audio, k_vision]).unsqueeze(-1) |
| l_neg_A2V = torch.einsum( |
| 'nc,ck->nk', [q_audio, self.queue_vision.clone().detach()]) |
| logits_A2V = torch.cat([l_pos_A2V, l_neg_A2V], dim=1) / self.T |
|
|
| labels = torch.zeros( |
| logits_V2A.shape[0], dtype=torch.long).to(q_vision.device) |
| loss_V2A = F.cross_entropy(logits_V2A, labels) |
| loss_A2V = F.cross_entropy(logits_A2V, labels) |
|
|
| self._dequeue_and_enqueue(k_vision, k_audio) |
| return loss_V2A + loss_A2V |
|
|