| """ |
| Train Electrical Outlets audio model. Spectrogram CNN, class weights, per-class recall, early stopping. |
| """ |
| from pathlib import Path |
| import sys |
| import argparse |
| from typing import Dict |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
|
|
| ROOT = Path(__file__).resolve().parent.parent |
| sys.path.insert(0, str(ROOT)) |
|
|
| from src.data.audio_dataset import ElectricalOutletsAudioDataset |
| from src.models.audio_model import ElectricalOutletsAudioModel |
|
|
|
|
| def load_config(config_path: Path) -> dict: |
| import yaml |
| with open(config_path) as f: |
| return yaml.safe_load(f) |
|
|
|
|
| def _wave_to_mel(waveform: torch.Tensor, n_mels: int, n_fft: int, hop: int, win: int) -> torch.Tensor: |
| import torchaudio |
| mel = torchaudio.transforms.MelSpectrogram( |
| sample_rate=16000, n_fft=n_fft, hop_length=hop, win_length=win, n_mels=n_mels, |
| )(waveform) |
| log_mel = torch.log(mel.clamp(min=1e-5)) |
| return log_mel |
|
|
|
|
| def per_class_recall(logits: torch.Tensor, targets: torch.Tensor, num_classes: int) -> Dict[int, float]: |
| preds = logits.argmax(dim=1) |
| recall = {} |
| for c in range(num_classes): |
| mask = targets == c |
| if mask.sum() == 0: |
| recall[c] = 0.0 |
| else: |
| recall[c] = (preds[mask] == c).float().mean().item() |
| return recall |
|
|
|
|
| def run_training( |
| data_root: Path, |
| label_mapping_path: Path, |
| config: dict, |
| weights_dir: Path, |
| device: str = "cuda", |
| ): |
| train_ratio = config["data"]["train_ratio"] |
| val_ratio = config["data"]["val_ratio"] |
| seed = config["data"].get("seed", 42) |
| batch_size = config["data"]["batch_size"] |
| num_workers = config["data"].get("num_workers", 0) |
| spec_cfg = config.get("spectrogram", {}) |
| n_mels = spec_cfg.get("n_mels", 64) |
| n_fft = spec_cfg.get("n_fft", 512) |
| hop = spec_cfg.get("hop_length", 256) |
| win = spec_cfg.get("win_length", 512) |
|
|
| def to_mel(x): |
| return _wave_to_mel(x, n_mels, n_fft, hop, win) |
|
|
| train_ds = ElectricalOutletsAudioDataset( |
| data_root, label_mapping_path, split="train", |
| train_ratio=train_ratio, val_ratio=val_ratio, seed=seed, transform=to_mel, |
| target_length_sec=config["data"].get("target_length_sec", 5.0), |
| sample_rate=config["data"].get("sample_rate", 16000), |
| ) |
| val_ds = ElectricalOutletsAudioDataset( |
| data_root, label_mapping_path, split="val", |
| train_ratio=train_ratio, val_ratio=val_ratio, seed=seed, transform=to_mel, |
| target_length_sec=config["data"].get("target_length_sec", 5.0), |
| sample_rate=config["data"].get("sample_rate", 16000), |
| ) |
| train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers) |
| val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers) |
|
|
| num_classes = train_ds.num_classes |
| model = ElectricalOutletsAudioModel( |
| num_classes=num_classes, |
| label_mapping_path=label_mapping_path, |
| n_mels=config["model"].get("n_mels", 64), |
| time_steps=config["model"].get("time_steps", 128), |
| ).to(device) |
| opt = torch.optim.AdamW( |
| model.parameters(), |
| lr=config["training"]["lr"], |
| weight_decay=config["training"].get("weight_decay", 1e-4), |
| ) |
| criterion = nn.CrossEntropyLoss() |
| epochs = config["training"]["epochs"] |
| patience = config["training"].get("early_stopping_patience", 12) |
| best_metric = -1.0 |
| best_epoch = 0 |
| wait = 0 |
| recall = {} |
|
|
| for epoch in range(epochs): |
| model.train() |
| for x, y in train_loader: |
| x, y = x.to(device), y.to(device) |
| opt.zero_grad() |
| logits = model(x) |
| loss = criterion(logits, y) |
| loss.backward() |
| opt.step() |
|
|
| model.eval() |
| val_logits, val_targets = [], [] |
| with torch.no_grad(): |
| for x, y in val_loader: |
| x = x.to(device) |
| val_logits.append(model(x).cpu()) |
| val_targets.append(y) |
| val_logits = torch.cat(val_logits, dim=0) |
| val_targets = torch.cat(val_targets, dim=0) |
| recall = per_class_recall(val_logits, val_targets, num_classes) |
| min_recall = min(recall.values()) |
| macro_recall = sum(recall.values()) / num_classes |
| metric = macro_recall |
| if metric > best_metric: |
| best_metric = metric |
| best_epoch = epoch |
| wait = 0 |
| weights_dir.mkdir(parents=True, exist_ok=True) |
| torch.save({ |
| "model_state_dict": model.state_dict(), |
| "num_classes": num_classes, |
| "idx_to_label": model.idx_to_label, |
| "idx_to_issue_type": model.idx_to_issue_type, |
| "idx_to_severity": model.idx_to_severity, |
| }, weights_dir / config["output"]["best_name"]) |
| else: |
| wait += 1 |
| print(f"Epoch {epoch} min_recall={min_recall:.4f} macro_recall={macro_recall:.4f} best={best_metric:.4f}") |
| if wait >= patience: |
| print("Early stopping at epoch", epoch) |
| break |
|
|
| if config.get("calibration", {}).get("use_temperature_scaling", False): |
| model.load_state_dict(torch.load(weights_dir / config["output"]["best_name"], map_location=device)["model_state_dict"]) |
| model.eval() |
| n_val = len(val_ds) |
| cal_size = max(1, int(n_val * config["calibration"].get("val_fraction_for_calibration", 0.5))) |
| cal_logits, cal_targets = [], [] |
| for i in range(cal_size): |
| x, y = val_ds[i] |
| x = x.unsqueeze(0).to(device) |
| with torch.no_grad(): |
| cal_logits.append(model(x).cpu()) |
| cal_targets.append(y) |
| cal_logits = torch.cat(cal_logits, dim=0) |
| cal_targets = torch.tensor(cal_targets) |
| temp = nn.Parameter(torch.ones(1) * 1.5) |
| opt_cal = torch.optim.LBFGS([temp], lr=0.01, max_iter=50) |
| def eval_cal(): |
| opt_cal.zero_grad() |
| loss = F.cross_entropy(cal_logits / temp, cal_targets) |
| loss.backward() |
| return loss |
| opt_cal.step(eval_cal) |
| ckpt = torch.load(weights_dir / config["output"]["best_name"], map_location="cpu") |
| ckpt["temperature"] = temp.item() |
| torch.save(ckpt, weights_dir / config["output"]["best_name"]) |
|
|
| return {"best_epoch": best_epoch, "best_metric": best_metric, "recall_per_class": recall} |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", default="config/audio_train_config.yaml") |
| parser.add_argument("--data_root", default=None) |
| parser.add_argument("--weights_dir", default="weights") |
| parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") |
| args = parser.parse_args() |
| root = Path(__file__).resolve().parent.parent |
| config = load_config(root / args.config) |
| data_root = Path(args.data_root) if args.data_root else root / config["data"]["root"] |
| label_mapping_path = root / config["data"]["label_mapping"] |
| weights_dir = root / args.weights_dir |
| results = run_training(data_root, label_mapping_path, config, weights_dir, args.device) |
| report_path = root / "docs" / config["output"]["report_name"] |
| report_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(report_path, "w") as f: |
| f.write("# Audio Model Report (Electrical Outlets)\n\n") |
| f.write("- **Preliminary model.** 100 samples is very small; recommend collecting more data.\n") |
| f.write(f"- Best epoch: {results['best_epoch']}, best metric: {results['best_metric']:.4f}\n\n") |
| f.write("## Per-class recall (validation)\n\n") |
| for c, r in results.get("recall_per_class", {}).items(): |
| f.write(f"- Class {c}: {r:.4f}\n") |
| f.write("\n## Limitations\n- Small dataset; use audio as support in fusion. Do not rely on audio-only for critical decisions.\n") |
| print("Report written to", report_path) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|