Spaces:
Sleeping
Sleeping
| import os | |
| # print("CUDA_VISIBLE_DEVICES:", os.environ["CUDA_VISIBLE_DEVICES"]) | |
| # import torch | |
| # print("CUDA device count:", torch.cuda.device_count()) | |
| # print("CUDA current device:", torch.cuda.current_device()) | |
| # print("CUDA device name:", torch.cuda.get_device_name(torch.cuda.current_device())) | |
| # os.environ['CUDA_VISIBLE_DEVICES']="2,3" | |
| from torch.cuda import is_available as cuda_available, is_bf16_supported | |
| from torch.backends.mps import is_available as mps_available | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import yaml | |
| import json | |
| import pickle | |
| import os | |
| import random | |
| import deepspeed | |
| from tqdm import tqdm | |
| import torch | |
| from torch import Tensor, argmax | |
| from evaluate import load as load_metric | |
| import sys | |
| import argparse | |
| import jsonlines | |
| from data_loader import Text2MusicDataset | |
| from transformer_model import Transformer | |
| from torch.utils.data import DataLoader | |
| # Parse command line arguments | |
| # parser = argparse.ArgumentParser() | |
| # parser.add_argument("--config", type=str, default=os.path.normpath("configs/config.yaml"), | |
| # help="Path to the config file") | |
| # parser = deepspeed.add_config_arguments(parser) | |
| # args = parser.parse_args() | |
| config_file = "../configs/config.yaml" | |
| # Load config file | |
| with open(config_file, 'r') as f: ##args.config | |
| configs = yaml.safe_load(f) | |
| batch_size = configs['training']['text2midi_model']['batch_size'] | |
| learning_rate = configs['training']['text2midi_model']['learning_rate'] | |
| epochs = configs['training']['text2midi_model']['epochs'] | |
| # Artifact folder | |
| artifact_folder = configs['artifact_folder'] | |
| # Load encoder tokenizer json file dictionary | |
| tokenizer_filepath = os.path.join(artifact_folder, "vocab.pkl") | |
| # Load the tokenizer dictionary | |
| with open(tokenizer_filepath, "rb") as f: | |
| tokenizer = pickle.load(f) | |
| # Get the vocab size | |
| vocab_size = len(tokenizer)+1 | |
| print("Vocab size: ", vocab_size) | |
| caption_dataset_path = configs['raw_data']['caption_dataset_path'] | |
| # Load the caption dataset | |
| with jsonlines.open(caption_dataset_path) as reader: | |
| captions = list(reader) | |
| def collate_fn(batch): | |
| """ | |
| Collate function for the DataLoader | |
| :param batch: The batch | |
| :return: The collated batch | |
| """ | |
| input_ids = [item[0].squeeze(0) for item in batch] | |
| # Pad or trim batch to the same length | |
| input_ids = nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0) | |
| attention_mask = [item[1].squeeze(0) for item in batch] | |
| # Pad or trim batch to the same length | |
| attention_mask = nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0) | |
| labels = [item[2].squeeze(0) for item in batch] | |
| # Pad or trim batch to the same length | |
| labels = nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=0) | |
| return input_ids, attention_mask, labels | |
| # Load the dataset | |
| dataset = Text2MusicDataset(configs, captions, mode="train", shuffle = True) | |
| data_length = len(dataset) | |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn) | |
| # Create the encoder-decoder model | |
| # Initialize the model | |
| d_model = configs['model']['text2midi_model']['decoder_d_model'] # Model dimension (same as FLAN-T5 encoder output dimension) | |
| nhead = configs['model']['text2midi_model']['decoder_num_heads'] # Number of heads in the multiheadattention models | |
| num_layers = configs['model']['text2midi_model']['decoder_num_layers'] # Number of decoder layers | |
| max_len = configs['model']['text2midi_model']['decoder_max_sequence_length'] # Maximum length of the input sequence | |
| use_moe = configs['model']['text2midi_model']['use_moe'] # Use mixture of experts | |
| num_experts = configs['model']['text2midi_model']['num_experts'] # Number of experts in the mixture of experts | |
| dim_feedforward = configs['model']['text2midi_model']['decoder_intermediate_size'] # Dimension of the feedforward network model | |
| use_deepspeed = configs['model']['text2midi_model']['use_deepspeed'] # Use deepspeed | |
| if use_deepspeed: | |
| ds_config = configs['deepspeed_config']['deepspeed_config_path'] | |
| import deepspeed | |
| from deepspeed.accelerator import get_accelerator | |
| local_rank = int(os.environ['LOCAL_RANK']) | |
| device = (torch.device(get_accelerator().device_name(), local_rank) if (local_rank > -1) | |
| and get_accelerator().is_available() else torch.device("cpu")) | |
| deepspeed.init_distributed(dist_backend='nccl') | |
| torch.backends.cuda.enable_mem_efficient_sdp(False) | |
| torch.backends.cuda.enable_flash_sdp(False) | |
| else: | |
| device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu") | |
| print(f"Device: {device}") | |
| print_every = 10 | |
| model = Transformer(vocab_size, d_model, nhead, max_len, num_layers, dim_feedforward, use_moe, num_experts, device=device) | |
| # Print number of parameters | |
| num_params = sum(p.numel() for p in model.parameters()) | |
| print(f"Number of parameters: {num_params}") | |
| # Print number of trainable parameters | |
| num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(f"Number of trainable parameters: {num_trainable_params}") | |
| if not use_deepspeed: | |
| optimizer = optim.Adam(model.parameters(), lr=1e-4) | |
| criterion = nn.CrossEntropyLoss() | |
| torch.cuda.empty_cache() | |
| def train_model(model, dataloader, criterion, num_epochs, optimizer=None, data_length=1000): | |
| if use_deepspeed: | |
| parameters = filter(lambda p: p.requires_grad, model.parameters()) | |
| model, optimizer, _, _ = deepspeed.initialize(model=model, | |
| optimizer=optimizer, | |
| model_parameters=model.parameters(), | |
| config=ds_config) | |
| else: | |
| model = model.to(device) | |
| model.train() | |
| for epoch in range(num_epochs): | |
| total_loss = 0 | |
| with tqdm(total=int(data_length/batch_size), desc=f"Epoch {epoch + 1}/{num_epochs}") as pbar: | |
| for step, batch in enumerate(dataloader): | |
| if use_deepspeed: | |
| model.zero_grad() | |
| else: | |
| optimizer.zero_grad() | |
| # Get the batch | |
| encoder_input, attention_mask, tgt = batch | |
| # print(encoder_input.shape) | |
| encoder_input = encoder_input.to(device) | |
| attention_mask = attention_mask.to(device) | |
| tgt = tgt.to(device) | |
| tgt_input = tgt[:, :-1] | |
| tgt_output = tgt[:, 1:] | |
| if use_moe: | |
| outputs, aux_loss = model(encoder_input, attention_mask, tgt_input) | |
| else: | |
| outputs = model(encoder_input, attention_mask, tgt_input) | |
| aux_loss = 0 | |
| loss = criterion(outputs.view(-1, outputs.size(-1)), tgt_output.reshape(-1)) | |
| loss += aux_loss | |
| if use_deepspeed: | |
| model.backward(loss) | |
| model.step() | |
| else: | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| if step % print_every == 0: | |
| pbar.set_postfix({"Loss": loss.item()}) | |
| pbar.update(1) | |
| pbar.set_postfix({"Loss": total_loss / len(dataloader)}) | |
| pbar.update(1) | |
| print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(dataloader)}") | |
| # Train the model | |
| if use_deepspeed: | |
| train_model(model, dataloader, criterion, num_epochs=epochs) | |
| else: | |
| train_model(model, dataloader, criterion, num_epochs=epochs, optimizer=optimizer, data_length=data_length) | |
| # Save the trained model | |
| torch.save(model.state_dict(), "transformer_decoder_remi_plus.pth") | |
| print("Model saved as transformer_decoder_remi_plus.pth") | |