# models.py import torch import torch.nn as nn import torch.nn.functional as F from torchvision.models import swin_v2_t, Swin_V2_T_Weights import copy class SheetMusicSwin(nn.Module): def __init__(self, out_channels=512): super(SheetMusicSwin, self).__init__() self.backbone = swin_v2_t(weights=Swin_V2_T_Weights.DEFAULT) in_features = self.backbone.head.in_features self.backbone.head = nn.Identity() self.projection_head = nn.Sequential( nn.Linear(in_features, 512), nn.LayerNorm(512), nn.GELU(), nn.Linear(512, out_channels) ) def forward(self, x): features = self.backbone(x) projected = self.projection_head(features) return F.normalize(projected, p=2, dim=1) class SpectrogramSwin(nn.Module): def __init__(self, out_channels=512): super(SpectrogramSwin, self).__init__() self.backbone = swin_v2_t(weights=Swin_V2_T_Weights.DEFAULT) # Modify first conv layer to accept 1-channel (grayscale) spectrograms original_conv = self.backbone.features[0][0] new_conv = nn.Conv2d( in_channels=1, out_channels=original_conv.out_channels, kernel_size=original_conv.kernel_size, stride=original_conv.stride, padding=original_conv.padding ) new_conv.weight.data = original_conv.weight.data.mean( dim=1, keepdim=True) if original_conv.bias is not None: new_conv.bias.data = original_conv.bias.data self.backbone.features[0][0] = new_conv in_features = self.backbone.head.in_features self.backbone.head = nn.Identity() self.projection_head = nn.Sequential( nn.Linear(in_features, 512), nn.LayerNorm(512), nn.GELU(), nn.Linear(512, out_channels) ) def forward(self, x): # Pad spectrogram to be divisible by 32 for Swin compatibility H, W = x.shape[2], x.shape[3] pad_h = (32 - H % 32) % 32 pad_w = (32 - W % 32) % 32 x = F.pad(x, (0, pad_w, 0, pad_h)) features = self.backbone(x) projected = self.projection_head(features) return F.normalize(projected, p=2, dim=1) class VisionAudioMoCo(nn.Module): def __init__(self, vision_encoder, audio_encoder, dim=512, K=16384, m=0.999, T=0.07): super(VisionAudioMoCo, self).__init__() self.K = K self.m = m self.T = T self.encoder_q_vision = vision_encoder self.encoder_q_audio = audio_encoder self.encoder_k_vision = copy.deepcopy(vision_encoder) self.encoder_k_audio = copy.deepcopy(audio_encoder) for param_q, param_k in zip(self.encoder_q_vision.parameters(), self.encoder_k_vision.parameters()): param_k.requires_grad = False for param_q, param_k in zip(self.encoder_q_audio.parameters(), self.encoder_k_audio.parameters()): param_k.requires_grad = False self.register_buffer("queue_vision", torch.randn(dim, K)) self.register_buffer("queue_audio", torch.randn(dim, K)) self.queue_vision = F.normalize(self.queue_vision, dim=0) self.queue_audio = F.normalize(self.queue_audio, dim=0) self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) @torch.no_grad() def _momentum_update_key_encoders(self): for param_q, param_k in zip(self.encoder_q_vision.parameters(), self.encoder_k_vision.parameters()): param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) for param_q, param_k in zip(self.encoder_q_audio.parameters(), self.encoder_k_audio.parameters()): param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) @torch.no_grad() def _dequeue_and_enqueue(self, keys_vision, keys_audio): batch_size = keys_vision.shape[0] ptr = int(self.queue_ptr) if ptr + batch_size <= self.K: self.queue_vision[:, ptr:ptr + batch_size] = keys_vision.T self.queue_audio[:, ptr:ptr + batch_size] = keys_audio.T self.queue_ptr[0] = (ptr + batch_size) % self.K else: remainder = self.K - ptr wrap_around = batch_size - remainder self.queue_vision[:, ptr:] = keys_vision[:remainder].T self.queue_audio[:, ptr:] = keys_audio[:remainder].T self.queue_vision[:, :wrap_around] = keys_vision[remainder:].T self.queue_audio[:, :wrap_around] = keys_audio[remainder:].T self.queue_ptr[0] = wrap_around def forward(self, images, audio_inputs): self._momentum_update_key_encoders() q_vision = self.encoder_q_vision(images) q_audio = self.encoder_q_audio(audio_inputs) with torch.no_grad(): k_vision = self.encoder_k_vision(images) k_audio = self.encoder_k_audio(audio_inputs) l_pos_V2A = torch.einsum('nc,nc->n', [q_vision, k_audio]).unsqueeze(-1) l_neg_V2A = torch.einsum( 'nc,ck->nk', [q_vision, self.queue_audio.clone().detach()]) logits_V2A = torch.cat([l_pos_V2A, l_neg_V2A], dim=1) / self.T l_pos_A2V = torch.einsum('nc,nc->n', [q_audio, k_vision]).unsqueeze(-1) l_neg_A2V = torch.einsum( 'nc,ck->nk', [q_audio, self.queue_vision.clone().detach()]) logits_A2V = torch.cat([l_pos_A2V, l_neg_A2V], dim=1) / self.T labels = torch.zeros( logits_V2A.shape[0], dtype=torch.long).to(q_vision.device) loss_V2A = F.cross_entropy(logits_V2A, labels) loss_A2V = F.cross_entropy(logits_A2V, labels) self._dequeue_and_enqueue(k_vision, k_audio) return loss_V2A + loss_A2V