Mohamed-ENNHIRI
Add Tab 7: resolution study (segformer_b0 + U-Net at 192/256/512)
a3200e4
Raw
History Blame Contribute Delete
2.68 kB
"""
Builders for the resolution study. Same 4 architectures, same loss recipe,
loaded via importlib so this experiment's dataset.py is not shadowed.
"""
import importlib.util
from pathlib import Path
import torch
import torch.nn as nn
REPO_ROOT = Path(__file__).resolve().parents[2]
PV_DIR = REPO_ROOT / "pv_panel_models"
def _load(module_name: str, file_path: Path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
_segnet_mod = _load("_pv_segnet_model", PV_DIR / "cnn_model" / "cnn_segmenter.py")
_unet_mod = _load("_pv_unet_model", PV_DIR / "unet_model" / "unet_model.py")
_segformer_b0_mod = _load("_pv_segformer_b0_model", PV_DIR / "vit_model" / "segformer_model.py")
_segformer_b5_mod = _load("_pv_segformer_b5_model",
PV_DIR / "segformer_b5_model" / "segformer_model.py")
class _SegNetDiceLoss(nn.Module):
def __init__(self, smooth=1e-6):
super().__init__()
self.smooth = smooth
def forward(self, pred, target):
pred = pred.view(-1)
target = target.view(-1)
intersection = (pred * target).sum()
dice = (2.0 * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)
return 1 - dice
class _SegNetCombinedLoss(nn.Module):
def __init__(self, bce_weight=0.5):
super().__init__()
self.bce = nn.BCELoss()
self.dice = _SegNetDiceLoss()
self.bce_weight = bce_weight
def forward(self, pred, target):
return self.bce_weight * self.bce(pred, target) + (1 - self.bce_weight) * self.dice(pred, target)
def build_segnet():
return _segnet_mod.SegNet(in_channels=3, out_channels=1), _SegNetCombinedLoss(bce_weight=0.5), True
def build_unet():
return _unet_mod.UNet(in_channels=3, out_channels=1), _unet_mod.CombinedLoss(bce_weight=0.5), False
def build_segformer_b0():
return (
_segformer_b0_mod.SegformerModel(pretrained_name="nvidia/mit-b0", num_classes=1),
_segformer_b0_mod.CombinedLoss(bce_weight=0.5),
False,
)
def build_segformer_b5():
return (
_segformer_b5_mod.SegformerModel(pretrained_name="nvidia/mit-b5", num_classes=1),
_segformer_b5_mod.CombinedLoss(bce_weight=0.5),
False,
)
MODEL_REGISTRY = {
"segnet": build_segnet,
"unet": build_unet,
"segformer_b0": build_segformer_b0,
"segformer_b5": build_segformer_b5,
}
PRETTY_NAME = {
"segnet": "SegNet (CNN)",
"unet": "U-Net",
"segformer_b0": "SegFormer-B0",
"segformer_b5": "SegFormer-B5",
}