Asadrizvi64's picture
Electrical Outlets diagnostic pipeline v1.0
5666923
"""
Train Electrical Outlets audio model. Spectrogram CNN, class weights, per-class recall, early stopping.
"""
from pathlib import Path
import sys
import argparse
from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
from src.data.audio_dataset import ElectricalOutletsAudioDataset
from src.models.audio_model import ElectricalOutletsAudioModel
def load_config(config_path: Path) -> dict:
import yaml
with open(config_path) as f:
return yaml.safe_load(f)
def _wave_to_mel(waveform: torch.Tensor, n_mels: int, n_fft: int, hop: int, win: int) -> torch.Tensor:
import torchaudio
mel = torchaudio.transforms.MelSpectrogram(
sample_rate=16000, n_fft=n_fft, hop_length=hop, win_length=win, n_mels=n_mels,
)(waveform)
log_mel = torch.log(mel.clamp(min=1e-5))
return log_mel
def per_class_recall(logits: torch.Tensor, targets: torch.Tensor, num_classes: int) -> Dict[int, float]:
preds = logits.argmax(dim=1)
recall = {}
for c in range(num_classes):
mask = targets == c
if mask.sum() == 0:
recall[c] = 0.0
else:
recall[c] = (preds[mask] == c).float().mean().item()
return recall
def run_training(
data_root: Path,
label_mapping_path: Path,
config: dict,
weights_dir: Path,
device: str = "cuda",
):
train_ratio = config["data"]["train_ratio"]
val_ratio = config["data"]["val_ratio"]
seed = config["data"].get("seed", 42)
batch_size = config["data"]["batch_size"]
num_workers = config["data"].get("num_workers", 0)
spec_cfg = config.get("spectrogram", {})
n_mels = spec_cfg.get("n_mels", 64)
n_fft = spec_cfg.get("n_fft", 512)
hop = spec_cfg.get("hop_length", 256)
win = spec_cfg.get("win_length", 512)
def to_mel(x):
return _wave_to_mel(x, n_mels, n_fft, hop, win)
train_ds = ElectricalOutletsAudioDataset(
data_root, label_mapping_path, split="train",
train_ratio=train_ratio, val_ratio=val_ratio, seed=seed, transform=to_mel,
target_length_sec=config["data"].get("target_length_sec", 5.0),
sample_rate=config["data"].get("sample_rate", 16000),
)
val_ds = ElectricalOutletsAudioDataset(
data_root, label_mapping_path, split="val",
train_ratio=train_ratio, val_ratio=val_ratio, seed=seed, transform=to_mel,
target_length_sec=config["data"].get("target_length_sec", 5.0),
sample_rate=config["data"].get("sample_rate", 16000),
)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)
num_classes = train_ds.num_classes
model = ElectricalOutletsAudioModel(
num_classes=num_classes,
label_mapping_path=label_mapping_path,
n_mels=config["model"].get("n_mels", 64),
time_steps=config["model"].get("time_steps", 128),
).to(device)
opt = torch.optim.AdamW(
model.parameters(),
lr=config["training"]["lr"],
weight_decay=config["training"].get("weight_decay", 1e-4),
)
criterion = nn.CrossEntropyLoss()
epochs = config["training"]["epochs"]
patience = config["training"].get("early_stopping_patience", 12)
best_metric = -1.0
best_epoch = 0
wait = 0
recall = {}
for epoch in range(epochs):
model.train()
for x, y in train_loader:
x, y = x.to(device), y.to(device)
opt.zero_grad()
logits = model(x)
loss = criterion(logits, y)
loss.backward()
opt.step()
model.eval()
val_logits, val_targets = [], []
with torch.no_grad():
for x, y in val_loader:
x = x.to(device)
val_logits.append(model(x).cpu())
val_targets.append(y)
val_logits = torch.cat(val_logits, dim=0)
val_targets = torch.cat(val_targets, dim=0)
recall = per_class_recall(val_logits, val_targets, num_classes)
min_recall = min(recall.values())
macro_recall = sum(recall.values()) / num_classes
metric = macro_recall
if metric > best_metric:
best_metric = metric
best_epoch = epoch
wait = 0
weights_dir.mkdir(parents=True, exist_ok=True)
torch.save({
"model_state_dict": model.state_dict(),
"num_classes": num_classes,
"idx_to_label": model.idx_to_label,
"idx_to_issue_type": model.idx_to_issue_type,
"idx_to_severity": model.idx_to_severity,
}, weights_dir / config["output"]["best_name"])
else:
wait += 1
print(f"Epoch {epoch} min_recall={min_recall:.4f} macro_recall={macro_recall:.4f} best={best_metric:.4f}")
if wait >= patience:
print("Early stopping at epoch", epoch)
break
if config.get("calibration", {}).get("use_temperature_scaling", False):
model.load_state_dict(torch.load(weights_dir / config["output"]["best_name"], map_location=device)["model_state_dict"])
model.eval()
n_val = len(val_ds)
cal_size = max(1, int(n_val * config["calibration"].get("val_fraction_for_calibration", 0.5)))
cal_logits, cal_targets = [], []
for i in range(cal_size):
x, y = val_ds[i]
x = x.unsqueeze(0).to(device)
with torch.no_grad():
cal_logits.append(model(x).cpu())
cal_targets.append(y)
cal_logits = torch.cat(cal_logits, dim=0)
cal_targets = torch.tensor(cal_targets)
temp = nn.Parameter(torch.ones(1) * 1.5)
opt_cal = torch.optim.LBFGS([temp], lr=0.01, max_iter=50)
def eval_cal():
opt_cal.zero_grad()
loss = F.cross_entropy(cal_logits / temp, cal_targets)
loss.backward()
return loss
opt_cal.step(eval_cal)
ckpt = torch.load(weights_dir / config["output"]["best_name"], map_location="cpu")
ckpt["temperature"] = temp.item()
torch.save(ckpt, weights_dir / config["output"]["best_name"])
return {"best_epoch": best_epoch, "best_metric": best_metric, "recall_per_class": recall}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", default="config/audio_train_config.yaml")
parser.add_argument("--data_root", default=None)
parser.add_argument("--weights_dir", default="weights")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args()
root = Path(__file__).resolve().parent.parent
config = load_config(root / args.config)
data_root = Path(args.data_root) if args.data_root else root / config["data"]["root"]
label_mapping_path = root / config["data"]["label_mapping"]
weights_dir = root / args.weights_dir
results = run_training(data_root, label_mapping_path, config, weights_dir, args.device)
report_path = root / "docs" / config["output"]["report_name"]
report_path.parent.mkdir(parents=True, exist_ok=True)
with open(report_path, "w") as f:
f.write("# Audio Model Report (Electrical Outlets)\n\n")
f.write("- **Preliminary model.** 100 samples is very small; recommend collecting more data.\n")
f.write(f"- Best epoch: {results['best_epoch']}, best metric: {results['best_metric']:.4f}\n\n")
f.write("## Per-class recall (validation)\n\n")
for c, r in results.get("recall_per_class", {}).items():
f.write(f"- Class {c}: {r:.4f}\n")
f.write("\n## Limitations\n- Small dataset; use audio as support in fusion. Do not rely on audio-only for critical decisions.\n")
print("Report written to", report_path)
if __name__ == "__main__":
main()