ccss17's picture
readme
86a3f14
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset, DatasetDict
from tqdm.auto import tqdm
import copy
REPO_ID = "ccss17/note-duration-dataset"
SEQUENCE_LENGTH = 128
PAD_TOKEN = 0
BATCH_SIZE = 64
MODEL_CONFIG = {
"num_singers": 18,
"singer_embedding_dim": 32,
"hidden_size": 256,
"num_layers": 3,
"dropout": 0.4,
}
LEARNING_RATE = 1e-4
NUM_EPOCHS = 100
EARLY_STOPPING_PATIENCE = 5
GPU_ID = 0
class ToyDurationPredictor(nn.Module):
def __init__(
self,
hidden_size,
num_layers,
dropout,
num_singers,
singer_embedding_dim,
):
super().__init__()
self.singer_embedding = nn.Embedding(num_singers, singer_embedding_dim)
self.rnn = nn.GRU(
input_size=1 + singer_embedding_dim,
hidden_size=hidden_size,
num_layers=num_layers,
bidirectional=True,
dropout=dropout,
batch_first=True,
)
self.fc = nn.Linear(hidden_size * 2, 1)
def forward(self, x_seq, x_sid):
x_seq = x_seq.unsqueeze(-1).float()
sid_emb = self.singer_embedding(x_sid)
sid_emb_expanded = sid_emb.unsqueeze(1).expand(-1, x_seq.size(1), -1)
combined_input = torch.cat([x_seq, sid_emb_expanded], dim=-1)
outputs, _ = self.rnn(combined_input)
prediction = self.fc(outputs)
return prediction.squeeze(-1)
class DurationDataset(Dataset):
def __init__(self, processed_hf_dataset, model_type, mean, std):
self.processed_dataset = processed_hf_dataset
self.model_type = model_type
self.mean = mean
self.std = std
def __len__(self):
return len(self.processed_dataset)
def __getitem__(self, idx):
item = self.processed_dataset[idx]
original_durations = torch.tensor(
item["durations"], dtype=torch.float32
)
label = (torch.clamp(original_durations, 0, 1000) - self.mean) / (
self.std + 1e-8
)
if self.model_type == "A":
input_seq = torch.tensor(item["durations"], dtype=torch.long)
else:
input_seq = torch.tensor(
item["quantized_durations"], dtype=torch.long
)
singer_idx = torch.tensor(item["singer_idx"], dtype=torch.long)
return {"input_seq": input_seq, "singer_id": singer_idx}, label
def _get_base_dataset_and_metadata():
"""Helper function to load, map, and split the base dataset."""
dataset = load_dataset(REPO_ID, split="train", trust_remote_code=True)
unique_singer_ids = sorted(dataset.unique("singer_id"))
singer_id_map = {sid: i for i, sid in enumerate(unique_singer_ids)}
dataset = dataset.map(
lambda ex: {"singer_idx": singer_id_map[ex["singer_id"]]}
)
train_testvalid = dataset.train_test_split(test_size=0.2, seed=42)
test_valid = train_testvalid["test"].train_test_split(
test_size=0.5, seed=42
)
split_dataset = DatasetDict(
{
"train": train_testvalid["train"],
"validation": test_valid["train"],
"test": test_valid["test"],
}
)
return split_dataset, singer_id_map
def _chunk_and_normalize_data(split_dataset):
"""Helper function to chunk data and calculate normalization stats."""
all_training_durations = [
item
for sublist in split_dataset["train"]["durations"]
for item in sublist
]
durations_tensor = torch.clamp(
torch.tensor(all_training_durations, dtype=torch.float32), 0, 1000
)
mean, std = durations_tensor.mean(), durations_tensor.std()
print(f"Calculated training set stats: Mean={mean:.2f}, Std={std:.2f}")
def chunk_examples_batched(examples):
chunked_output = {
"durations": [],
"quantized_durations": [],
"singer_idx": [],
}
for i in range(len(examples["durations"])):
durs, q_durs, s_idx = (
examples["durations"][i],
examples["quantized_durations"][i],
examples["singer_idx"][i],
)
for j in range(0, len(durs), SEQUENCE_LENGTH):
d_chunk = durs[j : j + SEQUENCE_LENGTH]
q_chunk = q_durs[j : j + SEQUENCE_LENGTH]
if len(d_chunk) < SEQUENCE_LENGTH:
padding = [PAD_TOKEN] * (SEQUENCE_LENGTH - len(d_chunk))
d_chunk.extend(padding)
q_chunk.extend(padding)
chunked_output["durations"].append(d_chunk)
chunked_output["quantized_durations"].append(q_chunk)
chunked_output["singer_idx"].append(s_idx)
return chunked_output
processed_train = split_dataset["train"].map(
chunk_examples_batched,
batched=True,
remove_columns=split_dataset["train"].column_names,
)
processed_val = split_dataset["validation"].map(
chunk_examples_batched,
batched=True,
remove_columns=split_dataset["validation"].column_names,
)
processed_test = split_dataset["test"].map(
chunk_examples_batched,
batched=True,
remove_columns=split_dataset["test"].column_names,
)
return processed_train, processed_val, processed_test, mean, std
def prepare_train_and_val_loaders(model_type):
"""Prepares and returns only the train and validation dataloaders."""
print(f"--- Preparing train/val data for Model {model_type} ---")
split_dataset, _ = _get_base_dataset_and_metadata()
processed_train, processed_val, _, mean, std = _chunk_and_normalize_data(
split_dataset
)
train_ds = DurationDataset(processed_train, model_type, mean, std)
val_ds = DurationDataset(processed_val, model_type, mean, std)
train_loader = DataLoader(
train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2
)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, num_workers=2)
return train_loader, val_loader
def prepare_test_loader():
"""Prepares and returns only the test dataloader and normalization stats."""
print("--- Preparing test data ---")
split_dataset, _ = _get_base_dataset_and_metadata()
_, _, processed_test, mean, std = _chunk_and_normalize_data(split_dataset)
# Test loader always uses model_type 'B' for quantized input
test_ds = DurationDataset(processed_test, "B", mean, std)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, num_workers=2)
return test_loader, mean, std
def train_model(model_type="B", gpu_id=0):
print(
f"\n{'=' * 20} TRAINING MODEL {model_type} ON GPU {gpu_id} {'=' * 20}"
)
device = torch.device(
f"cuda:{gpu_id}"
if torch.cuda.is_available() and gpu_id is not None
else "cpu"
)
print(f"Using device: {device}")
train_loader, val_loader = prepare_train_and_val_loaders(model_type)
config = MODEL_CONFIG
model = ToyDurationPredictor(**config).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.MSELoss()
best_val_loss = float("inf")
patience_counter = 0
best_model_state = None
for epoch in range(NUM_EPOCHS):
model.train()
progress_bar = tqdm(
train_loader, desc=f"Epoch {epoch + 1}/{NUM_EPOCHS}"
)
for batch in progress_bar:
inputs, labels = batch
input_seq, singer_id, labels = (
inputs["input_seq"].to(device),
inputs["singer_id"].to(device),
labels.to(device),
)
predictions = model(input_seq, singer_id)
loss = loss_fn(predictions, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
progress_bar.set_postfix(loss=f"{loss.item():.4f}")
model.eval()
val_loss = 0
with torch.no_grad():
for batch in val_loader:
inputs, labels = batch
input_seq, singer_id, labels = (
inputs["input_seq"].to(device),
inputs["singer_id"].to(device),
labels.to(device),
)
predictions = model(input_seq, singer_id)
val_loss += loss_fn(predictions, labels).item()
avg_val_loss = val_loss / len(val_loader)
print(f"Epoch {epoch + 1} - Validation Loss: {avg_val_loss:.4f}")
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
patience_counter = 0
best_model_state = copy.deepcopy(model.state_dict())
print(f"New best model found! Saving state.")
else:
patience_counter += 1
print(
f"No improvement in validation loss for {patience_counter} epoch(s)."
)
if patience_counter >= EARLY_STOPPING_PATIENCE:
print(
f"Stopping early after {patience_counter} epochs with no improvement."
)
break
model_save_path = f"./model_{model_type}.pth"
if best_model_state:
torch.save(best_model_state, model_save_path)
print(f"Best model for '{model_type}' saved to {model_save_path}")
return model_save_path
def evaluate_and_compare(model_A_path, model_B_path, gpu_id=0):
print(f"\n{'=' * 20} FINAL MODEL COMPARISON {'=' * 20}")
device = torch.device(
f"cuda:{gpu_id}"
if torch.cuda.is_available() and gpu_id is not None
else "cpu"
)
print(f"Using device for evaluation: {device}")
if isinstance(model_A_path, str) and isinstance(model_B_path, str):
config = MODEL_CONFIG
model_A = ToyDurationPredictor(**config).to(device)
model_A.load_state_dict(torch.load(model_A_path, map_location=device))
model_A.eval()
model_B = ToyDurationPredictor(**config).to(device)
model_B.load_state_dict(torch.load(model_B_path, map_location=device))
model_B.eval()
elif isinstance(model_A_path, nn.Module) and isinstance(
model_B_path, nn.Module
):
model_A = model_A_path
model_B = model_B_path
else:
raise NotImplementedError
test_loader, mean, std = prepare_test_loader()
total_mae_A, total_mae_B, total_samples = 0, 0, 0
print("Evaluating models on the test set...")
for batch in tqdm(test_loader, desc="Final Evaluation"):
inputs, labels_norm = batch
input_seq, singer_id, labels_norm = (
inputs["input_seq"].to(device),
inputs["singer_id"].to(device),
labels_norm.to(device),
)
with torch.no_grad():
preds_A_norm = model_A(input_seq, singer_id)
preds_B_norm = model_B(input_seq, singer_id)
preds_A = (preds_A_norm * std) + mean
preds_B = (preds_B_norm * std) + mean
original_labels = (labels_norm * std) + mean
total_mae_A += nn.functional.l1_loss(
preds_A, original_labels, reduction="sum"
).item()
total_mae_B += nn.functional.l1_loss(
preds_B, original_labels, reduction="sum"
).item()
total_samples += original_labels.numel()
avg_mae_A = total_mae_A / total_samples
avg_mae_B = total_mae_B / total_samples
print(f"\nFinal Test MAE for Model A (Control): {avg_mae_A:.4f} ticks")
print(f"Final Test MAE for Model B (Your Method): {avg_mae_B:.4f} ticks")
if __name__ == "__main__":
model_a_path = train_model(model_type="A", gpu_id=0)
model_b_path = train_model(model_type="B", gpu_id=1)
evaluate_and_compare(model_a_path, model_b_path, gpu_id=0)