Spaces:
Sleeping
Sleeping
| import os | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import yaml | |
| import math | |
| import time | |
| from transformers import get_scheduler | |
| import wandb | |
| import pickle | |
| import numpy as np | |
| import json | |
| import jsonlines | |
| from tqdm import tqdm | |
| import torch | |
| from accelerate import DistributedDataParallelKwargs, Accelerator | |
| from accelerate.logging import get_logger | |
| from data_loader_remi import Text2MusicDataset | |
| from transformer_model import Transformer | |
| from torch.utils.data import DataLoader | |
| import logging | |
| logger = get_logger(__name__) | |
| # Load config file | |
| config_file = "../configs/config.yaml" | |
| with open(config_file, 'r') as f: | |
| configs = yaml.safe_load(f) | |
| batch_size = configs['training']['text2midi_model']['batch_size'] | |
| learning_rate = configs['training']['text2midi_model']['learning_rate'] | |
| epochs = configs['training']['text2midi_model']['epochs'] | |
| artifact_folder = configs['artifact_folder'] | |
| tokenizer_filepath = os.path.join(artifact_folder, "vocab_remi.pkl") | |
| with open(tokenizer_filepath, "rb") as f: | |
| tokenizer = pickle.load(f) | |
| vocab_size = len(tokenizer) | |
| caption_dataset_path = configs['raw_data']['caption_dataset_path'] | |
| # Load the caption dataset | |
| with jsonlines.open(caption_dataset_path) as reader: | |
| captions = list(reader) | |
| # captions = list(reader) | |
| def collate_fn(batch): | |
| input_ids = [item[0].squeeze(0) for item in batch] | |
| input_ids = nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0) | |
| attention_mask = [item[1].squeeze(0) for item in batch] | |
| attention_mask = nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0) | |
| labels = [item[2].squeeze(0) for item in batch] | |
| labels = nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=0) | |
| return input_ids, attention_mask, labels | |
| d_model = configs['model']['text2midi_model']['decoder_d_model'] | |
| nhead = configs['model']['text2midi_model']['decoder_num_heads'] | |
| num_layers = configs['model']['text2midi_model']['decoder_num_layers'] | |
| max_len = configs['model']['text2midi_model']['decoder_max_sequence_length'] | |
| use_moe = configs['model']['text2midi_model']['use_moe'] | |
| num_experts = configs['model']['text2midi_model']['num_experts'] | |
| dim_feedforward = configs['model']['text2midi_model']['decoder_intermediate_size'] | |
| gradient_accumulation_steps = configs['training']['text2midi_model']['gradient_accumulation_steps'] | |
| use_scheduler = configs['training']['text2midi_model']['use_scheduler'] | |
| checkpointing_steps = configs['training']['text2midi_model']['checkpointing_steps'] | |
| lr_scheduler_type = configs['training']['text2midi_model']['lr_scheduler_type'] | |
| num_warmup_steps = configs['training']['text2midi_model']['num_warmup_steps'] | |
| max_train_steps = configs['training']['text2midi_model']['max_train_steps'] | |
| with_tracking = configs['training']['text2midi_model']['with_tracking'] | |
| report_to = configs['training']['text2midi_model']['report_to'] | |
| output_dir = configs['training']['text2midi_model']['output_dir'] | |
| per_device_train_batch_size = configs['training']['text2midi_model']['per_device_train_batch_size'] | |
| save_every = configs['training']['text2midi_model']['save_every'] | |
| accelerator_log_kwargs = {} | |
| if with_tracking: | |
| accelerator_log_kwargs["log_with"] = report_to | |
| # Remove the logging_dir argument in case of error | |
| accelerator_log_kwargs["logging_dir"] = output_dir | |
| accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps, mixed_precision='fp16', kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)], **accelerator_log_kwargs) | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| level=logging.INFO, | |
| ) | |
| logger.info(accelerator.state, main_process_only=False) | |
| if accelerator.is_main_process: | |
| if output_dir is None or output_dir == "": | |
| output_dir = "saved/" + str(int(time.time())) | |
| if not os.path.exists("saved"): | |
| os.makedirs("saved") | |
| os.makedirs(output_dir, exist_ok=True) | |
| elif output_dir is not None: | |
| os.makedirs(output_dir, exist_ok=True) | |
| os.makedirs("{}/{}".format(output_dir, "outputs"), exist_ok=True) | |
| accelerator.project_configuration.automatic_checkpoint_naming = False | |
| wandb.login() | |
| wandb.init(project="Text-2-Midi", settings=wandb.Settings(init_timeout=120)) | |
| accelerator.wait_for_everyone() | |
| device = accelerator.device | |
| with accelerator.main_process_first(): | |
| dataset = Text2MusicDataset(configs, captions, remi_tokenizer=tokenizer, mode="train", shuffle=True) | |
| dataloader = DataLoader(dataset, batch_size=per_device_train_batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn, drop_last=True) | |
| model = Transformer(vocab_size, d_model, nhead, max_len, num_layers, dim_feedforward, use_moe, num_experts, device=device) | |
| model.load_state_dict(torch.load('/root/output_test_new/epoch_68/pytorch_model.bin', map_location=device)) | |
| def count_parameters(model): | |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| total_params = count_parameters(model) | |
| print(f"Total number of trainable parameters: {total_params}") | |
| optimizer = optim.Adam(model.parameters(), lr=1e-4) | |
| overrode_max_train_steps = False | |
| num_update_steps_per_epoch = math.ceil(len(dataloader) / gradient_accumulation_steps) | |
| print("num_update_steps_per_epoch", num_update_steps_per_epoch) | |
| print("max_train_steps", max_train_steps) | |
| if max_train_steps == 'None': | |
| max_train_steps = epochs * num_update_steps_per_epoch | |
| print("max_train_steps", max_train_steps) | |
| overrode_max_train_steps = True | |
| num_warmup_steps = 20000 | |
| elif isinstance(max_train_steps, str): | |
| max_train_steps = int(max_train_steps) | |
| lr_scheduler = get_scheduler( | |
| name=lr_scheduler_type, | |
| optimizer=optimizer, | |
| num_warmup_steps=num_warmup_steps, | |
| num_training_steps=max_train_steps, | |
| ) | |
| model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) | |
| dataloader = accelerator.prepare(dataloader) | |
| if overrode_max_train_steps: | |
| max_train_steps = epochs * num_update_steps_per_epoch | |
| epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) | |
| # checkpointing_steps = checkpointing_steps if checkpointing_steps.isdigit() else None | |
| total_batch_size = per_device_train_batch_size * accelerator.num_processes * gradient_accumulation_steps | |
| logger.info("***** Running training *****") | |
| logger.info(f" Num examples = {len(dataset)}") | |
| logger.info(f" Num Epochs = {epochs}") | |
| logger.info(f" Instantaneous batch size per device = {per_device_train_batch_size}") | |
| logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | |
| logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") | |
| logger.info(f" Total optimization steps = {max_train_steps}") | |
| criterion = nn.CrossEntropyLoss() | |
| def train_model_accelerate(model, dataloader, criterion, num_epochs, max_train_steps, optimizer=None, out_dir=None, checkpointing_steps='epoch', with_tracking=False, save_every=5, device='cpu'): | |
| progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process) | |
| completed_steps = 0 | |
| starting_epoch = 68 | |
| model = model.to(device) | |
| model.train() | |
| best_loss = np.inf | |
| for epoch in range(starting_epoch, num_epochs): | |
| total_loss = 0 | |
| for step, batch in enumerate(dataloader): | |
| with accelerator.accumulate(model): | |
| encoder_input, attention_mask, tgt = batch | |
| encoder_input = encoder_input.to(device) | |
| attention_mask = attention_mask.to(device) | |
| tgt = tgt.to(device) | |
| tgt_input = tgt[:, :-1] | |
| tgt_output = tgt[:, 1:] | |
| if use_moe: | |
| outputs, aux_loss = model(encoder_input, attention_mask, tgt_input) | |
| else: | |
| outputs = model(encoder_input, attention_mask, tgt_input) | |
| aux_loss = 0 | |
| loss = criterion(outputs.view(-1, outputs.size(-1)), tgt_output.reshape(-1)) | |
| loss += aux_loss | |
| total_loss += loss.detach().float() | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| if accelerator.sync_gradients: | |
| progress_bar.set_postfix({"Loss": loss.item()}) | |
| progress_bar.update(1) | |
| completed_steps += 1 | |
| if accelerator.is_main_process: | |
| result = {} | |
| result["epoch"] = epoch+1 | |
| result["step"] = completed_steps | |
| result["train_loss"] = round(total_loss.item()/(gradient_accumulation_steps*completed_steps),4) | |
| wandb.log(result) | |
| if isinstance(checkpointing_steps, int): | |
| if completed_steps % checkpointing_steps == 0: | |
| output_dir = f"step_{completed_steps }" | |
| if out_dir is not None: | |
| output_dir = os.path.join(out_dir, output_dir) | |
| accelerator.save_state(output_dir) | |
| if completed_steps >= max_train_steps: | |
| break | |
| if accelerator.is_main_process: | |
| result = {} | |
| result["epoch"] = epoch+1 | |
| result["step"] = completed_steps | |
| result["train_loss"] = round(total_loss.item()/len(dataloader), 4) | |
| result_string = "Epoch: {}, Loss Train: {}\n".format(epoch, result["train_loss"]) | |
| accelerator.print(result_string) | |
| with open("{}/summary.jsonl".format(out_dir), "a") as f: | |
| f.write(json.dumps(result) + "\n\n") | |
| logger.info(result) | |
| if accelerator.is_main_process: | |
| if total_loss < best_loss: | |
| best_loss = total_loss | |
| save_checkpoint = True | |
| else: | |
| save_checkpoint = False | |
| accelerator.wait_for_everyone() | |
| if accelerator.is_main_process and checkpointing_steps == "best": | |
| if save_checkpoint: | |
| accelerator.save_state("{}/{}".format(out_dir, "best")) | |
| if (epoch + 1) % save_every == 0: | |
| logger.info("Saving checkpoint at epoch {}".format(epoch+1)) | |
| accelerator.save_state("{}/{}".format(out_dir, "epoch_" + str(epoch+1))) | |
| if accelerator.is_main_process and checkpointing_steps == "epoch": | |
| accelerator.save_state("{}/{}".format(out_dir, "epoch_" + str(epoch+1))) | |
| train_model_accelerate(model, dataloader, criterion, num_epochs=epochs, max_train_steps=max_train_steps, | |
| optimizer=optimizer, out_dir=output_dir, checkpointing_steps=checkpointing_steps, | |
| with_tracking=with_tracking, save_every=save_every, device=device) | |
| # torch.save(model.state_dict(), "transformer_decoder_remi_plus.pth") | |
| # print("Model saved as transformer_decoder_remi_plus.pth") | |