budijuarto's picture
Upload src/egg_damage/dl_models.py
c3f8b3c verified
from __future__ import annotations
from pathlib import Path
from typing import Any
import torch
from torch import nn
from .utils import get_logger
LOGGER = get_logger(__name__)
MODEL_REGISTRY: dict[str, dict[str, Any]] = {
"mobilenet_v3": {
"timm_name": "mobilenetv3_small_100",
"display_name": "MobileNetV3 Small",
"family": "cnn",
},
"resnet50": {"timm_name": "resnet50", "display_name": "ResNet50", "family": "cnn"},
"efficientnet_b0": {
"timm_name": "efficientnet_b0",
"display_name": "EfficientNetB0",
"family": "cnn",
},
"densenet121": {"timm_name": "densenet121", "display_name": "DenseNet121", "family": "cnn"},
"xception": {
"timm_name": "xception",
"aliases": ["xception", "legacy_xception"],
"display_name": "Xception",
"family": "cnn",
},
"vit_small": {
"timm_name": "vit_small_patch16_224",
"display_name": "ViT Small",
"family": "transformer",
},
}
def create_model(model_key: str, config: dict[str, Any], pretrained: bool | None = None) -> nn.Module:
if model_key not in MODEL_REGISTRY:
raise ValueError(f"Unknown deep model key: {model_key}")
try:
import timm
except ImportError as exc:
raise ImportError("Install timm to train deep learning models: pip install timm") from exc
spec = MODEL_REGISTRY[model_key]
use_pretrained = bool(config["models"].get("pretrained", True) if pretrained is None else pretrained)
names_to_try = list(dict.fromkeys(spec.get("aliases", [spec["timm_name"]])))
last_exc: Exception | None = None
pretrained_loaded = False
for timm_name in names_to_try:
try:
model = timm.create_model(timm_name, pretrained=use_pretrained, num_classes=2)
spec["timm_name"] = timm_name
pretrained_loaded = use_pretrained
break
except Exception as exc:
last_exc = exc
if use_pretrained:
try:
LOGGER.warning(
"Could not load pretrained weights for %s/%s (%s). Trying random init.",
model_key,
timm_name,
exc,
)
model = timm.create_model(timm_name, pretrained=False, num_classes=2)
spec["timm_name"] = timm_name
pretrained_loaded = False
break
except Exception as random_exc:
last_exc = random_exc
continue
else:
raise RuntimeError(f"Could not create timm model for {model_key}: {last_exc}") from last_exc
setattr(model, "_egg_pretrained_loaded", pretrained_loaded)
setattr(model, "_egg_timm_name", spec["timm_name"])
return model
def freeze_backbone_except_head(model: nn.Module) -> None:
for param in model.parameters():
param.requires_grad = False
head_tokens = ("classifier", "head", "fc", "last_linear")
trainable = 0
for name, param in model.named_parameters():
if any(token in name.lower() for token in head_tokens):
param.requires_grad = True
trainable += param.numel()
if trainable == 0:
LOGGER.warning("Could not identify classifier head; leaving all parameters trainable.")
for param in model.parameters():
param.requires_grad = True
def count_parameters(model: nn.Module) -> dict[str, int]:
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {"total_parameters": int(total), "trainable_parameters": int(trainable)}
def checkpoint_payload(
model: nn.Module,
model_key: str,
config: dict[str, Any],
history: list[dict[str, Any]],
best_metric: float,
) -> dict[str, Any]:
spec = MODEL_REGISTRY[model_key]
return {
"model_name": model_key,
"model_type": "deep_learning",
"model_key": model_key,
"timm_name": getattr(model, "_egg_timm_name", spec["timm_name"]),
"family": spec["family"],
"pretrained_loaded": bool(getattr(model, "_egg_pretrained_loaded", False)),
"state_dict": model.state_dict(),
"class_names": ["Not Damaged", "Damaged"],
"positive_class": "Damaged",
"image_size": int(config["preprocessing"]["image_size"]),
"threshold": float(config["evaluation"].get("threshold", 0.5)),
"history": history,
"best_val_f1": float(best_metric),
"config": config,
**count_parameters(model),
}
def load_torch_checkpoint(checkpoint_path: str | Path, map_location: str | torch.device = "cpu") -> dict[str, Any]:
"""Load project checkpoints across PyTorch versions.
PyTorch 2.6 defaults ``torch.load`` to ``weights_only=True``. These
checkpoints intentionally include metadata/config dictionaries, so they
need the legacy full checkpoint loading path.
"""
try:
return torch.load(checkpoint_path, map_location=map_location, weights_only=False)
except TypeError:
return torch.load(checkpoint_path, map_location=map_location)
def load_checkpoint_model(checkpoint_path: str, config: dict[str, Any] | None = None) -> tuple[nn.Module, dict[str, Any]]:
checkpoint = load_torch_checkpoint(checkpoint_path, map_location="cpu")
cfg = config or checkpoint.get("config")
if cfg is None:
raise ValueError("Checkpoint does not include a config; provide one explicitly.")
model = create_model(checkpoint["model_key"], cfg, pretrained=False)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
return model, checkpoint