| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import Any |
|
|
| import torch |
| from torch import nn |
|
|
| from .utils import get_logger |
|
|
|
|
| LOGGER = get_logger(__name__) |
|
|
| MODEL_REGISTRY: dict[str, dict[str, Any]] = { |
| "mobilenet_v3": { |
| "timm_name": "mobilenetv3_small_100", |
| "display_name": "MobileNetV3 Small", |
| "family": "cnn", |
| }, |
| "resnet50": {"timm_name": "resnet50", "display_name": "ResNet50", "family": "cnn"}, |
| "efficientnet_b0": { |
| "timm_name": "efficientnet_b0", |
| "display_name": "EfficientNetB0", |
| "family": "cnn", |
| }, |
| "densenet121": {"timm_name": "densenet121", "display_name": "DenseNet121", "family": "cnn"}, |
| "xception": { |
| "timm_name": "xception", |
| "aliases": ["xception", "legacy_xception"], |
| "display_name": "Xception", |
| "family": "cnn", |
| }, |
| "vit_small": { |
| "timm_name": "vit_small_patch16_224", |
| "display_name": "ViT Small", |
| "family": "transformer", |
| }, |
| } |
|
|
|
|
| def create_model(model_key: str, config: dict[str, Any], pretrained: bool | None = None) -> nn.Module: |
| if model_key not in MODEL_REGISTRY: |
| raise ValueError(f"Unknown deep model key: {model_key}") |
| try: |
| import timm |
| except ImportError as exc: |
| raise ImportError("Install timm to train deep learning models: pip install timm") from exc |
|
|
| spec = MODEL_REGISTRY[model_key] |
| use_pretrained = bool(config["models"].get("pretrained", True) if pretrained is None else pretrained) |
| names_to_try = list(dict.fromkeys(spec.get("aliases", [spec["timm_name"]]))) |
| last_exc: Exception | None = None |
| pretrained_loaded = False |
| for timm_name in names_to_try: |
| try: |
| model = timm.create_model(timm_name, pretrained=use_pretrained, num_classes=2) |
| spec["timm_name"] = timm_name |
| pretrained_loaded = use_pretrained |
| break |
| except Exception as exc: |
| last_exc = exc |
| if use_pretrained: |
| try: |
| LOGGER.warning( |
| "Could not load pretrained weights for %s/%s (%s). Trying random init.", |
| model_key, |
| timm_name, |
| exc, |
| ) |
| model = timm.create_model(timm_name, pretrained=False, num_classes=2) |
| spec["timm_name"] = timm_name |
| pretrained_loaded = False |
| break |
| except Exception as random_exc: |
| last_exc = random_exc |
| continue |
| else: |
| raise RuntimeError(f"Could not create timm model for {model_key}: {last_exc}") from last_exc |
| setattr(model, "_egg_pretrained_loaded", pretrained_loaded) |
| setattr(model, "_egg_timm_name", spec["timm_name"]) |
| return model |
|
|
|
|
| def freeze_backbone_except_head(model: nn.Module) -> None: |
| for param in model.parameters(): |
| param.requires_grad = False |
| head_tokens = ("classifier", "head", "fc", "last_linear") |
| trainable = 0 |
| for name, param in model.named_parameters(): |
| if any(token in name.lower() for token in head_tokens): |
| param.requires_grad = True |
| trainable += param.numel() |
| if trainable == 0: |
| LOGGER.warning("Could not identify classifier head; leaving all parameters trainable.") |
| for param in model.parameters(): |
| param.requires_grad = True |
|
|
|
|
| def count_parameters(model: nn.Module) -> dict[str, int]: |
| total = sum(p.numel() for p in model.parameters()) |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| return {"total_parameters": int(total), "trainable_parameters": int(trainable)} |
|
|
|
|
| def checkpoint_payload( |
| model: nn.Module, |
| model_key: str, |
| config: dict[str, Any], |
| history: list[dict[str, Any]], |
| best_metric: float, |
| ) -> dict[str, Any]: |
| spec = MODEL_REGISTRY[model_key] |
| return { |
| "model_name": model_key, |
| "model_type": "deep_learning", |
| "model_key": model_key, |
| "timm_name": getattr(model, "_egg_timm_name", spec["timm_name"]), |
| "family": spec["family"], |
| "pretrained_loaded": bool(getattr(model, "_egg_pretrained_loaded", False)), |
| "state_dict": model.state_dict(), |
| "class_names": ["Not Damaged", "Damaged"], |
| "positive_class": "Damaged", |
| "image_size": int(config["preprocessing"]["image_size"]), |
| "threshold": float(config["evaluation"].get("threshold", 0.5)), |
| "history": history, |
| "best_val_f1": float(best_metric), |
| "config": config, |
| **count_parameters(model), |
| } |
|
|
|
|
| def load_torch_checkpoint(checkpoint_path: str | Path, map_location: str | torch.device = "cpu") -> dict[str, Any]: |
| """Load project checkpoints across PyTorch versions. |
| |
| PyTorch 2.6 defaults ``torch.load`` to ``weights_only=True``. These |
| checkpoints intentionally include metadata/config dictionaries, so they |
| need the legacy full checkpoint loading path. |
| """ |
| try: |
| return torch.load(checkpoint_path, map_location=map_location, weights_only=False) |
| except TypeError: |
| return torch.load(checkpoint_path, map_location=map_location) |
|
|
|
|
| def load_checkpoint_model(checkpoint_path: str, config: dict[str, Any] | None = None) -> tuple[nn.Module, dict[str, Any]]: |
| checkpoint = load_torch_checkpoint(checkpoint_path, map_location="cpu") |
| cfg = config or checkpoint.get("config") |
| if cfg is None: |
| raise ValueError("Checkpoint does not include a config; provide one explicitly.") |
| model = create_model(checkpoint["model_key"], cfg, pretrained=False) |
| model.load_state_dict(checkpoint["state_dict"]) |
| model.eval() |
| return model, checkpoint |
|
|