File size: 3,877 Bytes
52efd90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Builders for the four architectures we compare.

Loaded via importlib so the parent dirs of the original model files are not
inserted into sys.path (otherwise their per-model `dataset.py` would shadow
this experiment's `dataset.py`).

Models:
    segnet         — pv_panel_models/cnn_model/cnn_segmenter.py        (SegNet)
    unet           — pv_panel_models/unet_model/unet_model.py          (U-Net)
    segformer_b0   — pv_panel_models/vit_model/segformer_model.py      (SegFormer mit-b0)
    segformer_b5   — pv_panel_models/segformer_b5_model/segformer_model.py  (SegFormer mit-b5)

NOTE: SegNet's `forward()` already applies sigmoid; UNet/SegFormer return raw
logits. The trainer uses `output_is_prob=True` for SegNet's metrics step.

SegNet's loss is reproduced inline (BCELoss + Dice on probabilities) since
pv_panel_models/cnn_model/train.py uses a sibling-relative import that doesn't
survive being loaded by importlib without a sys.path tweak.
"""
import importlib.util
from pathlib import Path

import torch
import torch.nn as nn

REPO_ROOT = Path(__file__).resolve().parents[2]
PV_DIR = REPO_ROOT / "pv_panel_models"


def _load(module_name: str, file_path: Path):
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    if spec is None or spec.loader is None:
        raise ImportError(f"could not load {file_path}")
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module


_segnet_mod = _load("_pv_segnet_model", PV_DIR / "cnn_model" / "cnn_segmenter.py")
_unet_mod = _load("_pv_unet_model", PV_DIR / "unet_model" / "unet_model.py")
_segformer_b0_mod = _load("_pv_segformer_b0_model", PV_DIR / "vit_model" / "segformer_model.py")
_segformer_b5_mod = _load("_pv_segformer_b5_model",
                          PV_DIR / "segformer_b5_model" / "segformer_model.py")


# SegNet expects probabilities (its forward applies sigmoid).
# Mirrors pv_panel_models/cnn_model/train.py:CombinedLoss exactly.
class _SegNetDiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth

    def forward(self, pred, target):
        pred = pred.view(-1)
        target = target.view(-1)
        intersection = (pred * target).sum()
        dice = (2.0 * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)
        return 1 - dice


class _SegNetCombinedLoss(nn.Module):
    def __init__(self, bce_weight=0.5):
        super().__init__()
        self.bce = nn.BCELoss()  # SegNet output is already in [0,1]
        self.dice = _SegNetDiceLoss()
        self.bce_weight = bce_weight

    def forward(self, pred, target):
        return self.bce_weight * self.bce(pred, target) + (1 - self.bce_weight) * self.dice(pred, target)


def build_segnet():
    model = _segnet_mod.SegNet(in_channels=3, out_channels=1)
    loss = _SegNetCombinedLoss(bce_weight=0.5)
    return model, loss, True  # output_is_prob (sigmoid in forward)


def build_unet():
    model = _unet_mod.UNet(in_channels=3, out_channels=1)
    loss = _unet_mod.CombinedLoss(bce_weight=0.5)
    return model, loss, False


def build_segformer_b0():
    model = _segformer_b0_mod.SegformerModel(pretrained_name="nvidia/mit-b0", num_classes=1)
    loss = _segformer_b0_mod.CombinedLoss(bce_weight=0.5)
    return model, loss, False


def build_segformer_b5():
    model = _segformer_b5_mod.SegformerModel(pretrained_name="nvidia/mit-b5", num_classes=1)
    loss = _segformer_b5_mod.CombinedLoss(bce_weight=0.5)
    return model, loss, False


MODEL_REGISTRY = {
    "segnet":       build_segnet,
    "unet":         build_unet,
    "segformer_b0": build_segformer_b0,
    "segformer_b5": build_segformer_b5,
}

PRETTY_NAME = {
    "segnet":       "SegNet (CNN)",
    "unet":         "U-Net",
    "segformer_b0": "SegFormer-B0",
    "segformer_b5": "SegFormer-B5",
}