|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
from tqdm import tqdm |
|
|
import argparse |
|
|
import os |
|
|
|
|
|
from .model import ResNet |
|
|
from .data import BeatTrackingDataset |
|
|
from ..baseline1.utils import MultiViewSpectrogram |
|
|
from ..data.load import ds |
|
|
|
|
|
|
|
|
def train(target_type: str, output_dir: str): |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
BATCH_SIZE = 128 |
|
|
EPOCHS = 3 |
|
|
LR = 0.001 |
|
|
NUM_WORKERS = 4 |
|
|
CONTEXT_FRAMES = 50 |
|
|
PATIENCE = 5 |
|
|
|
|
|
print(f"--- Training Model for target: {target_type} ---") |
|
|
print(f"Output directory: {output_dir}") |
|
|
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
writer = SummaryWriter(log_dir=os.path.join(output_dir, "logs")) |
|
|
|
|
|
|
|
|
train_dataset = BeatTrackingDataset( |
|
|
ds["train"], target_type=target_type, context_frames=CONTEXT_FRAMES |
|
|
) |
|
|
val_dataset = BeatTrackingDataset( |
|
|
ds["test"], target_type=target_type, context_frames=CONTEXT_FRAMES |
|
|
) |
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=BATCH_SIZE, |
|
|
shuffle=True, |
|
|
num_workers=NUM_WORKERS, |
|
|
pin_memory=True, |
|
|
prefetch_factor=4, |
|
|
persistent_workers=True, |
|
|
) |
|
|
val_loader = DataLoader( |
|
|
val_dataset, |
|
|
batch_size=BATCH_SIZE, |
|
|
shuffle=False, |
|
|
num_workers=NUM_WORKERS, |
|
|
pin_memory=True, |
|
|
prefetch_factor=4, |
|
|
persistent_workers=True, |
|
|
) |
|
|
|
|
|
print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}") |
|
|
|
|
|
|
|
|
model = ResNet(dropout_rate=0.5).to(DEVICE) |
|
|
|
|
|
|
|
|
preprocessor = MultiViewSpectrogram(sample_rate=16000, hop_length=160).to(DEVICE) |
|
|
|
|
|
|
|
|
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4) |
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS) |
|
|
criterion = nn.BCELoss() |
|
|
|
|
|
best_val_loss = float("inf") |
|
|
patience_counter = 0 |
|
|
global_step = 0 |
|
|
|
|
|
for epoch in range(EPOCHS): |
|
|
|
|
|
model.train() |
|
|
total_train_loss = 0 |
|
|
for waveform, y in tqdm( |
|
|
train_loader, |
|
|
desc=f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} Train", |
|
|
leave=False, |
|
|
): |
|
|
waveform, y = waveform.to(DEVICE), y.to(DEVICE) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
spec = preprocessor(waveform) |
|
|
|
|
|
mean = spec.mean(dim=(2, 3), keepdim=True) |
|
|
std = spec.std(dim=(2, 3), keepdim=True) + 1e-6 |
|
|
spec = (spec - mean) / std |
|
|
|
|
|
T_curr = spec.shape[-1] |
|
|
target_T = CONTEXT_FRAMES * 2 + 1 |
|
|
|
|
|
if T_curr > target_T: |
|
|
start = (T_curr - target_T) // 2 |
|
|
x = spec[:, :, :, start : start + target_T] |
|
|
elif T_curr < target_T: |
|
|
|
|
|
pad = target_T - T_curr |
|
|
x = torch.nn.functional.pad(spec, (0, pad)) |
|
|
else: |
|
|
x = spec |
|
|
|
|
|
optimizer.zero_grad() |
|
|
output = model(x) |
|
|
loss = criterion(output, y) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
total_train_loss += loss.item() |
|
|
global_step += 1 |
|
|
|
|
|
|
|
|
writer.add_scalar("train/batch_loss", loss.item(), global_step) |
|
|
|
|
|
avg_train_loss = total_train_loss / len(train_loader) |
|
|
|
|
|
|
|
|
model.eval() |
|
|
total_val_loss = 0 |
|
|
with torch.no_grad(): |
|
|
for waveform, y in tqdm( |
|
|
val_loader, |
|
|
desc=f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} Val", |
|
|
leave=False, |
|
|
): |
|
|
waveform, y = waveform.to(DEVICE), y.to(DEVICE) |
|
|
|
|
|
|
|
|
spec = preprocessor(waveform) |
|
|
|
|
|
mean = spec.mean(dim=(2, 3), keepdim=True) |
|
|
std = spec.std(dim=(2, 3), keepdim=True) + 1e-6 |
|
|
spec = (spec - mean) / std |
|
|
|
|
|
T_curr = spec.shape[-1] |
|
|
target_T = CONTEXT_FRAMES * 2 + 1 |
|
|
|
|
|
if T_curr > target_T: |
|
|
start = (T_curr - target_T) // 2 |
|
|
x = spec[:, :, :, start : start + target_T] |
|
|
else: |
|
|
pad = target_T - T_curr |
|
|
x = torch.nn.functional.pad(spec, (0, pad)) |
|
|
|
|
|
output = model(x) |
|
|
loss = criterion(output, y) |
|
|
total_val_loss += loss.item() |
|
|
|
|
|
avg_val_loss = total_val_loss / len(val_loader) |
|
|
|
|
|
|
|
|
writer.add_scalar("train/epoch_loss", avg_train_loss, epoch) |
|
|
writer.add_scalar("val/loss", avg_val_loss, epoch) |
|
|
writer.add_scalar("train/learning_rate", scheduler.get_last_lr()[0], epoch) |
|
|
|
|
|
|
|
|
scheduler.step() |
|
|
|
|
|
print( |
|
|
f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} - " |
|
|
f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}" |
|
|
) |
|
|
|
|
|
|
|
|
if avg_val_loss < best_val_loss: |
|
|
best_val_loss = avg_val_loss |
|
|
patience_counter = 0 |
|
|
model.save_pretrained(output_dir) |
|
|
print(f" -> Saved best model (val_loss: {best_val_loss:.4f})") |
|
|
else: |
|
|
patience_counter += 1 |
|
|
print(f" -> No improvement (patience: {patience_counter}/{PATIENCE})") |
|
|
|
|
|
if patience_counter >= PATIENCE: |
|
|
print("Early stopping triggered.") |
|
|
break |
|
|
|
|
|
writer.close() |
|
|
|
|
|
|
|
|
final_dir = os.path.join(output_dir, "final") |
|
|
model.save_pretrained(final_dir) |
|
|
print(f"Saved final model to {final_dir}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
"--target", |
|
|
type=str, |
|
|
choices=["beats", "downbeats"], |
|
|
default=None, |
|
|
help="Train a model for 'beats' or 'downbeats'. If not specified, trains both.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-dir", |
|
|
type=str, |
|
|
default="outputs/baseline2", |
|
|
help="Directory to save model and logs", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
targets = [args.target] if args.target else ["beats", "downbeats"] |
|
|
|
|
|
for target in targets: |
|
|
output_dir = os.path.join(args.output_dir, target) |
|
|
train(target, output_dir) |
|
|
|