| import os |
| import argparse |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| import h5py |
| import librosa |
| import pretty_midi |
| import soundfile as sf |
| import torchaudio |
| from tqdm import tqdm |
| from sklearn.metrics import f1_score, precision_score, recall_score |
| from transformers import WavLMModel, Wav2Vec2Model |
| import math |
| import logging |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)s | %(message)s') |
|
|
| |
| try: |
| torchaudio.set_audio_backend("soundfile") |
| except: |
| pass |
|
|
| |
| |
| |
|
|
| def compute_onset_labels(frame_labels, threshold=0.5): |
| """ |
| Compute onset labels from frame labels (from drum_train_sota.py). |
| Onset = frame is active AND previous frame was inactive. |
| """ |
| active = (frame_labels > threshold).float() |
| prev_active = F.pad(active[:, :-1], (0, 0, 1, 0), value=0) |
| onsets = active * (1 - prev_active) |
| return onsets |
|
|
|
|
| def compute_mel_spectrogram(waveform, sr=16000, n_mels=64, hop_length=320, n_fft=1024): |
| """Compute Mel Spectrogram matching CNNSA training params.""" |
| if isinstance(waveform, torch.Tensor): |
| waveform = waveform.numpy() |
| |
| if waveform.ndim > 1: |
| waveform = waveform.squeeze() |
|
|
| mel = librosa.feature.melspectrogram( |
| y=waveform, |
| sr=sr, |
| n_fft=n_fft, |
| hop_length=hop_length, |
| n_mels=n_mels |
| ) |
| mel = librosa.power_to_db(mel, ref=np.max) |
| return torch.tensor(mel, dtype=torch.float32) |
|
|
|
|
| def compute_hcqt(waveform, sr=22050, hop_length=512, harmonics=[1, 2, 3]): |
| """Compute HCQT (from bass_train_sota.py)""" |
| if isinstance(waveform, torch.Tensor): |
| y = waveform.squeeze().cpu().numpy() |
| else: |
| y = waveform |
|
|
| fmin = librosa.note_to_hz("E1") |
| bins_per_octave = 12 |
| n_octaves = 6 |
| n_bins = n_octaves * bins_per_octave |
|
|
| hcqt = [] |
| for h in harmonics: |
| cqt = librosa.cqt( |
| y=y, |
| sr=sr, |
| hop_length=hop_length, |
| fmin=fmin * h, |
| n_bins=n_bins, |
| bins_per_octave=bins_per_octave |
| ) |
| hcqt.append(np.abs(cqt)) |
|
|
| hcqt = np.log(np.stack(hcqt) + 1e-9) |
| return torch.from_numpy(hcqt).float().permute(0, 2, 1) |
|
|
|
|
| |
| |
| |
|
|
| def calculate_metrics(pred_logits, target_labels, threshold=0.5): |
| """ |
| Calculate Frame F1, Onset F1, Precision, Recall. |
| """ |
| preds = (torch.sigmoid(pred_logits) > threshold).float() |
| |
| |
| preds_flat = preds.cpu().numpy().flatten() |
| targets_flat = target_labels.cpu().numpy().flatten() |
| |
| |
| frame_f1 = f1_score(targets_flat, preds_flat, zero_division=0) |
| frame_precision = precision_score(targets_flat, preds_flat, zero_division=0) |
| frame_recall = recall_score(targets_flat, preds_flat, zero_division=0) |
| |
| |
| pred_onsets = compute_onset_labels(preds).cpu().numpy().flatten() |
| target_onsets = compute_onset_labels(target_labels).cpu().numpy().flatten() |
| |
| onset_f1 = f1_score(target_onsets, pred_onsets, zero_division=0) |
| onset_precision = precision_score(target_onsets, pred_onsets, zero_division=0) |
| onset_recall = recall_score(target_onsets, pred_onsets, zero_division=0) |
| |
| return { |
| 'frame_f1': frame_f1, |
| 'frame_precision': frame_precision, |
| 'frame_recall': frame_recall, |
| 'onset_f1': onset_f1, |
| 'onset_precision': onset_precision, |
| 'onset_recall': onset_recall |
| } |
|
|
|
|
| |
| |
| |
|
|
| class DrumEvalDataset(Dataset): |
| def __init__(self, h5_path): |
| self.h5_path = h5_path |
| with h5py.File(h5_path, "r") as f: |
| self.length = f["audio"].shape[0] |
| logging.info(f"Drum dataset: {self.length} samples") |
|
|
| def __len__(self): |
| return self.length |
|
|
| def __getitem__(self, idx): |
| with h5py.File(self.h5_path, "r") as f: |
| audio = torch.from_numpy(f["audio"][idx]).float() |
| labels = torch.from_numpy(f["labels"][idx]).float() |
| |
| |
| sota_input = audio |
| |
| |
| |
| comp_input = compute_mel_spectrogram(audio, sr=16000, n_mels=64, hop_length=256) |
| |
| return { |
| "sota_input": sota_input, |
| "comp_input": comp_input, |
| "labels": labels |
| } |
|
|
|
|
| class BassEvalDataset(Dataset): |
| def __init__(self, audio_dir, midi_dir): |
| import glob |
| self.pairs = [] |
| for af in sorted(glob.glob(os.path.join(audio_dir, "*.flac"))): |
| base = os.path.splitext(os.path.basename(af))[0] |
| if base.startswith('._'): |
| continue |
| mf = os.path.join(midi_dir, base + ".mid") |
| if not os.path.exists(mf): |
| mf = os.path.join(midi_dir, base + ".midi") |
| if os.path.exists(mf): |
| self.pairs.append((af, mf)) |
| |
| logging.info(f"Bass dataset: {len(self.pairs)} pairs") |
|
|
| def __len__(self): |
| return len(self.pairs) |
|
|
| def __getitem__(self, idx): |
| audio_path, midi_path = self.pairs[idx] |
| |
| try: |
| audio_data, sr = sf.read(audio_path) |
| waveform = torch.from_numpy(audio_data).float() |
| except Exception as e: |
| logging.error(f"Failed to read {audio_path}: {e}") |
| return self.__getitem__((idx + 1) % len(self)) |
| |
| |
| if waveform.ndim == 1: |
| waveform = waveform.unsqueeze(0) |
| else: |
| waveform = waveform.t() |
| |
| |
| if sr != 16000: |
| waveform = torchaudio.functional.resample(waveform, sr, 16000) |
| |
| if waveform.shape[0] > 1: |
| waveform = waveform.mean(dim=0, keepdim=True) |
| |
| |
| wav_22k = torchaudio.functional.resample(waveform, 16000, 22050) |
| hcqt = compute_hcqt(wav_22k) |
| |
| |
| mel_spec = torchaudio.transforms.MelSpectrogram( |
| sample_rate=22050, |
| n_fft=2048, |
| hop_length=512, |
| n_mels=88, |
| f_min=27.5, |
| f_max=1000.0, |
| normalized=True |
| )(wav_22k).squeeze(0) |
| mel_spec = torch.log(mel_spec + 1e-9).transpose(0, 1) |
| |
| |
| fps = sr / 512 |
| pm = pretty_midi.PrettyMIDI(midi_path) |
| |
| |
| n_frames = hcqt.shape[1] |
| |
| labels_full = np.zeros((n_frames, 88), dtype=np.float32) |
| |
| for inst in pm.instruments: |
| for note in inst.notes: |
| start = int(note.start * fps) |
| end = int(note.end * fps) |
| pitch = note.pitch - 21 |
| if 0 <= pitch < 88 and start < n_frames: |
| end = min(end, n_frames) |
| labels_full[start:end, pitch] = 1.0 |
| |
| labels_full = torch.from_numpy(labels_full).float() |
| |
| |
| labels_sota = labels_full[:, 7:47] |
| |
| return { |
| "sota_input_wav": waveform.squeeze(), |
| "sota_input_hcqt": hcqt, |
| "comp_input_mel": mel_spec, |
| "labels_full": labels_full, |
| "labels_sota": labels_sota |
| } |
|
|
|
|
| |
| |
| |
|
|
| class PositionalEncoding(nn.Module): |
| def __init__(self, d_model, max_len=5000): |
| super().__init__() |
| pe = torch.zeros(max_len, d_model) |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
| self.register_buffer('pe', pe) |
|
|
| def forward(self, x): |
| return x + self.pe[:x.size(0), :].unsqueeze(1) |
|
|
|
|
| class CNNSA(nn.Module): |
| def __init__(self, input_freq_bins=64, num_classes=9, d_model=512, nhead=8, num_layers=3): |
| super().__init__() |
| self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) |
| self.bn1 = nn.BatchNorm2d(32) |
| self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) |
| self.bn2 = nn.BatchNorm2d(64) |
| self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) |
| self.bn3 = nn.BatchNorm2d(128) |
| self.conv4 = nn.Conv2d(128, d_model, kernel_size=3, padding=1) |
| self.bn4 = nn.BatchNorm2d(d_model) |
| self.pool_sq = nn.MaxPool2d(2, 2) |
| self.pool_freq = nn.MaxPool2d((2, 1)) |
| |
| self.cnn_flatten_dim = d_model * 4 |
| |
| self.projection = nn.Linear(self.cnn_flatten_dim, d_model) |
| self.pos_encoder = PositionalEncoding(d_model) |
| encoder_layers = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=nhead, dim_feedforward=1024, dropout=0.2, batch_first=True |
| ) |
| self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers) |
| self.fc1 = nn.Linear(d_model, 256) |
| self.fc2 = nn.Linear(256, num_classes) |
| self.dropout = nn.Dropout(0.3) |
|
|
| def forward(self, x): |
| if x.dim() == 3: |
| x = x.unsqueeze(1) |
| |
| x = self.pool_sq(F.relu(self.bn1(self.conv1(x)))) |
| x = self.pool_sq(F.relu(self.bn2(self.conv2(x)))) |
| x = self.pool_freq(F.relu(self.bn3(self.conv3(x)))) |
| x = self.pool_freq(F.relu(self.bn4(self.conv4(x)))) |
| |
| b, c, f, t = x.shape |
| x = x.permute(0, 3, 1, 2).contiguous().view(b, t, c * f) |
| x = self.projection(x) |
| x = self.pos_encoder(x.transpose(0, 1)).transpose(0, 1) |
| x = self.transformer_encoder(x) |
| x = F.relu(self.fc1(x)) |
| x = self.dropout(x) |
| return self.fc2(x) |
|
|
|
|
| class DrumSOTAModel(nn.Module): |
| def __init__(self, num_classes=9, unfreeze_layers=4): |
| super().__init__() |
| try: |
| self.wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base", use_safetensors=True) |
| except: |
| self.wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base") |
| |
| hidden = self.wavlm.config.hidden_size |
| self.frame_head = nn.Sequential( |
| nn.Linear(hidden, hidden // 2), |
| nn.LayerNorm(hidden // 2), |
| nn.GELU(), |
| nn.Dropout(0.1), |
| nn.Linear(hidden // 2, num_classes) |
| ) |
| self.onset_head = nn.Sequential( |
| nn.Linear(hidden, hidden // 4), |
| nn.LayerNorm(hidden // 4), |
| nn.GELU(), |
| nn.Dropout(0.2), |
| nn.Linear(hidden // 4, num_classes) |
| ) |
|
|
| def forward(self, audio): |
| out = self.wavlm(audio).last_hidden_state |
| return self.frame_head(out), self.onset_head(out) |
|
|
|
|
| |
| |
|
|
| class ConformerBlock(nn.Module): |
| def __init__(self, d_model=512, nhead=8, conv_kernel=31, dropout=0.1): |
| super().__init__() |
| self.ffn1 = nn.Sequential( |
| nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 4), nn.SiLU(), nn.Dropout(dropout), |
| nn.Linear(d_model * 4, d_model), nn.Dropout(dropout) |
| ) |
| self.norm_attn = nn.LayerNorm(d_model) |
| self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) |
| self.dropout_attn = nn.Dropout(dropout) |
| self.norm_conv = nn.LayerNorm(d_model) |
| self.pointwise_conv1 = nn.Conv1d(d_model, d_model * 2, 1) |
| self.depthwise_conv = nn.Conv1d(d_model, d_model, conv_kernel, padding=conv_kernel//2, groups=d_model) |
| self.batch_norm = nn.BatchNorm1d(d_model) |
| self.activation = nn.SiLU() |
| self.pointwise_conv2 = nn.Conv1d(d_model, d_model, 1) |
| self.dropout_conv = nn.Dropout(dropout) |
| self.ffn2 = nn.Sequential( |
| nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 4), nn.SiLU(), nn.Dropout(dropout), |
| nn.Linear(d_model * 4, d_model), nn.Dropout(dropout) |
| ) |
| self.norm_final = nn.LayerNorm(d_model) |
|
|
| def forward(self, x): |
| x = x + 0.5 * self.ffn1(x) |
| residual = x |
| x = self.norm_attn(x) |
| x, _ = self.self_attn(x, x, x) |
| x = residual + self.dropout_attn(x) |
| residual = x |
| x = self.norm_conv(x).transpose(1, 2) |
| x = F.glu(self.pointwise_conv1(x), dim=1) |
| x = self.activation(self.batch_norm(self.depthwise_conv(x))) |
| x = self.dropout_conv(self.pointwise_conv2(x)).transpose(1, 2) |
| x = residual + x |
| x = x + 0.5 * self.ffn2(x) |
| return self.norm_final(x) |
|
|
|
|
| class Conformer(nn.Module): |
| def __init__(self, d_model=512, nhead=8, conv_kernel=31, num_layers=2): |
| super().__init__() |
| self.layers = nn.ModuleList([ConformerBlock(d_model, nhead, conv_kernel) for _ in range(num_layers)]) |
|
|
| def forward(self, x): |
| for layer in self.layers: |
| x = layer(x) |
| return x |
|
|
|
|
| class SimpleHarmonicAttention(nn.Module): |
| def __init__(self, n_bins=72, n_harmonics=3): |
| super().__init__() |
| self.attention = nn.MultiheadAttention(n_bins, 4, batch_first=True, dropout=0.1) |
|
|
| def forward(self, hcqt): |
| B, H, T, F = hcqt.shape |
| x = hcqt.permute(0, 2, 1, 3).reshape(B * T, H, F) |
| x, _ = self.attention(x, x, x) |
| return x.reshape(B, T, H, F).permute(0, 2, 1, 3) |
|
|
|
|
| class SpectralCNN(nn.Module): |
| def __init__(self, in_channels=3, hidden_dim=512): |
| super().__init__() |
| self.conv = nn.Sequential( |
| nn.Conv2d(in_channels, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d((1, 2)), |
| nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d((1, 2)), |
| nn.Conv2d(128, hidden_dim, 3, padding=1), nn.BatchNorm2d(hidden_dim), nn.ReLU() |
| ) |
| self.pool = nn.AdaptiveAvgPool2d((None, 1)) |
|
|
| def forward(self, x): |
| return self.pool(self.conv(x)).squeeze(-1).transpose(1, 2) |
|
|
|
|
| class BassSOTAModel(nn.Module): |
| def __init__(self, use_harmonic_branch=True, hidden_dim=768): |
| super().__init__() |
| self.use_harmonic_branch = use_harmonic_branch |
| self.audio_encoder = Wav2Vec2Model.from_pretrained("microsoft/wavlm-base-plus", use_safetensors=True) |
| for p in self.audio_encoder.parameters(): |
| p.requires_grad = False |
| self.audio_proj = nn.Sequential(nn.Linear(768, hidden_dim), nn.LayerNorm(hidden_dim), nn.Dropout(0.1)) |
| |
| N_BINS = 72 |
| HARMONICS = [1, 2, 3] |
| N_MIDI_BINS = 40 |
| |
| if use_harmonic_branch: |
| self.harmonic_attn = SimpleHarmonicAttention(N_BINS, len(HARMONICS)) |
| self.spec_cnn = SpectralCNN(len(HARMONICS), hidden_dim) |
| |
| fusion_dim = hidden_dim * (2 if use_harmonic_branch else 1) |
| self.fusion = nn.Sequential(nn.Linear(fusion_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(0.1)) |
| self.conformer = Conformer(hidden_dim, num_layers=2) |
| self.onset_head = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim//2), nn.LayerNorm(hidden_dim//2), nn.GELU(), nn.Linear(hidden_dim//2, N_MIDI_BINS) |
| ) |
| self.frame_head = nn.Sequential( |
| nn.Linear(hidden_dim + N_MIDI_BINS, hidden_dim//2), nn.LayerNorm(hidden_dim//2), nn.GELU(), |
| nn.Linear(hidden_dim//2, N_MIDI_BINS) |
| ) |
|
|
| def forward(self, waveform, hcqt=None): |
| with torch.no_grad(): |
| audio = self.audio_encoder(waveform).last_hidden_state |
| audio = self.audio_proj(audio) |
| |
| if self.use_harmonic_branch and hcqt is not None: |
| T_target = hcqt.shape[2] |
| spec = self.spec_cnn(self.harmonic_attn(hcqt)) |
| if audio.shape[1] != T_target: |
| audio = F.interpolate(audio.transpose(1, 2), size=T_target, mode='linear', align_corners=False).transpose(1, 2) |
| if spec.shape[1] != T_target: |
| spec = F.interpolate(spec.transpose(1, 2), size=T_target, mode='linear', align_corners=False).transpose(1, 2) |
| x = torch.cat([audio, spec], dim=-1) |
| else: |
| x = audio |
| |
| x = self.conformer(self.fusion(x)) |
| onset = self.onset_head(x) |
| frame = self.frame_head(torch.cat([x, onset], dim=-1)) |
| return onset, frame |
|
|
|
|
| class BassCompModel(nn.Module): |
| def __init__(self, input_features=88, hidden_size=256, num_classes=88): |
| super().__init__() |
| self.cnn = nn.Sequential( |
| nn.Conv2d(1, 16, (3, 3), padding=1), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d((1, 2)), |
| nn.Conv2d(16, 32, (3, 3), padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d((1, 2)) |
| ) |
| self.lstm = nn.LSTM(32 * (input_features//4), hidden_size, 2, batch_first=True, bidirectional=True) |
| self.fc = nn.Linear(hidden_size*2, num_classes) |
|
|
| def forward(self, x): |
| x = x.unsqueeze(1) |
| x = self.cnn(x) |
| b, c, t, f = x.size() |
| x = x.permute(0, 2, 1, 3).reshape(b, t, -1) |
| x, _ = self.lstm(x) |
| return self.fc(x) |
|
|
|
|
| |
| |
| |
|
|
| def load_model_safe(weights_path, device, task): |
| """Robustly load a model.""" |
| if not weights_path or not os.path.exists(weights_path): |
| logging.warning(f"Weights file not found: {weights_path}") |
| return None, None |
|
|
| logging.info(f"Loading weights from {weights_path}...") |
| try: |
| ckpt = torch.load(weights_path, map_location='cpu') |
| except Exception as e: |
| logging.error(f"Failed to load checkpoint: {e}") |
| return None, None |
|
|
| state_dict = ckpt |
| if isinstance(ckpt, dict): |
| if 'model' in ckpt: |
| state_dict = ckpt['model'] |
| elif 'model_state_dict' in ckpt: |
| state_dict = ckpt['model_state_dict'] |
| |
| keys = list(state_dict.keys()) |
| if not keys: |
| logging.error("Checkpoint is empty.") |
| return None, None |
| |
| model = None |
| model_type = "Unknown" |
| |
| if task == "bass": |
| if any(k.startswith("audio_encoder") or k.startswith("conformer") for k in keys): |
| logging.info("➡ Detected: BassSOTAModel") |
| model = BassSOTAModel().to(device) |
| model_type = "SOTA" |
| elif any(k.startswith("cnn") or k.startswith("lstm") for k in keys): |
| logging.info("➡ Detected: BassCompModel (CRNN)") |
| model = BassCompModel().to(device) |
| model_type = "CRNN" |
| |
| elif task == "drum": |
| if any(k.startswith("wavlm") for k in keys): |
| logging.info("➡ Detected: DrumSOTAModel") |
| model = DrumSOTAModel().to(device) |
| model_type = "SOTA" |
| else: |
| logging.info("➡ Detected: CNNSA") |
| model = CNNSA().to(device) |
| model_type = "CNNSA" |
| |
| if model: |
| try: |
| model.load_state_dict(state_dict, strict=True) |
| logging.info("✓ Loaded successfully") |
| except RuntimeError: |
| new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
| try: |
| model.load_state_dict(new_state_dict, strict=True) |
| logging.info("✓ Loaded after key fix") |
| except RuntimeError: |
| model.load_state_dict(new_state_dict, strict=False) |
| logging.warning("⚠ Loaded with strict=False") |
| |
| return model, model_type |
|
|
|
|
| |
| |
| |
|
|
| def evaluate(args): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| logging.info(f"Task: {args.task} | Device: {device}") |
| |
| |
| models = {} |
| |
| if args.sota_weights: |
| model_sota, type_sota = load_model_safe(args.sota_weights, device, args.task) |
| if model_sota: |
| models['SOTA'] = (model_sota, type_sota) |
| |
| if args.comp_weights: |
| model_comp, type_comp = load_model_safe(args.comp_weights, device, args.task) |
| if model_comp: |
| models['Comparison'] = (model_comp, type_comp) |
| |
| if not models: |
| logging.error("No models loaded. Exiting.") |
| return |
| |
| |
| if args.task == "drum": |
| dataset = DrumEvalDataset(args.data_path) |
| elif args.task == "bass": |
| if not args.midi_path: |
| logging.error("--midi_path required for bass evaluation") |
| return |
| dataset = BassEvalDataset(args.data_path, args.midi_path) |
| |
| loader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=2) |
| |
| |
| results = {name: { |
| 'frame_f1': [], 'frame_precision': [], 'frame_recall': [], |
| 'onset_f1': [], 'onset_precision': [], 'onset_recall': [] |
| } for name in models} |
| |
| |
| for m, _ in models.values(): |
| m.eval() |
| |
| logging.info("Starting evaluation...") |
| with torch.no_grad(): |
| for batch_idx, batch in enumerate(tqdm(loader, desc="Evaluating")): |
| if args.task == "drum": |
| wav = batch['sota_input'].to(device) |
| mel = batch['comp_input'].to(device) |
| y = batch['labels'].to(device) |
| |
| for name, (model, mtype) in models.items(): |
| if mtype == "SOTA": |
| f_pred, o_pred = model(wav) |
| else: |
| f_pred = model(mel) |
| o_pred = f_pred |
| |
| |
| if f_pred.shape[1] != y.shape[1]: |
| f_pred = F.interpolate(f_pred.transpose(1, 2), size=y.shape[1], mode='linear').transpose(1, 2) |
| if o_pred.shape[1] != y.shape[1]: |
| o_pred = F.interpolate(o_pred.transpose(1, 2), size=y.shape[1], mode='linear').transpose(1, 2) |
| |
| |
| metrics = calculate_metrics(f_pred, y) |
| for k, v in metrics.items(): |
| results[name][k].append(v) |
| |
| elif args.task == "bass": |
| wav = batch['sota_input_wav'].to(device) |
| hcqt = batch['sota_input_hcqt'].to(device) |
| mel = batch['comp_input_mel'].to(device) |
| y_full = batch['labels_full'].to(device) |
| y_sota = batch['labels_sota'].to(device) |
| |
| for name, (model, mtype) in models.items(): |
| if mtype == "SOTA": |
| o_pred, f_pred = model(wav, hcqt) |
| target = y_sota |
| elif mtype == "CRNN": |
| f_pred = model(mel) |
| o_pred = f_pred |
| target = y_full |
| |
| |
| if f_pred.shape[1] != target.shape[1]: |
| f_pred = F.interpolate(f_pred.transpose(1, 2), size=target.shape[1], mode='linear').transpose(1, 2) |
| if o_pred.shape[1] != target.shape[1]: |
| o_pred = F.interpolate(o_pred.transpose(1, 2), size=target.shape[1], mode='linear').transpose(1, 2) |
| |
| metrics = calculate_metrics(f_pred, target) |
| for k, v in metrics.items(): |
| results[name][k].append(v) |
| |
| |
| print(f"\n{'='*80}") |
| print(f"EVALUATION RESULTS - {args.task.upper()}") |
| print(f"{'='*80}") |
| print(f"{'Model':<15} | {'Type':<8} | {'Frame F1':<10} | {'Frame P':<10} | {'Frame R':<10} | {'Onset F1':<10}") |
| print("-" * 80) |
| |
| for name, metrics in results.items(): |
| mtype = models[name][1] |
| print(f"{name:<15} | {mtype:<8} | " |
| f"{np.mean(metrics['frame_f1']):.4f} | " |
| f"{np.mean(metrics['frame_precision']):.4f} | " |
| f"{np.mean(metrics['frame_recall']):.4f} | " |
| f"{np.mean(metrics['onset_f1']):.4f}") |
| |
| print(f"{'='*80}\n") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Evaluate SOTA vs Comparison models") |
| parser.add_argument("--task", required=True, choices=["drum", "bass"]) |
| parser.add_argument("--data_path", required=True, help="Path to audio dir (bass) or H5 file (drum)") |
| parser.add_argument("--midi_path", help="MIDI directory (bass only)") |
| parser.add_argument("--sota_weights", required=True, help="SOTA model weights") |
| parser.add_argument("--comp_weights", required=True, help="Comparison model weights") |
| |
| args = parser.parse_args() |
| evaluate(args) |