import itertools import json import random import numpy as np import torch from torch import nn import torch.nn.functional as F from tools import create_key from model.timbre_encoder_pretrain import get_timbre_encoder class ProjectionLayer(nn.Module): """Single-layer Linear projection with dropout, layer norm, and Gelu activation""" def __init__(self, input_dim, output_dim, dropout): super(ProjectionLayer, self).__init__() self.projection = nn.Linear(input_dim, output_dim) self.gelu = nn.GELU() self.fc = nn.Linear(output_dim, output_dim) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(output_dim) def forward(self, x): projected = self.projection(x) x = self.gelu(projected) x = self.fc(x) x = self.dropout(x) x = x + projected x = self.layer_norm(x) return x class ProjectionHead(nn.Module): """Stack of 'ProjectionLayer'""" def __init__(self, embedding_dim, projection_dim, dropout, num_layers=2): super(ProjectionHead, self).__init__() self.layers = nn.ModuleList([ProjectionLayer(embedding_dim if i == 0 else projection_dim, projection_dim, dropout) for i in range(num_layers)]) def forward(self, x): for layer in self.layers: x = layer(x) return x class multi_modal_model(nn.Module): """The multi-modal model for contrastive learning""" def __init__( self, timbre_encoder, text_encoder, spectrogram_feature_dim, text_feature_dim, multi_modal_emb_dim, temperature, dropout, num_projection_layers=1, freeze_spectrogram_encoder=True, freeze_text_encoder=True, ): super().__init__() self.timbre_encoder = timbre_encoder self.text_encoder = text_encoder self.multi_modal_emb_dim = multi_modal_emb_dim self.text_projection = ProjectionHead(embedding_dim=text_feature_dim, projection_dim=self.multi_modal_emb_dim, dropout=dropout, num_layers=num_projection_layers) self.spectrogram_projection = ProjectionHead(embedding_dim=spectrogram_feature_dim, projection_dim=self.multi_modal_emb_dim, dropout=dropout, num_layers=num_projection_layers) self.temperature = temperature # Make spectrogram_encoder parameters non-trainable for param in self.timbre_encoder.parameters(): param.requires_grad = not freeze_spectrogram_encoder # Make text_encoder parameters non-trainable for param in self.text_encoder.parameters(): param.requires_grad = not freeze_text_encoder def forward(self, spectrogram_batch, tokenized_text_batch): # Getting Image and Text Embeddings (with same dimension) spectrogram_features, _, _, _, _ = self.timbre_encoder(spectrogram_batch) text_features = self.text_encoder.get_text_features(**tokenized_text_batch) # Concat and apply projection spectrogram_embeddings = self.spectrogram_projection(spectrogram_features) text_embeddings = self.text_projection(text_features) # Calculating the Loss logits = (text_embeddings @ spectrogram_embeddings.T) / self.temperature images_similarity = spectrogram_embeddings @ spectrogram_embeddings.T texts_similarity = text_embeddings @ text_embeddings.T targets = F.softmax( (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1 ) texts_loss = cross_entropy(logits, targets, reduction='none') images_loss = cross_entropy(logits.T, targets.T, reduction='none') contrastive_loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size) contrastive_loss = contrastive_loss.mean() return contrastive_loss def get_text_features(self, input_ids, attention_mask): text_features = self.text_encoder.get_text_features(input_ids=input_ids, attention_mask=attention_mask) return self.text_projection(text_features) def get_timbre_features(self, spectrogram_batch): spectrogram_features, _, _, _, _ = self.timbre_encoder(spectrogram_batch) return self.spectrogram_projection(spectrogram_features) def cross_entropy(preds, targets, reduction='none'): log_softmax = nn.LogSoftmax(dim=-1) loss = (-targets * log_softmax(preds)).sum(1) if reduction == "none": return loss elif reduction == "mean": return loss.mean() def get_multi_modal_model(timbre_encoder, text_encoder, model_Config, load_pretrain=False, model_name=None, device="cpu"): mmm = multi_modal_model(timbre_encoder, text_encoder, **model_Config) print(f"Model intialized, size: {sum(p.numel() for p in mmm.parameters() if p.requires_grad)}") mmm.to(device) if load_pretrain: print(f"Loading weights from models/{model_name}_MMM.pth") checkpoint = torch.load(f'models/{model_name}_MMM.pth', map_location=device) mmm.load_state_dict(checkpoint['model_state_dict']) mmm.eval() return mmm def train_epoch(text_tokenizer, model, train_loader, labels_mapping, optimizer, device): (data, attributes) = next(iter(train_loader)) keys = [create_key(attribute) for attribute in attributes] while(len(set(keys)) != len(keys)): (data, attributes) = next(iter(train_loader)) keys = [create_key(attribute) for attribute in attributes] data = data.to(device) texts = [labels_mapping[create_key(attribute)] for attribute in attributes] selected_texts = [l[random.randint(0, len(l) - 1)] for l in texts] tokenized_text = text_tokenizer(selected_texts, padding=True, return_tensors="pt").to(device) loss = model(data, tokenized_text) optimizer.zero_grad() loss.backward() optimizer.step() return loss.item() def valid_epoch(text_tokenizer, model, valid_loader, labels_mapping, device): (data, attributes) = next(iter(valid_loader)) keys = [create_key(attribute) for attribute in attributes] while(len(set(keys)) != len(keys)): (data, attributes) = next(iter(valid_loader)) keys = [create_key(attribute) for attribute in attributes] data = data.to(device) texts = [labels_mapping[create_key(attribute)] for attribute in attributes] selected_texts = [l[random.randint(0, len(l) - 1)] for l in texts] tokenized_text = text_tokenizer(selected_texts, padding=True, return_tensors="pt").to(device) loss = model(data, tokenized_text) return loss.item() def train_multi_modal_model(device, training_dataloader, labels_mapping, text_tokenizer, text_encoder, timbre_encoder_Config, MMM_config, MMM_training_config, mmm_name, BATCH_SIZE, max_iter=0, load_pretrain=True, timbre_encoder_name=None, init_loss=None, save_steps=2000): def save_model_hyperparameter(model_name, MMM_config, MMM_training_config, BATCH_SIZE, model_size, current_iter, current_loss): model_hyperparameter = MMM_config model_hyperparameter.update(MMM_training_config) model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE model_hyperparameter["model_size"] = model_size model_hyperparameter["current_iter"] = current_iter model_hyperparameter["current_loss"] = current_loss with open(f"models/hyperparameters/{model_name}_MMM.json", "w") as json_file: json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4) timbreEncoder = get_timbre_encoder(timbre_encoder_Config, load_pretrain=True, model_name=timbre_encoder_name, device=device) mmm = multi_modal_model(timbreEncoder, text_encoder, **MMM_config).to(device) print(f"spectrogram_encoder parameter: {sum(p.numel() for p in mmm.timbre_encoder.parameters())}") print(f"text_encoder parameter: {sum(p.numel() for p in mmm.text_encoder.parameters())}") print(f"spectrogram_projection parameter: {sum(p.numel() for p in mmm.spectrogram_projection.parameters())}") print(f"text_projection parameter: {sum(p.numel() for p in mmm.text_projection.parameters())}") total_parameters = sum(p.numel() for p in mmm.parameters()) trainable_parameters = sum(p.numel() for p in mmm.parameters() if p.requires_grad) print(f"Trainable/Total parameter: {trainable_parameters}/{total_parameters}") params = [ {"params": itertools.chain( mmm.spectrogram_projection.parameters(), mmm.text_projection.parameters(), ), "lr": MMM_training_config["head_lr"], "weight_decay": MMM_training_config["head_weight_decay"]}, ] if not MMM_config["freeze_text_encoder"]: params.append({"params": mmm.text_encoder.parameters(), "lr": MMM_training_config["text_encoder_lr"], "weight_decay": MMM_training_config["text_encoder_weight_decay"]}) if not MMM_config["freeze_spectrogram_encoder"]: params.append({"params": mmm.timbre_encoder.parameters(), "lr": MMM_training_config["spectrogram_encoder_lr"], "weight_decay": MMM_training_config["timbre_encoder_weight_decay"]}) optimizer = torch.optim.AdamW(params, weight_decay=0.) if load_pretrain: print(f"Loading weights from models/{mmm_name}_MMM.pt") checkpoint = torch.load(f'models/{mmm_name}_MMM.pth') mmm.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) else: print("Model initialized.") if max_iter == 0: print("Return model directly.") return mmm, optimizer if init_loss is None: previous_lowest_loss = valid_epoch(text_tokenizer, mmm, training_dataloader, labels_mapping, device) else: previous_lowest_loss = init_loss print(f"Initial total loss: {previous_lowest_loss}") train_loss_list = [] for i in range(max_iter): mmm.train() train_loss = train_epoch(text_tokenizer, mmm, training_dataloader, labels_mapping, optimizer, device) train_loss_list.append(train_loss) step = int( optimizer.state_dict()['state'][list(optimizer.state_dict()['state'].keys())[0]]['step'].cpu().numpy()) if (i + 1) % 100 == 0: print('%d step' % (step)) if (i + 1) % save_steps == 0: current_loss = np.mean(train_loss_list[-save_steps:]) print(f"train_total_loss: {current_loss}") if current_loss < previous_lowest_loss: previous_lowest_loss = current_loss torch.save({ 'model_state_dict': mmm.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, f'models/{mmm_name}_MMM.pth') save_model_hyperparameter(mmm_name, MMM_config, MMM_training_config, BATCH_SIZE, total_parameters, step, current_loss) return mmm, optimizer