Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import tqdm | |
| import json | |
| import numpy as np | |
| from torch.utils.data import DataLoader | |
| from typing import Sequence | |
| from src.models.predict import AudioPredictor | |
| from src.data.dataset import FullTFPatchesDataset, RandomPatchDataset | |
| from src.config.config import TrainConfig | |
| config = TrainConfig() | |
| class CNNTrainer: | |
| def __init__(self, config: TrainConfig = config) -> None: | |
| self.config = config | |
| def train_cnn( | |
| self, | |
| model: torch.nn.Module, | |
| X_train: Sequence[np.ndarray], | |
| y_train: Sequence[int], | |
| X_val: Sequence[np.ndarray], | |
| y_val: Sequence[int], | |
| fold_num: int, | |
| ) -> float: | |
| device = self.config.device | |
| os.makedirs(self.config.checkpoint_dir, exist_ok=True) | |
| model.to(device) | |
| if self.config.use_all_patches: | |
| train_dataset = FullTFPatchesDataset(X_train, y_train) | |
| print(f"\n{'='*60}\nUsing ALL PATCHES method\n{'='*60}") | |
| else: | |
| train_dataset = RandomPatchDataset(X_train, y_train) | |
| print(f"\n{'='*60}\nUsing ALL PATCHES method\n{'='*60}") | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=self.config.batch_size, | |
| shuffle=True, | |
| num_workers=4, | |
| pin_memory=True | |
| ) | |
| total_patches = len(train_dataset) | |
| patches_per_epoch = int(total_patches * self.config.samples_per_epoch_fraction) | |
| batches_per_epoch = patches_per_epoch // self.config.batch_size | |
| print(f"Total available patches: {total_patches:,}") | |
| print(f"Patches per epoch ({self.config.samples_per_epoch_fraction}): {patches_per_epoch:,}") | |
| print(f"Batches per epoch: {batches_per_epoch:,}\n{'='*60}\n") | |
| criterion = torch.nn.CrossEntropyLoss() | |
| optimizer = torch.optim.AdamW([ | |
| {'params': model.features.parameters(), 'weight_decay': 0.0}, | |
| {'params': model.classifier.parameters(), 'weight_decay': 0.001} | |
| ], lr=self.config.lr) | |
| start_epoch = 0 | |
| best_val_acc = 0.0 | |
| training_history: dict = {'train_loss': [], 'train_acc': [], 'val_acc': [], 'epochs': []} | |
| if self.config.resume_from and os.path.exists(self.config.resume_from): | |
| print(f"Resuming from checkpoint: {self.config.resume_from}") | |
| checkpoint = torch.load(self.config.resume_from, map_location=device) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| start_epoch = checkpoint['epoch'] + 1 | |
| best_val_acc = checkpoint['best_val_acc'] | |
| training_history = checkpoint['history'] | |
| print(f"Resuming from epoch {checkpoint['epoch']}, best val acc: {best_val_acc:.4f}\n") | |
| for epoch in range(start_epoch, self.config.epochs): | |
| model.train() | |
| train_loss, correct, total, batches_processed = 0.0, 0, 0, 0 | |
| for xb, yb in tqdm.tqdm(train_loader, f"Epoch {epoch+1} Train", leave=False): | |
| if batches_processed >= batches_per_epoch: | |
| break | |
| xb, yb = xb.to(device), yb.to(device) | |
| optimizer.zero_grad() | |
| out = model(xb) | |
| loss = criterion(out, yb) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | |
| optimizer.step() | |
| train_loss += loss.item() * xb.size(0) | |
| _, pred = out.max(1) | |
| correct += (pred == yb).sum().item() | |
| total += yb.size(0) | |
| batches_processed += 1 | |
| train_loss /= total | |
| train_acc = correct / total | |
| model.eval() | |
| val_correct = 0 | |
| val_total = len(y_val) | |
| for i in tqdm.tqdm(range(val_total), desc=f"Epoch {epoch+1} Val", leave=False): | |
| spec = X_val[i] | |
| true_label = y_val[i] | |
| pred_label = self._predict_val(model, spec, device) | |
| if pred_label == true_label: | |
| val_correct += 1 | |
| val_acc = val_correct / val_total | |
| training_history['train_loss'].append(train_loss) | |
| training_history['train_acc'].append(train_acc) | |
| training_history['val_acc'].append(val_acc) | |
| training_history['epochs'].append(epoch + 1) | |
| is_best = val_acc > best_val_acc | |
| if is_best: | |
| best_val_acc = val_acc | |
| torch.save(model.state_dict(), "best_model.pt") | |
| print( | |
| f"Fold {fold_num} | Epoch {epoch+1}/{self.config.epochs} | " | |
| f"Train loss: {train_loss:.4f}, Train acc: {train_acc:.4f} | " | |
| f"Val acc: {val_acc:.4f} (best: {best_val_acc:.4f})" | |
| ) | |
| if (epoch + 1) % self.config.save_every_n_epoch == 0: | |
| checkpoint = { | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'train_loss': train_loss, | |
| 'train_acc': train_acc, | |
| 'val_acc': val_acc, | |
| 'best_val_acc': best_val_acc, | |
| 'history': training_history, | |
| 'config': { | |
| 'batch_size': self.config.batch_size, | |
| 'lr': self.config.lr, | |
| 'total_patches': total_patches, | |
| 'patches_per_epoch': patches_per_epoch, | |
| } | |
| } | |
| checkpoint_path = os.path.join( | |
| self.config.checkpoint_dir, | |
| f"checkpoint_epoch_{epoch+1}.pt" | |
| ) | |
| torch.save(checkpoint, checkpoint_path) | |
| if is_best: | |
| best_path = os.path.join(self.config.checkpoint_dir, "best_model.pt") | |
| torch.save(checkpoint, best_path) | |
| latest_path = os.path.join(self.config.checkpoint_dir, "latest_checkpoint.pt") | |
| torch.save(checkpoint, latest_path) | |
| history_path = os.path.join(self.config.checkpoint_dir, "training_history.json") | |
| with open(history_path, 'w') as f: | |
| json.dump(training_history, f, indent=2) | |
| final_model_dir = "models/saved" | |
| os.makedirs(final_model_dir, exist_ok=True) | |
| final_model_path = os.path.join(final_model_dir, "final_model.pt") | |
| torch.save({ | |
| 'model_state_dict': model.state_dict(), | |
| 'best_val_acc': best_val_acc, | |
| 'config': { | |
| 'batch_size': self.config.batch_size, | |
| 'lr': self.config.lr, | |
| 'epochs': self.config.epochs, | |
| } | |
| }, final_model_path) | |
| print(f"\nTraining complete! Final model saved to {final_model_path}") | |
| return best_val_acc | |
| def train_k_fold_cnn( | |
| self, | |
| model_class: type, | |
| X: Sequence[np.ndarray], | |
| y: Sequence[int], | |
| ) -> tuple[list[float], float]: | |
| X_arr = np.array(X) | |
| y_arr = np.array(y) | |
| n_samples = len(y_arr) | |
| indices = np.arange(n_samples) | |
| np.random.shuffle(indices) | |
| fold_sizes = (n_samples // 5) * np.ones(5, dtype=int) | |
| fold_sizes[:n_samples % 5] += 1 | |
| current = 0 | |
| fold_accuracies: list[float] = [] | |
| for fold_num, fold_size in enumerate(fold_sizes, 1): | |
| start, stop = current, current + fold_size | |
| val_idx = indices[start:stop] | |
| train_idx = np.concatenate([indices[:start], indices[stop:]]) | |
| current = stop | |
| X_train, y_train = X_arr[train_idx].tolist(), y_arr[train_idx] | |
| X_val, y_val = X_arr[val_idx].tolist(), y_arr[val_idx] | |
| print(f"\n{'='*80}\nFOLD {fold_num}/5 | Train: {len(X_train)}, Val: {len(X_val)}\n{'='*80}\n") | |
| model = model_class() | |
| best_acc = self.train_cnn( | |
| model=model, | |
| X_train=X_train, y_train=y_train, | |
| X_val=X_val, y_val=y_val, | |
| fold_num=fold_num, | |
| ) | |
| fold_accuracies.append(best_acc) | |
| print(f"\nFold {fold_num} Best Accuracy: {best_acc:.4f}\n") | |
| mean_acc = float(np.mean(fold_accuracies)) | |
| std_acc = float(np.std(fold_accuracies)) | |
| print(f"\n{'='*80}\nFINAL 5-FOLD CV RESULTS\nFold Accuracies: {fold_accuracies}\nMean: {mean_acc:.4f} ± {std_acc:.4f}\n{'='*80}\n") | |
| results_path = os.path.join(self.config.checkpoint_dir, "5fold_cv_results.json") | |
| os.makedirs(self.config.checkpoint_dir, exist_ok=True) | |
| with open(results_path, 'w') as f: | |
| json.dump({'fold_accuracies': fold_accuracies, 'mean_accuracy': mean_acc, 'std_accuracy': std_acc}, f, indent=2) | |
| return fold_accuracies, mean_acc | |
| def _predict_val(self, model: torch.nn.Module, spec: np.ndarray, device: str) -> int: | |
| from src.config.config import DatasetConfig | |
| cfg = DatasetConfig() | |
| n_frames = spec.shape[0] | |
| if n_frames < cfg.cnn_input_length: | |
| spec = np.pad(spec, ((0, cfg.cnn_input_length - n_frames), (0, 0)), mode="constant") | |
| n_frames = cfg.cnn_input_length | |
| patches = np.stack([ | |
| spec[s:s + cfg.cnn_input_length] | |
| for s in range(0, n_frames - cfg.cnn_input_length + 1) | |
| ])[:, np.newaxis] | |
| patches_t = torch.tensor(patches, dtype=torch.float32).to(device) | |
| with torch.no_grad(): | |
| out = model(patches_t).mean(dim=0) | |
| return out.argmax().item() |