File size: 5,775 Bytes
c3f8b3c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | from __future__ import annotations
from pathlib import Path
from typing import Any
import torch
from torch import nn
from .utils import get_logger
LOGGER = get_logger(__name__)
MODEL_REGISTRY: dict[str, dict[str, Any]] = {
"mobilenet_v3": {
"timm_name": "mobilenetv3_small_100",
"display_name": "MobileNetV3 Small",
"family": "cnn",
},
"resnet50": {"timm_name": "resnet50", "display_name": "ResNet50", "family": "cnn"},
"efficientnet_b0": {
"timm_name": "efficientnet_b0",
"display_name": "EfficientNetB0",
"family": "cnn",
},
"densenet121": {"timm_name": "densenet121", "display_name": "DenseNet121", "family": "cnn"},
"xception": {
"timm_name": "xception",
"aliases": ["xception", "legacy_xception"],
"display_name": "Xception",
"family": "cnn",
},
"vit_small": {
"timm_name": "vit_small_patch16_224",
"display_name": "ViT Small",
"family": "transformer",
},
}
def create_model(model_key: str, config: dict[str, Any], pretrained: bool | None = None) -> nn.Module:
if model_key not in MODEL_REGISTRY:
raise ValueError(f"Unknown deep model key: {model_key}")
try:
import timm
except ImportError as exc:
raise ImportError("Install timm to train deep learning models: pip install timm") from exc
spec = MODEL_REGISTRY[model_key]
use_pretrained = bool(config["models"].get("pretrained", True) if pretrained is None else pretrained)
names_to_try = list(dict.fromkeys(spec.get("aliases", [spec["timm_name"]])))
last_exc: Exception | None = None
pretrained_loaded = False
for timm_name in names_to_try:
try:
model = timm.create_model(timm_name, pretrained=use_pretrained, num_classes=2)
spec["timm_name"] = timm_name
pretrained_loaded = use_pretrained
break
except Exception as exc:
last_exc = exc
if use_pretrained:
try:
LOGGER.warning(
"Could not load pretrained weights for %s/%s (%s). Trying random init.",
model_key,
timm_name,
exc,
)
model = timm.create_model(timm_name, pretrained=False, num_classes=2)
spec["timm_name"] = timm_name
pretrained_loaded = False
break
except Exception as random_exc:
last_exc = random_exc
continue
else:
raise RuntimeError(f"Could not create timm model for {model_key}: {last_exc}") from last_exc
setattr(model, "_egg_pretrained_loaded", pretrained_loaded)
setattr(model, "_egg_timm_name", spec["timm_name"])
return model
def freeze_backbone_except_head(model: nn.Module) -> None:
for param in model.parameters():
param.requires_grad = False
head_tokens = ("classifier", "head", "fc", "last_linear")
trainable = 0
for name, param in model.named_parameters():
if any(token in name.lower() for token in head_tokens):
param.requires_grad = True
trainable += param.numel()
if trainable == 0:
LOGGER.warning("Could not identify classifier head; leaving all parameters trainable.")
for param in model.parameters():
param.requires_grad = True
def count_parameters(model: nn.Module) -> dict[str, int]:
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {"total_parameters": int(total), "trainable_parameters": int(trainable)}
def checkpoint_payload(
model: nn.Module,
model_key: str,
config: dict[str, Any],
history: list[dict[str, Any]],
best_metric: float,
) -> dict[str, Any]:
spec = MODEL_REGISTRY[model_key]
return {
"model_name": model_key,
"model_type": "deep_learning",
"model_key": model_key,
"timm_name": getattr(model, "_egg_timm_name", spec["timm_name"]),
"family": spec["family"],
"pretrained_loaded": bool(getattr(model, "_egg_pretrained_loaded", False)),
"state_dict": model.state_dict(),
"class_names": ["Not Damaged", "Damaged"],
"positive_class": "Damaged",
"image_size": int(config["preprocessing"]["image_size"]),
"threshold": float(config["evaluation"].get("threshold", 0.5)),
"history": history,
"best_val_f1": float(best_metric),
"config": config,
**count_parameters(model),
}
def load_torch_checkpoint(checkpoint_path: str | Path, map_location: str | torch.device = "cpu") -> dict[str, Any]:
"""Load project checkpoints across PyTorch versions.
PyTorch 2.6 defaults ``torch.load`` to ``weights_only=True``. These
checkpoints intentionally include metadata/config dictionaries, so they
need the legacy full checkpoint loading path.
"""
try:
return torch.load(checkpoint_path, map_location=map_location, weights_only=False)
except TypeError:
return torch.load(checkpoint_path, map_location=map_location)
def load_checkpoint_model(checkpoint_path: str, config: dict[str, Any] | None = None) -> tuple[nn.Module, dict[str, Any]]:
checkpoint = load_torch_checkpoint(checkpoint_path, map_location="cpu")
cfg = config or checkpoint.get("config")
if cfg is None:
raise ValueError("Checkpoint does not include a config; provide one explicitly.")
model = create_model(checkpoint["model_key"], cfg, pretrained=False)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
return model, checkpoint
|