kmaes's picture
Upload 27 files
b148e11 verified
import os
# print("CUDA_VISIBLE_DEVICES:", os.environ["CUDA_VISIBLE_DEVICES"])
# import torch
# print("CUDA device count:", torch.cuda.device_count())
# print("CUDA current device:", torch.cuda.current_device())
# print("CUDA device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
# os.environ['CUDA_VISIBLE_DEVICES']="2,3"
from torch.cuda import is_available as cuda_available, is_bf16_supported
from torch.backends.mps import is_available as mps_available
import torch.nn as nn
import torch.optim as optim
import yaml
import json
import pickle
import os
import random
import deepspeed
from tqdm import tqdm
import torch
from torch import Tensor, argmax
from evaluate import load as load_metric
import sys
import argparse
import jsonlines
from data_loader import Text2MusicDataset
from transformer_model import Transformer
from torch.utils.data import DataLoader
# Parse command line arguments
# parser = argparse.ArgumentParser()
# parser.add_argument("--config", type=str, default=os.path.normpath("configs/config.yaml"),
# help="Path to the config file")
# parser = deepspeed.add_config_arguments(parser)
# args = parser.parse_args()
config_file = "../configs/config.yaml"
# Load config file
with open(config_file, 'r') as f: ##args.config
configs = yaml.safe_load(f)
batch_size = configs['training']['text2midi_model']['batch_size']
learning_rate = configs['training']['text2midi_model']['learning_rate']
epochs = configs['training']['text2midi_model']['epochs']
# Artifact folder
artifact_folder = configs['artifact_folder']
# Load encoder tokenizer json file dictionary
tokenizer_filepath = os.path.join(artifact_folder, "vocab.pkl")
# Load the tokenizer dictionary
with open(tokenizer_filepath, "rb") as f:
tokenizer = pickle.load(f)
# Get the vocab size
vocab_size = len(tokenizer)+1
print("Vocab size: ", vocab_size)
caption_dataset_path = configs['raw_data']['caption_dataset_path']
# Load the caption dataset
with jsonlines.open(caption_dataset_path) as reader:
captions = list(reader)
def collate_fn(batch):
"""
Collate function for the DataLoader
:param batch: The batch
:return: The collated batch
"""
input_ids = [item[0].squeeze(0) for item in batch]
# Pad or trim batch to the same length
input_ids = nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0)
attention_mask = [item[1].squeeze(0) for item in batch]
# Pad or trim batch to the same length
attention_mask = nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
labels = [item[2].squeeze(0) for item in batch]
# Pad or trim batch to the same length
labels = nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=0)
return input_ids, attention_mask, labels
# Load the dataset
dataset = Text2MusicDataset(configs, captions, mode="train", shuffle = True)
data_length = len(dataset)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)
# Create the encoder-decoder model
# Initialize the model
d_model = configs['model']['text2midi_model']['decoder_d_model'] # Model dimension (same as FLAN-T5 encoder output dimension)
nhead = configs['model']['text2midi_model']['decoder_num_heads'] # Number of heads in the multiheadattention models
num_layers = configs['model']['text2midi_model']['decoder_num_layers'] # Number of decoder layers
max_len = configs['model']['text2midi_model']['decoder_max_sequence_length'] # Maximum length of the input sequence
use_moe = configs['model']['text2midi_model']['use_moe'] # Use mixture of experts
num_experts = configs['model']['text2midi_model']['num_experts'] # Number of experts in the mixture of experts
dim_feedforward = configs['model']['text2midi_model']['decoder_intermediate_size'] # Dimension of the feedforward network model
use_deepspeed = configs['model']['text2midi_model']['use_deepspeed'] # Use deepspeed
if use_deepspeed:
ds_config = configs['deepspeed_config']['deepspeed_config_path']
import deepspeed
from deepspeed.accelerator import get_accelerator
local_rank = int(os.environ['LOCAL_RANK'])
device = (torch.device(get_accelerator().device_name(), local_rank) if (local_rank > -1)
and get_accelerator().is_available() else torch.device("cpu"))
deepspeed.init_distributed(dist_backend='nccl')
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
else:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
print_every = 10
model = Transformer(vocab_size, d_model, nhead, max_len, num_layers, dim_feedforward, use_moe, num_experts, device=device)
# Print number of parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {num_params}")
# Print number of trainable parameters
num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_trainable_params}")
if not use_deepspeed:
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
torch.cuda.empty_cache()
def train_model(model, dataloader, criterion, num_epochs, optimizer=None, data_length=1000):
if use_deepspeed:
parameters = filter(lambda p: p.requires_grad, model.parameters())
model, optimizer, _, _ = deepspeed.initialize(model=model,
optimizer=optimizer,
model_parameters=model.parameters(),
config=ds_config)
else:
model = model.to(device)
model.train()
for epoch in range(num_epochs):
total_loss = 0
with tqdm(total=int(data_length/batch_size), desc=f"Epoch {epoch + 1}/{num_epochs}") as pbar:
for step, batch in enumerate(dataloader):
if use_deepspeed:
model.zero_grad()
else:
optimizer.zero_grad()
# Get the batch
encoder_input, attention_mask, tgt = batch
# print(encoder_input.shape)
encoder_input = encoder_input.to(device)
attention_mask = attention_mask.to(device)
tgt = tgt.to(device)
tgt_input = tgt[:, :-1]
tgt_output = tgt[:, 1:]
if use_moe:
outputs, aux_loss = model(encoder_input, attention_mask, tgt_input)
else:
outputs = model(encoder_input, attention_mask, tgt_input)
aux_loss = 0
loss = criterion(outputs.view(-1, outputs.size(-1)), tgt_output.reshape(-1))
loss += aux_loss
if use_deepspeed:
model.backward(loss)
model.step()
else:
loss.backward()
optimizer.step()
total_loss += loss.item()
if step % print_every == 0:
pbar.set_postfix({"Loss": loss.item()})
pbar.update(1)
pbar.set_postfix({"Loss": total_loss / len(dataloader)})
pbar.update(1)
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(dataloader)}")
# Train the model
if use_deepspeed:
train_model(model, dataloader, criterion, num_epochs=epochs)
else:
train_model(model, dataloader, criterion, num_epochs=epochs, optimizer=optimizer, data_length=data_length)
# Save the trained model
torch.save(model.state_dict(), "transformer_decoder_remi_plus.pth")
print("Model saved as transformer_decoder_remi_plus.pth")