""" Builders for the resolution study. Same 4 architectures, same loss recipe, loaded via importlib so this experiment's dataset.py is not shadowed. """ import importlib.util from pathlib import Path import torch import torch.nn as nn REPO_ROOT = Path(__file__).resolve().parents[2] PV_DIR = REPO_ROOT / "pv_panel_models" def _load(module_name: str, file_path: Path): spec = importlib.util.spec_from_file_location(module_name, file_path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module _segnet_mod = _load("_pv_segnet_model", PV_DIR / "cnn_model" / "cnn_segmenter.py") _unet_mod = _load("_pv_unet_model", PV_DIR / "unet_model" / "unet_model.py") _segformer_b0_mod = _load("_pv_segformer_b0_model", PV_DIR / "vit_model" / "segformer_model.py") _segformer_b5_mod = _load("_pv_segformer_b5_model", PV_DIR / "segformer_b5_model" / "segformer_model.py") class _SegNetDiceLoss(nn.Module): def __init__(self, smooth=1e-6): super().__init__() self.smooth = smooth def forward(self, pred, target): pred = pred.view(-1) target = target.view(-1) intersection = (pred * target).sum() dice = (2.0 * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth) return 1 - dice class _SegNetCombinedLoss(nn.Module): def __init__(self, bce_weight=0.5): super().__init__() self.bce = nn.BCELoss() self.dice = _SegNetDiceLoss() self.bce_weight = bce_weight def forward(self, pred, target): return self.bce_weight * self.bce(pred, target) + (1 - self.bce_weight) * self.dice(pred, target) def build_segnet(): return _segnet_mod.SegNet(in_channels=3, out_channels=1), _SegNetCombinedLoss(bce_weight=0.5), True def build_unet(): return _unet_mod.UNet(in_channels=3, out_channels=1), _unet_mod.CombinedLoss(bce_weight=0.5), False def build_segformer_b0(): return ( _segformer_b0_mod.SegformerModel(pretrained_name="nvidia/mit-b0", num_classes=1), _segformer_b0_mod.CombinedLoss(bce_weight=0.5), False, ) def build_segformer_b5(): return ( _segformer_b5_mod.SegformerModel(pretrained_name="nvidia/mit-b5", num_classes=1), _segformer_b5_mod.CombinedLoss(bce_weight=0.5), False, ) MODEL_REGISTRY = { "segnet": build_segnet, "unet": build_unet, "segformer_b0": build_segformer_b0, "segformer_b5": build_segformer_b5, } PRETTY_NAME = { "segnet": "SegNet (CNN)", "unet": "U-Net", "segformer_b0": "SegFormer-B0", "segformer_b5": "SegFormer-B5", }