| """ |
| Train Electrical Outlets image model. |
| FINAL v5: Frozen backbone → partial unfreeze. 5 classes, 1300 images. |
| """ |
| from pathlib import Path |
| import sys |
| import argparse |
| from typing import Dict |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
| from torchvision import transforms |
|
|
| ROOT = Path(__file__).resolve().parent.parent |
| sys.path.insert(0, str(ROOT)) |
|
|
| from src.data.image_dataset import ElectricalOutletsImageDataset, get_image_class_weights |
| from src.models.image_model import ElectricalOutletsImageModel |
|
|
|
|
| def load_config(path): |
| import yaml |
| with open(path) as f: |
| return yaml.safe_load(f) |
|
|
|
|
| def focal_loss(logits, targets, alpha=0.25, gamma=2.0, weight=None): |
| ce = F.cross_entropy(logits, targets, reduction="none", weight=weight) |
| pt = torch.exp(-ce) |
| return (alpha * (1 - pt) ** gamma * ce).mean() |
|
|
|
|
| def per_class_recall(logits, targets, num_classes): |
| preds = logits.argmax(dim=1) |
| recall = {} |
| for c in range(num_classes): |
| mask = targets == c |
| recall[c] = (preds[mask] == c).float().mean().item() if mask.sum() > 0 else 0.0 |
| return recall |
|
|
|
|
| def run_training(data_root, label_mapping_path, config, weights_dir, device="cuda"): |
| cfg_data = config["data"] |
| cfg_train = config["training"] |
| cfg_aug = config["augmentation"] |
| cfg_model = config["model"] |
|
|
| |
| train_tf = transforms.Compose([ |
| transforms.Resize(cfg_aug["resize"]), |
| transforms.RandomResizedCrop(cfg_aug["crop"], scale=(0.65, 1.0)), |
| transforms.RandomHorizontalFlip(0.5), |
| transforms.RandomRotation(15), |
| transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.05), |
| transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), |
| transforms.GaussianBlur(3, sigma=(0.1, 2.0)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| transforms.RandomErasing(p=0.15), |
| ]) |
| val_tf = transforms.Compose([ |
| transforms.Resize(cfg_aug["resize"]), |
| transforms.CenterCrop(cfg_aug["crop"]), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| ]) |
|
|
| |
| train_ds = ElectricalOutletsImageDataset( |
| data_root, label_mapping_path, split="train", |
| train_ratio=cfg_data["train_ratio"], val_ratio=cfg_data["val_ratio"], |
| seed=cfg_data.get("seed", 42), transform=train_tf, |
| ) |
| val_ds = ElectricalOutletsImageDataset( |
| data_root, label_mapping_path, split="val", |
| train_ratio=cfg_data["train_ratio"], val_ratio=cfg_data["val_ratio"], |
| seed=cfg_data.get("seed", 42), transform=val_tf, |
| ) |
| train_loader = DataLoader(train_ds, batch_size=cfg_data["batch_size"], shuffle=True, |
| num_workers=cfg_data.get("num_workers", 4), pin_memory=True) |
| val_loader = DataLoader(val_ds, batch_size=cfg_data["batch_size"], shuffle=False, |
| num_workers=cfg_data.get("num_workers", 4)) |
|
|
| num_classes = train_ds.num_classes |
| print(f"\nTrain: {len(train_ds)}, Val: {len(val_ds)}, Classes: {num_classes}") |
|
|
| |
| class_weights = None |
| if cfg_train.get("use_class_weights", True): |
| class_weights = get_image_class_weights(label_mapping_path, data_root).to(device) |
| print(f"Class weights: {[f'{w:.3f}' for w in class_weights.tolist()]}") |
|
|
| use_focal = cfg_train.get("use_focal", True) |
| criterion_ce = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1) |
|
|
| |
| model = ElectricalOutletsImageModel( |
| num_classes=num_classes, |
| label_mapping_path=label_mapping_path, |
| pretrained=True, |
| head_hidden=cfg_model.get("head_hidden", 256), |
| head_dropout=cfg_model.get("head_dropout", 0.4), |
| ).to(device) |
|
|
| |
| |
| |
| for p in model.backbone.parameters(): |
| p.requires_grad = False |
|
|
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f"Params: {trainable:,} trainable / {total_params:,} total ({100*trainable/total_params:.1f}%)") |
|
|
| epochs = cfg_train["epochs"] |
| patience = cfg_train.get("early_stopping_patience", 20) |
| lr = cfg_train.get("lr", 3e-3) |
|
|
| opt = torch.optim.AdamW( |
| filter(lambda p: p.requires_grad, model.parameters()), |
| lr=lr, weight_decay=cfg_train.get("weight_decay", 1e-3), |
| ) |
| sched = torch.optim.lr_scheduler.OneCycleLR( |
| opt, max_lr=lr, epochs=epochs, |
| steps_per_epoch=len(train_loader), pct_start=0.15, |
| ) |
|
|
| print(f"\n{'='*60}") |
| print(f" Stage 1: Frozen backbone, lr={lr}, {epochs} epochs max") |
| print(f"{'='*60}") |
|
|
| best_metric = -1.0 |
| best_epoch = 0 |
| wait = 0 |
| recall = {} |
|
|
| for epoch in range(epochs): |
| model.train() |
| epoch_loss = 0 |
| for x, y in train_loader: |
| x, y = x.to(device), y.to(device) |
| opt.zero_grad() |
| logits = model(x) |
| loss = focal_loss(logits, y, weight=class_weights) if use_focal else criterion_ce(logits, y) |
| loss.backward() |
| opt.step() |
| sched.step() |
| epoch_loss += loss.item() |
|
|
| |
| model.eval() |
| vl, vt = [], [] |
| with torch.no_grad(): |
| for x, y in val_loader: |
| vl.append(model(x.to(device)).cpu()) |
| vt.append(y) |
| vl, vt = torch.cat(vl), torch.cat(vt) |
| recall = per_class_recall(vl, vt, num_classes) |
| min_r = min(recall.values()) |
| macro_r = sum(recall.values()) / num_classes |
| val_acc = (vl.argmax(1) == vt).float().mean().item() |
| metric = min_r if cfg_train.get("early_stopping_metric") == "val_min_recall" else macro_r |
|
|
| star = "" |
| if metric > best_metric: |
| best_metric = metric |
| best_epoch = epoch |
| wait = 0 |
| weights_dir.mkdir(parents=True, exist_ok=True) |
| torch.save({ |
| "model_state_dict": model.state_dict(), |
| "num_classes": num_classes, |
| "idx_to_issue_type": model.idx_to_issue_type, |
| "idx_to_severity": model.idx_to_severity, |
| }, weights_dir / config["output"]["best_name"]) |
| star = " ★" |
| else: |
| wait += 1 |
|
|
| print(f"E{epoch:3d} loss={epoch_loss/len(train_loader):.4f} acc={val_acc:.3f} " |
| f"min_r={min_r:.3f} macro={macro_r:.3f} best={best_metric:.3f}@{best_epoch}{star}") |
|
|
| if wait >= patience: |
| print(f"Early stop @ {epoch}") |
| break |
|
|
| |
| |
| |
| if cfg_train.get("finetune_last_blocks", True) and best_metric > 0.15: |
| print(f"\n{'='*60}") |
| print(f" Stage 2: Partial unfreeze (last 2 blocks)") |
| print(f"{'='*60}") |
|
|
| ckpt = torch.load(weights_dir / config["output"]["best_name"], map_location=device) |
| model.load_state_dict(ckpt["model_state_dict"]) |
|
|
| for p in model.backbone.parameters(): |
| p.requires_grad = False |
| for name, p in model.backbone.named_parameters(): |
| if "features.7" in name or "features.8" in name: |
| p.requires_grad = True |
| |
| for p in model.head.parameters(): |
| p.requires_grad = True |
|
|
| ft_lr = cfg_train.get("finetune_lr", 5e-5) |
| ft_epochs = cfg_train.get("finetune_epochs", 25) |
| opt2 = torch.optim.AdamW( |
| filter(lambda p: p.requires_grad, model.parameters()), |
| lr=ft_lr, weight_decay=1e-3, |
| ) |
| sched2 = torch.optim.lr_scheduler.CosineAnnealingLR(opt2, T_max=ft_epochs, eta_min=1e-6) |
| wait2 = 0 |
|
|
| for epoch in range(ft_epochs): |
| model.train() |
| el = 0 |
| for x, y in train_loader: |
| x, y = x.to(device), y.to(device) |
| opt2.zero_grad() |
| logits = model(x) |
| loss = focal_loss(logits, y, weight=class_weights) if use_focal else criterion_ce(logits, y) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt2.step() |
| el += loss.item() |
| sched2.step() |
|
|
| model.eval() |
| vl, vt = [], [] |
| with torch.no_grad(): |
| for x, y in val_loader: |
| vl.append(model(x.to(device)).cpu()) |
| vt.append(y) |
| vl, vt = torch.cat(vl), torch.cat(vt) |
| recall = per_class_recall(vl, vt, num_classes) |
| min_r = min(recall.values()) |
| macro_r = sum(recall.values()) / num_classes |
| val_acc = (vl.argmax(1) == vt).float().mean().item() |
| metric = min_r if cfg_train.get("early_stopping_metric") == "val_min_recall" else macro_r |
|
|
| star = "" |
| if metric > best_metric: |
| best_metric = metric |
| best_epoch = epoch + 1000 |
| wait2 = 0 |
| torch.save({ |
| "model_state_dict": model.state_dict(), |
| "num_classes": num_classes, |
| "idx_to_issue_type": model.idx_to_issue_type, |
| "idx_to_severity": model.idx_to_severity, |
| }, weights_dir / config["output"]["best_name"]) |
| star = " ★" |
| else: |
| wait2 += 1 |
|
|
| print(f" FT{epoch:3d} loss={el/len(train_loader):.4f} acc={val_acc:.3f} " |
| f"min_r={min_r:.3f} macro={macro_r:.3f} best={best_metric:.3f}{star}") |
| if wait2 >= 10: |
| print(f" FT early stop @ {epoch}") |
| break |
|
|
| |
| if config.get("calibration", {}).get("use_temperature_scaling", False): |
| ckpt = torch.load(weights_dir / config["output"]["best_name"], map_location=device) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| model.eval() |
| cal_size = max(1, int(len(val_ds) * 0.5)) |
| cl, ct = [], [] |
| for i in range(cal_size): |
| x, y = val_ds[i] |
| with torch.no_grad(): |
| cl.append(model(x.unsqueeze(0).to(device)).cpu()) |
| ct.append(y) |
| cl, ct = torch.cat(cl), torch.tensor(ct) |
| temp = nn.Parameter(torch.ones(1) * 1.5) |
| opt_c = torch.optim.LBFGS([temp], lr=0.01, max_iter=50) |
| def eval_c(): |
| opt_c.zero_grad() |
| l = F.cross_entropy(cl / temp, ct) |
| l.backward() |
| return l |
| opt_c.step(eval_c) |
| ckpt["temperature"] = temp.item() |
| torch.save(ckpt, weights_dir / config["output"]["best_name"]) |
| print(f"Temperature T={temp.item():.4f}") |
|
|
| print(f"\n{'='*60}") |
| print(f" DONE — Best: {best_metric:.4f}") |
| per_cls = " | ".join([f"C{c}={r:.2f}" for c, r in recall.items()]) |
| print(f" Recall: {per_cls}") |
| print(f"{'='*60}\n") |
|
|
| return {"best_epoch": best_epoch, "best_metric": best_metric, "recall_per_class": recall} |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", default="config/image_train_config.yaml") |
| parser.add_argument("--data_root", default=None) |
| parser.add_argument("--weights_dir", default="weights") |
| parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") |
| args = parser.parse_args() |
| root = ROOT |
| config = load_config(root / args.config) |
| data_root = Path(args.data_root) if args.data_root else root / config["data"]["root"] |
| label_mapping_path = root / config["data"]["label_mapping"] |
| weights_dir = root / args.weights_dir |
| results = run_training(data_root, label_mapping_path, config, weights_dir, args.device) |
|
|
| report_path = root / "docs" / config["output"]["report_name"] |
| report_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(report_path, "w") as f: |
| f.write("# Image Model Report (Electrical Outlets)\n\n") |
| f.write(f"- Best metric: {results['best_metric']:.4f}\n") |
| f.write(f"- Classes: 5 (burn, cracked, loose, normal, water)\n\n") |
| f.write("## Per-class recall\n\n") |
| issue_names = ["burn_overheating", "cracked_faceplate", "loose_outlet", "normal", "water_exposed"] |
| for c, r in results.get("recall_per_class", {}).items(): |
| name = issue_names[c] if c < len(issue_names) else f"class_{c}" |
| f.write(f"- {name}: {r:.4f}\n") |
| print("Report:", report_path) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|