from __future__ import annotations from pathlib import Path from typing import Any import torch from torch import nn from .utils import get_logger LOGGER = get_logger(__name__) MODEL_REGISTRY: dict[str, dict[str, Any]] = { "mobilenet_v3": { "timm_name": "mobilenetv3_small_100", "display_name": "MobileNetV3 Small", "family": "cnn", }, "resnet50": {"timm_name": "resnet50", "display_name": "ResNet50", "family": "cnn"}, "efficientnet_b0": { "timm_name": "efficientnet_b0", "display_name": "EfficientNetB0", "family": "cnn", }, "densenet121": {"timm_name": "densenet121", "display_name": "DenseNet121", "family": "cnn"}, "xception": { "timm_name": "xception", "aliases": ["xception", "legacy_xception"], "display_name": "Xception", "family": "cnn", }, "vit_small": { "timm_name": "vit_small_patch16_224", "display_name": "ViT Small", "family": "transformer", }, } def create_model(model_key: str, config: dict[str, Any], pretrained: bool | None = None) -> nn.Module: if model_key not in MODEL_REGISTRY: raise ValueError(f"Unknown deep model key: {model_key}") try: import timm except ImportError as exc: raise ImportError("Install timm to train deep learning models: pip install timm") from exc spec = MODEL_REGISTRY[model_key] use_pretrained = bool(config["models"].get("pretrained", True) if pretrained is None else pretrained) names_to_try = list(dict.fromkeys(spec.get("aliases", [spec["timm_name"]]))) last_exc: Exception | None = None pretrained_loaded = False for timm_name in names_to_try: try: model = timm.create_model(timm_name, pretrained=use_pretrained, num_classes=2) spec["timm_name"] = timm_name pretrained_loaded = use_pretrained break except Exception as exc: last_exc = exc if use_pretrained: try: LOGGER.warning( "Could not load pretrained weights for %s/%s (%s). Trying random init.", model_key, timm_name, exc, ) model = timm.create_model(timm_name, pretrained=False, num_classes=2) spec["timm_name"] = timm_name pretrained_loaded = False break except Exception as random_exc: last_exc = random_exc continue else: raise RuntimeError(f"Could not create timm model for {model_key}: {last_exc}") from last_exc setattr(model, "_egg_pretrained_loaded", pretrained_loaded) setattr(model, "_egg_timm_name", spec["timm_name"]) return model def freeze_backbone_except_head(model: nn.Module) -> None: for param in model.parameters(): param.requires_grad = False head_tokens = ("classifier", "head", "fc", "last_linear") trainable = 0 for name, param in model.named_parameters(): if any(token in name.lower() for token in head_tokens): param.requires_grad = True trainable += param.numel() if trainable == 0: LOGGER.warning("Could not identify classifier head; leaving all parameters trainable.") for param in model.parameters(): param.requires_grad = True def count_parameters(model: nn.Module) -> dict[str, int]: total = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) return {"total_parameters": int(total), "trainable_parameters": int(trainable)} def checkpoint_payload( model: nn.Module, model_key: str, config: dict[str, Any], history: list[dict[str, Any]], best_metric: float, ) -> dict[str, Any]: spec = MODEL_REGISTRY[model_key] return { "model_name": model_key, "model_type": "deep_learning", "model_key": model_key, "timm_name": getattr(model, "_egg_timm_name", spec["timm_name"]), "family": spec["family"], "pretrained_loaded": bool(getattr(model, "_egg_pretrained_loaded", False)), "state_dict": model.state_dict(), "class_names": ["Not Damaged", "Damaged"], "positive_class": "Damaged", "image_size": int(config["preprocessing"]["image_size"]), "threshold": float(config["evaluation"].get("threshold", 0.5)), "history": history, "best_val_f1": float(best_metric), "config": config, **count_parameters(model), } def load_torch_checkpoint(checkpoint_path: str | Path, map_location: str | torch.device = "cpu") -> dict[str, Any]: """Load project checkpoints across PyTorch versions. PyTorch 2.6 defaults ``torch.load`` to ``weights_only=True``. These checkpoints intentionally include metadata/config dictionaries, so they need the legacy full checkpoint loading path. """ try: return torch.load(checkpoint_path, map_location=map_location, weights_only=False) except TypeError: return torch.load(checkpoint_path, map_location=map_location) def load_checkpoint_model(checkpoint_path: str, config: dict[str, Any] | None = None) -> tuple[nn.Module, dict[str, Any]]: checkpoint = load_torch_checkpoint(checkpoint_path, map_location="cpu") cfg = config or checkpoint.get("config") if cfg is None: raise ValueError("Checkpoint does not include a config; provide one explicitly.") model = create_model(checkpoint["model_key"], cfg, pretrained=False) model.load_state_dict(checkpoint["state_dict"]) model.eval() return model, checkpoint