File size: 5,775 Bytes
c3f8b3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from __future__ import annotations

from pathlib import Path
from typing import Any

import torch
from torch import nn

from .utils import get_logger


LOGGER = get_logger(__name__)

MODEL_REGISTRY: dict[str, dict[str, Any]] = {
    "mobilenet_v3": {
        "timm_name": "mobilenetv3_small_100",
        "display_name": "MobileNetV3 Small",
        "family": "cnn",
    },
    "resnet50": {"timm_name": "resnet50", "display_name": "ResNet50", "family": "cnn"},
    "efficientnet_b0": {
        "timm_name": "efficientnet_b0",
        "display_name": "EfficientNetB0",
        "family": "cnn",
    },
    "densenet121": {"timm_name": "densenet121", "display_name": "DenseNet121", "family": "cnn"},
    "xception": {
        "timm_name": "xception",
        "aliases": ["xception", "legacy_xception"],
        "display_name": "Xception",
        "family": "cnn",
    },
    "vit_small": {
        "timm_name": "vit_small_patch16_224",
        "display_name": "ViT Small",
        "family": "transformer",
    },
}


def create_model(model_key: str, config: dict[str, Any], pretrained: bool | None = None) -> nn.Module:
    if model_key not in MODEL_REGISTRY:
        raise ValueError(f"Unknown deep model key: {model_key}")
    try:
        import timm
    except ImportError as exc:
        raise ImportError("Install timm to train deep learning models: pip install timm") from exc

    spec = MODEL_REGISTRY[model_key]
    use_pretrained = bool(config["models"].get("pretrained", True) if pretrained is None else pretrained)
    names_to_try = list(dict.fromkeys(spec.get("aliases", [spec["timm_name"]])))
    last_exc: Exception | None = None
    pretrained_loaded = False
    for timm_name in names_to_try:
        try:
            model = timm.create_model(timm_name, pretrained=use_pretrained, num_classes=2)
            spec["timm_name"] = timm_name
            pretrained_loaded = use_pretrained
            break
        except Exception as exc:
            last_exc = exc
            if use_pretrained:
                try:
                    LOGGER.warning(
                        "Could not load pretrained weights for %s/%s (%s). Trying random init.",
                        model_key,
                        timm_name,
                        exc,
                    )
                    model = timm.create_model(timm_name, pretrained=False, num_classes=2)
                    spec["timm_name"] = timm_name
                    pretrained_loaded = False
                    break
                except Exception as random_exc:
                    last_exc = random_exc
                    continue
    else:
        raise RuntimeError(f"Could not create timm model for {model_key}: {last_exc}") from last_exc
    setattr(model, "_egg_pretrained_loaded", pretrained_loaded)
    setattr(model, "_egg_timm_name", spec["timm_name"])
    return model


def freeze_backbone_except_head(model: nn.Module) -> None:
    for param in model.parameters():
        param.requires_grad = False
    head_tokens = ("classifier", "head", "fc", "last_linear")
    trainable = 0
    for name, param in model.named_parameters():
        if any(token in name.lower() for token in head_tokens):
            param.requires_grad = True
            trainable += param.numel()
    if trainable == 0:
        LOGGER.warning("Could not identify classifier head; leaving all parameters trainable.")
        for param in model.parameters():
            param.requires_grad = True


def count_parameters(model: nn.Module) -> dict[str, int]:
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {"total_parameters": int(total), "trainable_parameters": int(trainable)}


def checkpoint_payload(
    model: nn.Module,
    model_key: str,
    config: dict[str, Any],
    history: list[dict[str, Any]],
    best_metric: float,
) -> dict[str, Any]:
    spec = MODEL_REGISTRY[model_key]
    return {
        "model_name": model_key,
        "model_type": "deep_learning",
        "model_key": model_key,
        "timm_name": getattr(model, "_egg_timm_name", spec["timm_name"]),
        "family": spec["family"],
        "pretrained_loaded": bool(getattr(model, "_egg_pretrained_loaded", False)),
        "state_dict": model.state_dict(),
        "class_names": ["Not Damaged", "Damaged"],
        "positive_class": "Damaged",
        "image_size": int(config["preprocessing"]["image_size"]),
        "threshold": float(config["evaluation"].get("threshold", 0.5)),
        "history": history,
        "best_val_f1": float(best_metric),
        "config": config,
        **count_parameters(model),
    }


def load_torch_checkpoint(checkpoint_path: str | Path, map_location: str | torch.device = "cpu") -> dict[str, Any]:
    """Load project checkpoints across PyTorch versions.

    PyTorch 2.6 defaults ``torch.load`` to ``weights_only=True``. These
    checkpoints intentionally include metadata/config dictionaries, so they
    need the legacy full checkpoint loading path.
    """
    try:
        return torch.load(checkpoint_path, map_location=map_location, weights_only=False)
    except TypeError:
        return torch.load(checkpoint_path, map_location=map_location)


def load_checkpoint_model(checkpoint_path: str, config: dict[str, Any] | None = None) -> tuple[nn.Module, dict[str, Any]]:
    checkpoint = load_torch_checkpoint(checkpoint_path, map_location="cpu")
    cfg = config or checkpoint.get("config")
    if cfg is None:
        raise ValueError("Checkpoint does not include a config; provide one explicitly.")
    model = create_model(checkpoint["model_key"], cfg, pretrained=False)
    model.load_state_dict(checkpoint["state_dict"])
    model.eval()
    return model, checkpoint