Asadrizvi64's picture
Electrical Outlets diagnostic pipeline v1.0
5666923
"""
Train Electrical Outlets image model.
FINAL v5: Frozen backbone → partial unfreeze. 5 classes, 1300 images.
"""
from pathlib import Path
import sys
import argparse
from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
from src.data.image_dataset import ElectricalOutletsImageDataset, get_image_class_weights
from src.models.image_model import ElectricalOutletsImageModel
def load_config(path):
import yaml
with open(path) as f:
return yaml.safe_load(f)
def focal_loss(logits, targets, alpha=0.25, gamma=2.0, weight=None):
ce = F.cross_entropy(logits, targets, reduction="none", weight=weight)
pt = torch.exp(-ce)
return (alpha * (1 - pt) ** gamma * ce).mean()
def per_class_recall(logits, targets, num_classes):
preds = logits.argmax(dim=1)
recall = {}
for c in range(num_classes):
mask = targets == c
recall[c] = (preds[mask] == c).float().mean().item() if mask.sum() > 0 else 0.0
return recall
def run_training(data_root, label_mapping_path, config, weights_dir, device="cuda"):
cfg_data = config["data"]
cfg_train = config["training"]
cfg_aug = config["augmentation"]
cfg_model = config["model"]
# Transforms
train_tf = transforms.Compose([
transforms.Resize(cfg_aug["resize"]),
transforms.RandomResizedCrop(cfg_aug["crop"], scale=(0.65, 1.0)),
transforms.RandomHorizontalFlip(0.5),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.05),
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
transforms.GaussianBlur(3, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
transforms.RandomErasing(p=0.15),
])
val_tf = transforms.Compose([
transforms.Resize(cfg_aug["resize"]),
transforms.CenterCrop(cfg_aug["crop"]),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# Datasets
train_ds = ElectricalOutletsImageDataset(
data_root, label_mapping_path, split="train",
train_ratio=cfg_data["train_ratio"], val_ratio=cfg_data["val_ratio"],
seed=cfg_data.get("seed", 42), transform=train_tf,
)
val_ds = ElectricalOutletsImageDataset(
data_root, label_mapping_path, split="val",
train_ratio=cfg_data["train_ratio"], val_ratio=cfg_data["val_ratio"],
seed=cfg_data.get("seed", 42), transform=val_tf,
)
train_loader = DataLoader(train_ds, batch_size=cfg_data["batch_size"], shuffle=True,
num_workers=cfg_data.get("num_workers", 4), pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=cfg_data["batch_size"], shuffle=False,
num_workers=cfg_data.get("num_workers", 4))
num_classes = train_ds.num_classes
print(f"\nTrain: {len(train_ds)}, Val: {len(val_ds)}, Classes: {num_classes}")
# Class weights
class_weights = None
if cfg_train.get("use_class_weights", True):
class_weights = get_image_class_weights(label_mapping_path, data_root).to(device)
print(f"Class weights: {[f'{w:.3f}' for w in class_weights.tolist()]}")
use_focal = cfg_train.get("use_focal", True)
criterion_ce = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
# Model
model = ElectricalOutletsImageModel(
num_classes=num_classes,
label_mapping_path=label_mapping_path,
pretrained=True,
head_hidden=cfg_model.get("head_hidden", 256),
head_dropout=cfg_model.get("head_dropout", 0.4),
).to(device)
# ══════════════════════════════════════════════
# STAGE 1: Frozen backbone — train head only
# ══════════════════════════════════════════════
for p in model.backbone.parameters():
p.requires_grad = False
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Params: {trainable:,} trainable / {total_params:,} total ({100*trainable/total_params:.1f}%)")
epochs = cfg_train["epochs"]
patience = cfg_train.get("early_stopping_patience", 20)
lr = cfg_train.get("lr", 3e-3)
opt = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=lr, weight_decay=cfg_train.get("weight_decay", 1e-3),
)
sched = torch.optim.lr_scheduler.OneCycleLR(
opt, max_lr=lr, epochs=epochs,
steps_per_epoch=len(train_loader), pct_start=0.15,
)
print(f"\n{'='*60}")
print(f" Stage 1: Frozen backbone, lr={lr}, {epochs} epochs max")
print(f"{'='*60}")
best_metric = -1.0
best_epoch = 0
wait = 0
recall = {}
for epoch in range(epochs):
model.train()
epoch_loss = 0
for x, y in train_loader:
x, y = x.to(device), y.to(device)
opt.zero_grad()
logits = model(x)
loss = focal_loss(logits, y, weight=class_weights) if use_focal else criterion_ce(logits, y)
loss.backward()
opt.step()
sched.step()
epoch_loss += loss.item()
# Validate
model.eval()
vl, vt = [], []
with torch.no_grad():
for x, y in val_loader:
vl.append(model(x.to(device)).cpu())
vt.append(y)
vl, vt = torch.cat(vl), torch.cat(vt)
recall = per_class_recall(vl, vt, num_classes)
min_r = min(recall.values())
macro_r = sum(recall.values()) / num_classes
val_acc = (vl.argmax(1) == vt).float().mean().item()
metric = min_r if cfg_train.get("early_stopping_metric") == "val_min_recall" else macro_r
star = ""
if metric > best_metric:
best_metric = metric
best_epoch = epoch
wait = 0
weights_dir.mkdir(parents=True, exist_ok=True)
torch.save({
"model_state_dict": model.state_dict(),
"num_classes": num_classes,
"idx_to_issue_type": model.idx_to_issue_type,
"idx_to_severity": model.idx_to_severity,
}, weights_dir / config["output"]["best_name"])
star = " ★"
else:
wait += 1
print(f"E{epoch:3d} loss={epoch_loss/len(train_loader):.4f} acc={val_acc:.3f} "
f"min_r={min_r:.3f} macro={macro_r:.3f} best={best_metric:.3f}@{best_epoch}{star}")
if wait >= patience:
print(f"Early stop @ {epoch}")
break
# ══════════════════════════════════════════════
# STAGE 2: Unfreeze last 2 backbone blocks
# ══════════════════════════════════════════════
if cfg_train.get("finetune_last_blocks", True) and best_metric > 0.15:
print(f"\n{'='*60}")
print(f" Stage 2: Partial unfreeze (last 2 blocks)")
print(f"{'='*60}")
ckpt = torch.load(weights_dir / config["output"]["best_name"], map_location=device)
model.load_state_dict(ckpt["model_state_dict"])
for p in model.backbone.parameters():
p.requires_grad = False
for name, p in model.backbone.named_parameters():
if "features.7" in name or "features.8" in name:
p.requires_grad = True
# Head stays trainable
for p in model.head.parameters():
p.requires_grad = True
ft_lr = cfg_train.get("finetune_lr", 5e-5)
ft_epochs = cfg_train.get("finetune_epochs", 25)
opt2 = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=ft_lr, weight_decay=1e-3,
)
sched2 = torch.optim.lr_scheduler.CosineAnnealingLR(opt2, T_max=ft_epochs, eta_min=1e-6)
wait2 = 0
for epoch in range(ft_epochs):
model.train()
el = 0
for x, y in train_loader:
x, y = x.to(device), y.to(device)
opt2.zero_grad()
logits = model(x)
loss = focal_loss(logits, y, weight=class_weights) if use_focal else criterion_ce(logits, y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt2.step()
el += loss.item()
sched2.step()
model.eval()
vl, vt = [], []
with torch.no_grad():
for x, y in val_loader:
vl.append(model(x.to(device)).cpu())
vt.append(y)
vl, vt = torch.cat(vl), torch.cat(vt)
recall = per_class_recall(vl, vt, num_classes)
min_r = min(recall.values())
macro_r = sum(recall.values()) / num_classes
val_acc = (vl.argmax(1) == vt).float().mean().item()
metric = min_r if cfg_train.get("early_stopping_metric") == "val_min_recall" else macro_r
star = ""
if metric > best_metric:
best_metric = metric
best_epoch = epoch + 1000
wait2 = 0
torch.save({
"model_state_dict": model.state_dict(),
"num_classes": num_classes,
"idx_to_issue_type": model.idx_to_issue_type,
"idx_to_severity": model.idx_to_severity,
}, weights_dir / config["output"]["best_name"])
star = " ★"
else:
wait2 += 1
print(f" FT{epoch:3d} loss={el/len(train_loader):.4f} acc={val_acc:.3f} "
f"min_r={min_r:.3f} macro={macro_r:.3f} best={best_metric:.3f}{star}")
if wait2 >= 10:
print(f" FT early stop @ {epoch}")
break
# Temperature scaling
if config.get("calibration", {}).get("use_temperature_scaling", False):
ckpt = torch.load(weights_dir / config["output"]["best_name"], map_location=device)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
cal_size = max(1, int(len(val_ds) * 0.5))
cl, ct = [], []
for i in range(cal_size):
x, y = val_ds[i]
with torch.no_grad():
cl.append(model(x.unsqueeze(0).to(device)).cpu())
ct.append(y)
cl, ct = torch.cat(cl), torch.tensor(ct)
temp = nn.Parameter(torch.ones(1) * 1.5)
opt_c = torch.optim.LBFGS([temp], lr=0.01, max_iter=50)
def eval_c():
opt_c.zero_grad()
l = F.cross_entropy(cl / temp, ct)
l.backward()
return l
opt_c.step(eval_c)
ckpt["temperature"] = temp.item()
torch.save(ckpt, weights_dir / config["output"]["best_name"])
print(f"Temperature T={temp.item():.4f}")
print(f"\n{'='*60}")
print(f" DONE — Best: {best_metric:.4f}")
per_cls = " | ".join([f"C{c}={r:.2f}" for c, r in recall.items()])
print(f" Recall: {per_cls}")
print(f"{'='*60}\n")
return {"best_epoch": best_epoch, "best_metric": best_metric, "recall_per_class": recall}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", default="config/image_train_config.yaml")
parser.add_argument("--data_root", default=None)
parser.add_argument("--weights_dir", default="weights")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args()
root = ROOT
config = load_config(root / args.config)
data_root = Path(args.data_root) if args.data_root else root / config["data"]["root"]
label_mapping_path = root / config["data"]["label_mapping"]
weights_dir = root / args.weights_dir
results = run_training(data_root, label_mapping_path, config, weights_dir, args.device)
report_path = root / "docs" / config["output"]["report_name"]
report_path.parent.mkdir(parents=True, exist_ok=True)
with open(report_path, "w") as f:
f.write("# Image Model Report (Electrical Outlets)\n\n")
f.write(f"- Best metric: {results['best_metric']:.4f}\n")
f.write(f"- Classes: 5 (burn, cracked, loose, normal, water)\n\n")
f.write("## Per-class recall\n\n")
issue_names = ["burn_overheating", "cracked_faceplate", "loose_outlet", "normal", "water_exposed"]
for c, r in results.get("recall_per_class", {}).items():
name = issue_names[c] if c < len(issue_names) else f"class_{c}"
f.write(f"- {name}: {r:.4f}\n")
print("Report:", report_path)
if __name__ == "__main__":
main()