Upload folder using huggingface_hub
Browse files- __init__.py +28 -0
- config.py +23 -0
- datasets/__init__.py +7 -0
- datasets/flair.py +140 -0
- models/__init__.py +19 -0
- models/heads.py +246 -0
- models/isdnet.py +101 -0
- models/modules.py +258 -0
- utils/__init__.py +17 -0
- utils/distributed.py +56 -0
- weights/isdnet_flair_best.pth +3 -0
__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ISDNet: Integrating Shallow and Deep Networks for Efficient Ultra-high Resolution Segmentation
|
| 3 |
+
|
| 4 |
+
A standalone PyTorch implementation.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .models import ISDNet
|
| 8 |
+
from .datasets import FLAIRDataset
|
| 9 |
+
from .config import (
|
| 10 |
+
DATA_ROOT,
|
| 11 |
+
STDC_PRETRAIN_PATH,
|
| 12 |
+
BATCH_SIZE_PER_GPU,
|
| 13 |
+
NUM_WORKERS,
|
| 14 |
+
BASE_LR,
|
| 15 |
+
WEIGHT_DECAY,
|
| 16 |
+
NUM_EPOCHS,
|
| 17 |
+
NUM_CLASSES,
|
| 18 |
+
CROP_SIZE,
|
| 19 |
+
DOWN_RATIO,
|
| 20 |
+
IGNORE_INDEX,
|
| 21 |
+
SAVE_INTERVAL,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
__version__ = "1.0.0"
|
| 25 |
+
__all__ = [
|
| 26 |
+
"ISDNet",
|
| 27 |
+
"FLAIRDataset",
|
| 28 |
+
]
|
config.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ISDNet Configuration
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
# Data paths
|
| 6 |
+
DATA_ROOT = "/ccast/FLAIR1024_optimal"
|
| 7 |
+
STDC_PRETRAIN_PATH = "STDCNet813M_73.91.tar"
|
| 8 |
+
|
| 9 |
+
# Training hyperparameters
|
| 10 |
+
BATCH_SIZE_PER_GPU = 16
|
| 11 |
+
NUM_WORKERS = 4
|
| 12 |
+
BASE_LR = 1e-3
|
| 13 |
+
WEIGHT_DECAY = 0.0005
|
| 14 |
+
NUM_EPOCHS = 80
|
| 15 |
+
|
| 16 |
+
# Model configuration
|
| 17 |
+
NUM_CLASSES = 15 # Classes 0-14 only
|
| 18 |
+
CROP_SIZE = 512
|
| 19 |
+
DOWN_RATIO = 4
|
| 20 |
+
IGNORE_INDEX = 255 # For classes >= 15
|
| 21 |
+
|
| 22 |
+
# Checkpointing
|
| 23 |
+
SAVE_INTERVAL = 5
|
datasets/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ISDNet datasets
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .flair import FLAIRDataset
|
| 6 |
+
|
| 7 |
+
__all__ = ["FLAIRDataset"]
|
datasets/flair.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FLAIR French Land Cover Dataset
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils.data import Dataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class FLAIRDataset(Dataset):
|
| 13 |
+
"""
|
| 14 |
+
FLAIR French Land Cover dataset.
|
| 15 |
+
|
| 16 |
+
15 classes (0-14), classes >= 15 are mapped to ignore_index (255).
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
data_root: Path to dataset root
|
| 20 |
+
split: 'train', 'valid', or 'test'
|
| 21 |
+
crop_size: Size of random/center crop
|
| 22 |
+
augment: Whether to apply augmentations (auto-disabled for non-train splits)
|
| 23 |
+
ignore_index: Label value to use for ignored classes
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
# ImageNet normalization
|
| 27 |
+
MEAN = np.array([123.675, 116.28, 103.53], dtype=np.float32)
|
| 28 |
+
STD = np.array([58.395, 57.12, 57.375], dtype=np.float32)
|
| 29 |
+
|
| 30 |
+
# Class names
|
| 31 |
+
CLASSES = [
|
| 32 |
+
'building', 'pervious', 'impervious', 'bare_soil', 'water',
|
| 33 |
+
'coniferous', 'deciduous', 'brushwood', 'vineyard', 'herbaceous',
|
| 34 |
+
'agricultural', 'plowed_land', 'swimming_pool', 'snow', 'greenhouse'
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
def __init__(self, data_root, split='train', crop_size=512, augment=True, ignore_index=255):
|
| 38 |
+
self.data_root = data_root
|
| 39 |
+
self.split = split
|
| 40 |
+
self.crop_size = crop_size
|
| 41 |
+
self.augment = augment and split == 'train'
|
| 42 |
+
self.ignore_index = ignore_index
|
| 43 |
+
|
| 44 |
+
self.img_dir = os.path.join(data_root, split, 'img')
|
| 45 |
+
self.msk_dir = os.path.join(data_root, split, 'msk')
|
| 46 |
+
self.img_files = sorted(os.listdir(self.img_dir))
|
| 47 |
+
|
| 48 |
+
def __len__(self):
|
| 49 |
+
return len(self.img_files)
|
| 50 |
+
|
| 51 |
+
def _photometric_distortion(self, img):
|
| 52 |
+
"""Apply photometric distortion (brightness, contrast, saturation, hue)."""
|
| 53 |
+
# Random brightness
|
| 54 |
+
if np.random.rand() > 0.5:
|
| 55 |
+
delta = np.random.uniform(-32, 32)
|
| 56 |
+
img = img + delta
|
| 57 |
+
|
| 58 |
+
# Random contrast
|
| 59 |
+
if np.random.rand() > 0.5:
|
| 60 |
+
alpha = np.random.uniform(0.5, 1.5)
|
| 61 |
+
img = img * alpha
|
| 62 |
+
|
| 63 |
+
# Convert to HSV for saturation and hue
|
| 64 |
+
img_uint8 = np.clip(img, 0, 255).astype(np.uint8)
|
| 65 |
+
img_hsv = np.array(Image.fromarray(img_uint8).convert('HSV')).astype(np.float32)
|
| 66 |
+
|
| 67 |
+
# Random saturation
|
| 68 |
+
if np.random.rand() > 0.5:
|
| 69 |
+
img_hsv[:, :, 1] = img_hsv[:, :, 1] * np.random.uniform(0.5, 1.5)
|
| 70 |
+
|
| 71 |
+
# Random hue
|
| 72 |
+
if np.random.rand() > 0.5:
|
| 73 |
+
img_hsv[:, :, 0] = (img_hsv[:, :, 0] + np.random.uniform(-18, 18)) % 256
|
| 74 |
+
|
| 75 |
+
# Convert back to RGB
|
| 76 |
+
img_hsv = np.clip(img_hsv, 0, 255).astype(np.uint8)
|
| 77 |
+
img = np.array(Image.fromarray(img_hsv, mode='HSV').convert('RGB')).astype(np.float32)
|
| 78 |
+
|
| 79 |
+
return np.clip(img, 0, 255)
|
| 80 |
+
|
| 81 |
+
def _random_rotate(self, img, msk):
|
| 82 |
+
"""Random rotation by 90, 180, or 270 degrees."""
|
| 83 |
+
k = np.random.choice([0, 1, 2, 3])
|
| 84 |
+
if k > 0:
|
| 85 |
+
img = np.rot90(img, k).copy()
|
| 86 |
+
msk = np.rot90(msk, k).copy()
|
| 87 |
+
return img, msk
|
| 88 |
+
|
| 89 |
+
def __getitem__(self, idx):
|
| 90 |
+
img_path = os.path.join(self.img_dir, self.img_files[idx])
|
| 91 |
+
msk_path = os.path.join(self.msk_dir, self.img_files[idx].replace('_RGBI_', '_LABEL-COSIA_'))
|
| 92 |
+
|
| 93 |
+
img = np.array(Image.open(img_path)).astype(np.float32)[:, :, :3]
|
| 94 |
+
msk = np.array(Image.open(msk_path)).astype(np.int64)
|
| 95 |
+
|
| 96 |
+
# Remap classes: keep 0-14, map >=15 to ignore_index
|
| 97 |
+
msk[msk >= 15] = self.ignore_index
|
| 98 |
+
|
| 99 |
+
# Apply photometric distortion BEFORE normalization
|
| 100 |
+
if self.augment:
|
| 101 |
+
img = self._photometric_distortion(img)
|
| 102 |
+
|
| 103 |
+
# Normalize
|
| 104 |
+
img = (img - self.MEAN) / self.STD
|
| 105 |
+
|
| 106 |
+
# Random/center crop
|
| 107 |
+
if self.crop_size and img.shape[0] >= self.crop_size:
|
| 108 |
+
h, w = img.shape[:2]
|
| 109 |
+
if self.augment:
|
| 110 |
+
# Try to find a crop with good class coverage (cat_max_ratio logic)
|
| 111 |
+
for _ in range(10):
|
| 112 |
+
top = np.random.randint(0, h - self.crop_size + 1)
|
| 113 |
+
left = np.random.randint(0, w - self.crop_size + 1)
|
| 114 |
+
crop_msk = msk[top:top+self.crop_size, left:left+self.crop_size]
|
| 115 |
+
valid_msk = crop_msk[crop_msk != self.ignore_index]
|
| 116 |
+
if len(valid_msk) > 0:
|
| 117 |
+
unique, counts = np.unique(valid_msk, return_counts=True)
|
| 118 |
+
if len(unique) > 1:
|
| 119 |
+
max_ratio = counts.max() / counts.sum()
|
| 120 |
+
if max_ratio < 0.75:
|
| 121 |
+
break
|
| 122 |
+
img = img[top:top+self.crop_size, left:left+self.crop_size]
|
| 123 |
+
msk = msk[top:top+self.crop_size, left:left+self.crop_size]
|
| 124 |
+
else:
|
| 125 |
+
# Center crop for validation
|
| 126 |
+
top = (h - self.crop_size) // 2
|
| 127 |
+
left = (w - self.crop_size) // 2
|
| 128 |
+
img = img[top:top+self.crop_size, left:left+self.crop_size]
|
| 129 |
+
msk = msk[top:top+self.crop_size, left:left+self.crop_size]
|
| 130 |
+
|
| 131 |
+
# Random rotation
|
| 132 |
+
if self.augment and np.random.rand() > 0.5:
|
| 133 |
+
img, msk = self._random_rotate(img, msk)
|
| 134 |
+
|
| 135 |
+
# Random horizontal flip
|
| 136 |
+
if self.augment and np.random.rand() > 0.5:
|
| 137 |
+
img = np.fliplr(img).copy()
|
| 138 |
+
msk = np.fliplr(msk).copy()
|
| 139 |
+
|
| 140 |
+
return torch.from_numpy(img.transpose(2, 0, 1).astype(np.float32)), torch.from_numpy(msk)
|
models/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ISDNet models
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .isdnet import ISDNet
|
| 6 |
+
from .modules import ConvX, AddBottleneck, CatBottleneck, ShallowNet, Lap_Pyramid_Conv
|
| 7 |
+
from .heads import ASPPModule, ISDHead, RefineASPPHead
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"ISDNet",
|
| 11 |
+
"ConvX",
|
| 12 |
+
"AddBottleneck",
|
| 13 |
+
"CatBottleneck",
|
| 14 |
+
"ShallowNet",
|
| 15 |
+
"Lap_Pyramid_Conv",
|
| 16 |
+
"ASPPModule",
|
| 17 |
+
"ISDHead",
|
| 18 |
+
"RefineASPPHead",
|
| 19 |
+
]
|
models/heads.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ISDNet decoder heads: ASPP, ISDHead, RefineASPPHead
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from mmcv.cnn import ConvModule
|
| 9 |
+
|
| 10 |
+
from .modules import ShallowNet, Lap_Pyramid_Conv
|
| 11 |
+
from ..utils import batch_mm_loop
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ASPPModule(nn.ModuleList):
|
| 15 |
+
"""Atrous Spatial Pyramid Pooling module."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, dilations, in_ch, ch, conv_cfg, norm_cfg, act_cfg):
|
| 18 |
+
super().__init__([
|
| 19 |
+
ConvModule(
|
| 20 |
+
in_ch, ch,
|
| 21 |
+
1 if d == 1 else 3,
|
| 22 |
+
dilation=d,
|
| 23 |
+
padding=0 if d == 1 else d,
|
| 24 |
+
conv_cfg=conv_cfg,
|
| 25 |
+
norm_cfg=norm_cfg,
|
| 26 |
+
act_cfg=act_cfg
|
| 27 |
+
)
|
| 28 |
+
for d in dilations
|
| 29 |
+
])
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
return [m(x) for m in self]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class SegmentationHead(nn.Module):
|
| 36 |
+
"""Simple segmentation head with conv + classifier."""
|
| 37 |
+
|
| 38 |
+
def __init__(self, conv_cfg, norm_cfg, act_cfg, in_ch, mid_ch, n_classes, **kw):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.conv = ConvModule(in_ch, mid_ch, 3, 1, 1,
|
| 41 |
+
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
| 42 |
+
self.out = nn.Conv2d(mid_ch, n_classes, 1, bias=True)
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
return self.out(self.conv(x))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class SRDecoder(nn.Module):
|
| 49 |
+
"""Super-resolution decoder for feature alignment loss."""
|
| 50 |
+
|
| 51 |
+
def __init__(self, conv_cfg, norm_cfg, act_cfg, ch=128, up_lists=[2, 2, 2]):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.up1 = nn.Upsample(scale_factor=up_lists[0])
|
| 54 |
+
self.conv1 = ConvModule(ch, ch // 2, 3, 1, 1,
|
| 55 |
+
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
| 56 |
+
self.up2 = nn.Upsample(scale_factor=up_lists[1])
|
| 57 |
+
self.conv2 = ConvModule(ch // 2, ch // 2, 3, 1, 1,
|
| 58 |
+
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
| 59 |
+
self.up3 = nn.Upsample(scale_factor=up_lists[2])
|
| 60 |
+
self.conv3 = ConvModule(ch // 2, ch, 3, 1, 1,
|
| 61 |
+
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
| 62 |
+
self.conv_sr = SegmentationHead(conv_cfg, norm_cfg, act_cfg, ch, ch // 2, 3)
|
| 63 |
+
|
| 64 |
+
def forward(self, x, fa=False):
|
| 65 |
+
feats = self.conv3(self.up3(self.conv2(self.up2(self.conv1(self.up1(x))))))
|
| 66 |
+
if fa:
|
| 67 |
+
return feats, self.conv_sr(feats)
|
| 68 |
+
return self.conv_sr(feats)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class ChannelAtt(nn.Module):
|
| 72 |
+
"""Channel attention module."""
|
| 73 |
+
|
| 74 |
+
def __init__(self, in_ch, out_ch, conv_cfg, norm_cfg, act_cfg):
|
| 75 |
+
super().__init__()
|
| 76 |
+
self.conv = ConvModule(in_ch, out_ch, 3, 1, 1,
|
| 77 |
+
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
| 78 |
+
self.conv1x1 = ConvModule(out_ch, out_ch, 1, 1, 0,
|
| 79 |
+
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None)
|
| 80 |
+
|
| 81 |
+
def forward(self, x):
|
| 82 |
+
feat = self.conv(x)
|
| 83 |
+
return feat, self.conv1x1(feat.mean(dim=(2, 3), keepdim=True))
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class RelationAwareFusion(nn.Module):
|
| 87 |
+
"""
|
| 88 |
+
Relation-aware fusion module.
|
| 89 |
+
|
| 90 |
+
Fuses shallow (spatial) and deep (context) features using
|
| 91 |
+
cross-attention mechanism.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def __init__(self, ch, conv_cfg, norm_cfg, act_cfg, ext=2, r=16):
|
| 95 |
+
super().__init__()
|
| 96 |
+
self.r = r
|
| 97 |
+
self.g1 = nn.Parameter(torch.zeros(1))
|
| 98 |
+
self.g2 = nn.Parameter(torch.zeros(1))
|
| 99 |
+
self.sp_mlp = nn.Sequential(
|
| 100 |
+
nn.Linear(ch * 2, ch),
|
| 101 |
+
nn.ReLU(),
|
| 102 |
+
nn.Linear(ch, ch)
|
| 103 |
+
)
|
| 104 |
+
self.sp_att = ChannelAtt(ch * ext, ch, conv_cfg, norm_cfg, act_cfg)
|
| 105 |
+
self.co_mlp = nn.Sequential(
|
| 106 |
+
nn.Linear(ch * 2, ch),
|
| 107 |
+
nn.ReLU(),
|
| 108 |
+
nn.Linear(ch, ch)
|
| 109 |
+
)
|
| 110 |
+
self.co_att = ChannelAtt(ch, ch, conv_cfg, norm_cfg, act_cfg)
|
| 111 |
+
self.co_head = ConvModule(ch, ch, 3, 1, 1,
|
| 112 |
+
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
| 113 |
+
self.smooth = ConvModule(ch, ch, 3, 1, 1,
|
| 114 |
+
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None)
|
| 115 |
+
|
| 116 |
+
def forward(self, sp_feat, co_feat):
|
| 117 |
+
s_f, s_a = self.sp_att(sp_feat)
|
| 118 |
+
c_f, c_a = self.co_att(co_feat)
|
| 119 |
+
b, c = s_a.shape[:2]
|
| 120 |
+
|
| 121 |
+
# Use loop-based batch mm to avoid CUBLAS strided batched issues
|
| 122 |
+
s_a_reshaped = s_a.view(b, self.r, c // self.r)
|
| 123 |
+
c_a_reshaped = c_a.view(b, self.r, c // self.r).permute(0, 2, 1)
|
| 124 |
+
aff = batch_mm_loop(s_a_reshaped, c_a_reshaped).view(b, -1)
|
| 125 |
+
|
| 126 |
+
re_s = torch.sigmoid(s_a + self.g1 * F.relu(self.sp_mlp(aff)).unsqueeze(-1).unsqueeze(-1))
|
| 127 |
+
re_c = torch.sigmoid(c_a + self.g2 * F.relu(self.co_mlp(aff)).unsqueeze(-1).unsqueeze(-1))
|
| 128 |
+
|
| 129 |
+
c_f = self.co_head(
|
| 130 |
+
F.interpolate(c_f * re_c, s_f.shape[2:], mode='bilinear', align_corners=False)
|
| 131 |
+
)
|
| 132 |
+
return s_f, c_f, self.smooth(s_f * re_s + c_f)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class Reducer(nn.Module):
|
| 136 |
+
"""Channel reducer module."""
|
| 137 |
+
|
| 138 |
+
def __init__(self, in_ch=512, reduce=128):
|
| 139 |
+
super().__init__()
|
| 140 |
+
self.conv = nn.Conv2d(in_ch, reduce, 1, bias=False)
|
| 141 |
+
self.bn = nn.SyncBatchNorm(reduce)
|
| 142 |
+
|
| 143 |
+
def forward(self, x):
|
| 144 |
+
return F.relu(self.bn(self.conv(x)))
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class ISDHead(nn.Module):
|
| 148 |
+
"""
|
| 149 |
+
ISD decoder head.
|
| 150 |
+
|
| 151 |
+
Combines shallow STDC features with deep backbone features
|
| 152 |
+
using relation-aware fusion at multiple scales.
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
def __init__(self, in_ch, ch, num_classes, down_ratio, prev_ch,
|
| 156 |
+
conv_cfg=None, norm_cfg=dict(type='SyncBN'), act_cfg=dict(type='ReLU'),
|
| 157 |
+
dropout=0.1, reduce=False, stdc_pretrain=''):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.ch = ch
|
| 160 |
+
self.fuse8 = RelationAwareFusion(ch, conv_cfg, norm_cfg, act_cfg, ext=2)
|
| 161 |
+
self.fuse16 = RelationAwareFusion(ch, conv_cfg, norm_cfg, act_cfg, ext=4)
|
| 162 |
+
self.sr_dec = SRDecoder(conv_cfg, norm_cfg, act_cfg, ch, [4, 2, 2])
|
| 163 |
+
self.stdc = ShallowNet(in_channels=6, pretrain_model=stdc_pretrain)
|
| 164 |
+
self.lap = Lap_Pyramid_Conv(num_high=2)
|
| 165 |
+
self.seg_aux16 = SegmentationHead(conv_cfg, norm_cfg, act_cfg, ch, ch // 2, num_classes)
|
| 166 |
+
self.seg_aux8 = SegmentationHead(conv_cfg, norm_cfg, act_cfg, ch, ch // 2, num_classes)
|
| 167 |
+
self.seg = SegmentationHead(conv_cfg, norm_cfg, act_cfg, ch, ch // 2, num_classes)
|
| 168 |
+
self.reduce = Reducer() if reduce else None
|
| 169 |
+
self.drop = nn.Dropout2d(dropout) if dropout > 0 else None
|
| 170 |
+
|
| 171 |
+
def forward(self, inputs, prev_output, train_flag=True):
|
| 172 |
+
# Laplacian pyramid decomposition
|
| 173 |
+
pyr = self.lap.pyramid_decom(inputs)
|
| 174 |
+
pyr1_up = F.interpolate(pyr[1], pyr[0].shape[2:], mode='bilinear', align_corners=False)
|
| 175 |
+
high_in = torch.cat([pyr[0], pyr1_up], dim=1)
|
| 176 |
+
|
| 177 |
+
# Shallow features
|
| 178 |
+
s8, s16 = self.stdc(high_in)
|
| 179 |
+
|
| 180 |
+
# Deep features
|
| 181 |
+
deep = self.reduce(prev_output[0]) if self.reduce else prev_output[0]
|
| 182 |
+
|
| 183 |
+
# Multi-scale fusion
|
| 184 |
+
_, a16, f16 = self.fuse16(s16, deep)
|
| 185 |
+
_, a8, f8 = self.fuse8(s8, f16)
|
| 186 |
+
|
| 187 |
+
# Segmentation output
|
| 188 |
+
out = self.seg(self.drop(f8) if self.drop else f8)
|
| 189 |
+
|
| 190 |
+
if train_flag:
|
| 191 |
+
feats, sr_out = self.sr_dec(deep, True)
|
| 192 |
+
target = pyr[0] + pyr1_up
|
| 193 |
+
if sr_out.shape[2:] != target.shape[2:]:
|
| 194 |
+
sr_out = F.interpolate(sr_out, target.shape[2:], mode='bilinear', align_corners=False)
|
| 195 |
+
return (out,
|
| 196 |
+
self.seg_aux16(a8),
|
| 197 |
+
self.seg_aux8(a16),
|
| 198 |
+
{'recon_losses': F.mse_loss(sr_out, target) * 0.1},
|
| 199 |
+
{'fa_loss': self._fa(deep, feats)})
|
| 200 |
+
return out
|
| 201 |
+
|
| 202 |
+
def _fa(self, seg_f, sr_f, eps=1e-6):
|
| 203 |
+
"""Feature alignment loss."""
|
| 204 |
+
if seg_f.shape[2:] != sr_f.shape[2:]:
|
| 205 |
+
sr_f = F.interpolate(sr_f, seg_f.shape[2:], mode='bilinear', align_corners=False)
|
| 206 |
+
sf = torch.flatten(seg_f, 2)
|
| 207 |
+
srf = torch.flatten(sr_f, 2)
|
| 208 |
+
sf = sf / (sf.norm(p=2, dim=2, keepdim=True) + eps)
|
| 209 |
+
srf = srf / (srf.norm(p=2, dim=2, keepdim=True) + eps)
|
| 210 |
+
# Use loop-based batch mm for CUBLAS compatibility
|
| 211 |
+
sf_t = sf.permute(0, 2, 1)
|
| 212 |
+
srf_t = srf.permute(0, 2, 1)
|
| 213 |
+
return F.l1_loss(batch_mm_loop(sf_t, sf), batch_mm_loop(srf_t, srf).detach())
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class RefineASPPHead(nn.Module):
|
| 217 |
+
"""
|
| 218 |
+
ASPP-based decoder head for deep path.
|
| 219 |
+
|
| 220 |
+
Processes low-resolution backbone features with
|
| 221 |
+
atrous spatial pyramid pooling.
|
| 222 |
+
"""
|
| 223 |
+
|
| 224 |
+
def __init__(self, in_ch, ch, num_classes, dilations=(1, 12, 24, 36),
|
| 225 |
+
conv_cfg=None, norm_cfg=dict(type='SyncBN'), act_cfg=dict(type='ReLU'),
|
| 226 |
+
dropout=0.1, in_index=-1):
|
| 227 |
+
super().__init__()
|
| 228 |
+
self.in_index = in_index
|
| 229 |
+
self.pool = nn.Sequential(
|
| 230 |
+
nn.AdaptiveAvgPool2d(1),
|
| 231 |
+
ConvModule(in_ch, ch, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
| 232 |
+
)
|
| 233 |
+
self.aspp = ASPPModule(dilations, in_ch, ch, conv_cfg, norm_cfg, act_cfg)
|
| 234 |
+
self.bottle = ConvModule(
|
| 235 |
+
(len(dilations) + 1) * ch, ch, 3, padding=1,
|
| 236 |
+
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg
|
| 237 |
+
)
|
| 238 |
+
self.seg = nn.Conv2d(ch, num_classes, 1)
|
| 239 |
+
self.drop = nn.Dropout2d(dropout) if dropout > 0 else None
|
| 240 |
+
|
| 241 |
+
def forward(self, inputs):
|
| 242 |
+
x = inputs[self.in_index] if isinstance(inputs, (list, tuple)) else inputs
|
| 243 |
+
outs = [F.interpolate(self.pool(x), x.shape[2:], mode='bilinear', align_corners=False)]
|
| 244 |
+
outs.extend(self.aspp(x))
|
| 245 |
+
feat = self.bottle(torch.cat(outs, dim=1))
|
| 246 |
+
return self.seg(self.drop(feat) if self.drop else feat), [feat]
|
models/isdnet.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ISDNet: Integrating Shallow and Deep Networks for Efficient Ultra-high Resolution Segmentation
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from mmcv.cnn import ConvModule
|
| 9 |
+
import timm
|
| 10 |
+
|
| 11 |
+
from .heads import ISDHead, RefineASPPHead
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ISDNet(nn.Module):
|
| 15 |
+
"""
|
| 16 |
+
ISDNet model for ultra-high resolution segmentation.
|
| 17 |
+
|
| 18 |
+
Combines a deep ResNet backbone with a shallow STDC-like network
|
| 19 |
+
to efficiently process both global context and local details.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
num_classes: Number of segmentation classes
|
| 23 |
+
backbone: Backbone model name (from timm)
|
| 24 |
+
ch: Base channel number for decoder
|
| 25 |
+
down_ratio: Downsampling ratio for deep path
|
| 26 |
+
dilations: ASPP dilation rates
|
| 27 |
+
pretrained: Use pretrained backbone weights
|
| 28 |
+
stdc_pretrain: Path to pretrained STDC weights
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, num_classes=15, backbone='resnet18', ch=128,
|
| 32 |
+
down_ratio=4, dilations=(1, 12, 24, 36),
|
| 33 |
+
pretrained=True, stdc_pretrain=''):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.ds = down_ratio
|
| 36 |
+
|
| 37 |
+
# Backbone (deep path)
|
| 38 |
+
self.bb = timm.create_model(backbone, pretrained=pretrained, features_only=True)
|
| 39 |
+
bb_ch = self.bb.feature_info.channels()
|
| 40 |
+
print(f"Backbone channels: {bb_ch}")
|
| 41 |
+
|
| 42 |
+
# Deep decoder (ASPP)
|
| 43 |
+
self.dec = RefineASPPHead(bb_ch[-1], ch, num_classes, dilations, in_index=-1)
|
| 44 |
+
|
| 45 |
+
# Shallow decoder (ISD head)
|
| 46 |
+
self.ref = ISDHead(3, ch, num_classes, down_ratio, ch, stdc_pretrain=stdc_pretrain)
|
| 47 |
+
|
| 48 |
+
# Auxiliary head
|
| 49 |
+
self.aux = nn.Sequential(
|
| 50 |
+
ConvModule(bb_ch[-2], 64, 3, padding=1,
|
| 51 |
+
norm_cfg=dict(type='SyncBN'), act_cfg=dict(type='ReLU')),
|
| 52 |
+
nn.Dropout2d(0.1),
|
| 53 |
+
nn.Conv2d(64, num_classes, 1)
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
def forward(self, img, return_loss=True):
|
| 57 |
+
"""
|
| 58 |
+
Forward pass.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
img: Input image tensor (B, C, H, W)
|
| 62 |
+
return_loss: If True, return dict with all outputs for loss computation
|
| 63 |
+
If False, return only final segmentation output
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
If return_loss=True: Dict with 'out', 'out_deep', 'out_aux16', 'out_aux8',
|
| 67 |
+
'aux_out', 'losses_re', 'losses_fa'
|
| 68 |
+
If return_loss=False: Segmentation logits (B, num_classes, H, W)
|
| 69 |
+
"""
|
| 70 |
+
# Downsample for deep path
|
| 71 |
+
x = self.bb(F.interpolate(
|
| 72 |
+
img,
|
| 73 |
+
[s // self.ds for s in img.shape[2:]],
|
| 74 |
+
mode='bilinear',
|
| 75 |
+
align_corners=False
|
| 76 |
+
))
|
| 77 |
+
|
| 78 |
+
# Deep path output
|
| 79 |
+
out_g, prev = self.dec(x)
|
| 80 |
+
|
| 81 |
+
if return_loss:
|
| 82 |
+
# Full training forward with all auxiliary outputs
|
| 83 |
+
out_r, a16, a8, l_re, l_fa = self.ref(img, prev, True)
|
| 84 |
+
sz = img.shape[2:]
|
| 85 |
+
return {
|
| 86 |
+
'out': F.interpolate(out_r, sz, mode='bilinear', align_corners=False),
|
| 87 |
+
'out_deep': F.interpolate(out_g, sz, mode='bilinear', align_corners=False),
|
| 88 |
+
'out_aux16': F.interpolate(a16, sz, mode='bilinear', align_corners=False),
|
| 89 |
+
'out_aux8': F.interpolate(a8, sz, mode='bilinear', align_corners=False),
|
| 90 |
+
'aux_out': F.interpolate(self.aux(x[-2]), sz, mode='bilinear', align_corners=False),
|
| 91 |
+
'losses_re': l_re,
|
| 92 |
+
'losses_fa': l_fa
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
# Inference: only shallow path output
|
| 96 |
+
return F.interpolate(
|
| 97 |
+
self.ref(img, prev, False),
|
| 98 |
+
img.shape[2:],
|
| 99 |
+
mode='bilinear',
|
| 100 |
+
align_corners=False
|
| 101 |
+
)
|
models/modules.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ISDNet building blocks: STDC-like modules and Laplacian pyramid
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import math
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch.nn import init
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ConvX(nn.Module):
|
| 14 |
+
"""Basic conv-bn-relu block."""
|
| 15 |
+
|
| 16 |
+
def __init__(self, in_planes, out_planes, kernel=3, stride=1):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.conv = nn.Conv2d(
|
| 19 |
+
in_planes, out_planes,
|
| 20 |
+
kernel_size=kernel, stride=stride,
|
| 21 |
+
padding=kernel // 2, bias=False
|
| 22 |
+
)
|
| 23 |
+
self.bn = nn.SyncBatchNorm(out_planes)
|
| 24 |
+
self.relu = nn.ReLU(inplace=True)
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
return self.relu(self.bn(self.conv(x)))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class AddBottleneck(nn.Module):
|
| 31 |
+
"""STDC AddBottleneck: residual addition fusion."""
|
| 32 |
+
|
| 33 |
+
def __init__(self, in_planes, out_planes, block_num=3, stride=1):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.conv_list = nn.ModuleList()
|
| 36 |
+
self.stride = stride
|
| 37 |
+
|
| 38 |
+
if stride == 2:
|
| 39 |
+
self.avd_layer = nn.Sequential(
|
| 40 |
+
nn.Conv2d(out_planes // 2, out_planes // 2, 3, 2, 1,
|
| 41 |
+
groups=out_planes // 2, bias=False),
|
| 42 |
+
nn.SyncBatchNorm(out_planes // 2)
|
| 43 |
+
)
|
| 44 |
+
self.skip = nn.Sequential(
|
| 45 |
+
nn.Conv2d(in_planes, in_planes, 3, 2, 1, groups=in_planes, bias=False),
|
| 46 |
+
nn.SyncBatchNorm(in_planes),
|
| 47 |
+
nn.Conv2d(in_planes, out_planes, 1, bias=False),
|
| 48 |
+
nn.SyncBatchNorm(out_planes)
|
| 49 |
+
)
|
| 50 |
+
stride = 1
|
| 51 |
+
|
| 52 |
+
for idx in range(block_num):
|
| 53 |
+
if idx == 0:
|
| 54 |
+
self.conv_list.append(ConvX(in_planes, out_planes // 2, kernel=1))
|
| 55 |
+
elif idx == 1 and block_num == 2:
|
| 56 |
+
self.conv_list.append(ConvX(out_planes // 2, out_planes // 2, stride=stride))
|
| 57 |
+
elif idx == 1:
|
| 58 |
+
self.conv_list.append(ConvX(out_planes // 2, out_planes // 4, stride=stride))
|
| 59 |
+
elif idx < block_num - 1:
|
| 60 |
+
self.conv_list.append(
|
| 61 |
+
ConvX(out_planes // int(math.pow(2, idx)),
|
| 62 |
+
out_planes // int(math.pow(2, idx + 1)))
|
| 63 |
+
)
|
| 64 |
+
else:
|
| 65 |
+
self.conv_list.append(
|
| 66 |
+
ConvX(out_planes // int(math.pow(2, idx)),
|
| 67 |
+
out_planes // int(math.pow(2, idx)))
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
out_list, out = [], x
|
| 72 |
+
for idx, conv in enumerate(self.conv_list):
|
| 73 |
+
if idx == 0 and self.stride == 2:
|
| 74 |
+
out = self.avd_layer(conv(out))
|
| 75 |
+
else:
|
| 76 |
+
out = conv(out)
|
| 77 |
+
out_list.append(out)
|
| 78 |
+
|
| 79 |
+
if self.stride == 2:
|
| 80 |
+
return torch.cat(out_list, dim=1) + self.skip(x)
|
| 81 |
+
return torch.cat(out_list, dim=1) + x
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class CatBottleneck(nn.Module):
|
| 85 |
+
"""STDC CatBottleneck: concatenation fusion."""
|
| 86 |
+
|
| 87 |
+
def __init__(self, in_planes, out_planes, block_num=3, stride=1):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.conv_list = nn.ModuleList()
|
| 90 |
+
self.stride = stride
|
| 91 |
+
|
| 92 |
+
if stride == 2:
|
| 93 |
+
self.avd_layer = nn.Sequential(
|
| 94 |
+
nn.Conv2d(out_planes // 2, out_planes // 2, 3, 2, 1,
|
| 95 |
+
groups=out_planes // 2, bias=False),
|
| 96 |
+
nn.SyncBatchNorm(out_planes // 2)
|
| 97 |
+
)
|
| 98 |
+
self.skip = nn.AvgPool2d(3, 2, 1)
|
| 99 |
+
stride = 1
|
| 100 |
+
|
| 101 |
+
for idx in range(block_num):
|
| 102 |
+
if idx == 0:
|
| 103 |
+
self.conv_list.append(ConvX(in_planes, out_planes // 2, kernel=1))
|
| 104 |
+
elif idx == 1 and block_num == 2:
|
| 105 |
+
self.conv_list.append(ConvX(out_planes // 2, out_planes // 2, stride=stride))
|
| 106 |
+
elif idx == 1:
|
| 107 |
+
self.conv_list.append(ConvX(out_planes // 2, out_planes // 4, stride=stride))
|
| 108 |
+
elif idx < block_num - 1:
|
| 109 |
+
self.conv_list.append(
|
| 110 |
+
ConvX(out_planes // int(math.pow(2, idx)),
|
| 111 |
+
out_planes // int(math.pow(2, idx + 1)))
|
| 112 |
+
)
|
| 113 |
+
else:
|
| 114 |
+
self.conv_list.append(
|
| 115 |
+
ConvX(out_planes // int(math.pow(2, idx)),
|
| 116 |
+
out_planes // int(math.pow(2, idx)))
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
def forward(self, x):
|
| 120 |
+
out_list = []
|
| 121 |
+
out1 = self.conv_list[0](x)
|
| 122 |
+
|
| 123 |
+
for idx, conv in enumerate(self.conv_list[1:]):
|
| 124 |
+
if idx == 0 and self.stride == 2:
|
| 125 |
+
out = conv(self.avd_layer(out1))
|
| 126 |
+
elif idx == 0:
|
| 127 |
+
out = conv(out1)
|
| 128 |
+
else:
|
| 129 |
+
out = conv(out)
|
| 130 |
+
out_list.append(out)
|
| 131 |
+
|
| 132 |
+
if self.stride == 2:
|
| 133 |
+
out_list.insert(0, self.skip(out1))
|
| 134 |
+
else:
|
| 135 |
+
out_list.insert(0, out1)
|
| 136 |
+
|
| 137 |
+
return torch.cat(out_list, dim=1)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class ShallowNet(nn.Module):
|
| 141 |
+
"""
|
| 142 |
+
STDC-like shallow network for high-resolution feature extraction.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
base: Base channel number
|
| 146 |
+
in_channels: Input channels (3 for RGB, 6 for pyramid concat)
|
| 147 |
+
layers: Number of blocks per stage
|
| 148 |
+
block_num: Number of convs per block
|
| 149 |
+
type: 'cat' for CatBottleneck, 'add' for AddBottleneck
|
| 150 |
+
pretrain_model: Path to pretrained STDC weights
|
| 151 |
+
"""
|
| 152 |
+
|
| 153 |
+
def __init__(self, base=64, in_channels=3, layers=[2, 2], block_num=4,
|
| 154 |
+
type="cat", pretrain_model=''):
|
| 155 |
+
super().__init__()
|
| 156 |
+
block = CatBottleneck if type == "cat" else AddBottleneck
|
| 157 |
+
self.in_channels = in_channels
|
| 158 |
+
|
| 159 |
+
features = [
|
| 160 |
+
ConvX(in_channels, base // 2, 3, 2),
|
| 161 |
+
ConvX(base // 2, base, 3, 2)
|
| 162 |
+
]
|
| 163 |
+
|
| 164 |
+
for i, layer in enumerate(layers):
|
| 165 |
+
for j in range(layer):
|
| 166 |
+
if i == 0 and j == 0:
|
| 167 |
+
features.append(block(base, base * 4, block_num, 2))
|
| 168 |
+
elif j == 0:
|
| 169 |
+
features.append(
|
| 170 |
+
block(base * int(math.pow(2, i + 1)),
|
| 171 |
+
base * int(math.pow(2, i + 2)), block_num, 2)
|
| 172 |
+
)
|
| 173 |
+
else:
|
| 174 |
+
features.append(
|
| 175 |
+
block(base * int(math.pow(2, i + 2)),
|
| 176 |
+
base * int(math.pow(2, i + 2)), block_num, 1)
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
self.features = nn.Sequential(*features)
|
| 180 |
+
self.x2 = nn.Sequential(self.features[:1])
|
| 181 |
+
self.x4 = nn.Sequential(self.features[1:2])
|
| 182 |
+
self.x8 = nn.Sequential(self.features[2:4])
|
| 183 |
+
self.x16 = nn.Sequential(self.features[4:6])
|
| 184 |
+
|
| 185 |
+
if pretrain_model and os.path.exists(pretrain_model):
|
| 186 |
+
print(f'Loading pretrain model {pretrain_model}')
|
| 187 |
+
sd = torch.load(pretrain_model, weights_only=False)["state_dict"]
|
| 188 |
+
ssd = self.state_dict()
|
| 189 |
+
for k, v in sd.items():
|
| 190 |
+
if k == 'features.0.conv.weight' and in_channels != 3:
|
| 191 |
+
v = torch.cat([v, v], dim=1)
|
| 192 |
+
if k in ssd:
|
| 193 |
+
ssd.update({k: v})
|
| 194 |
+
self.load_state_dict(ssd, strict=False)
|
| 195 |
+
else:
|
| 196 |
+
for m in self.modules():
|
| 197 |
+
if isinstance(m, nn.Conv2d):
|
| 198 |
+
init.kaiming_normal_(m.weight, mode='fan_out')
|
| 199 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
|
| 200 |
+
init.constant_(m.weight, 1)
|
| 201 |
+
init.constant_(m.bias, 0)
|
| 202 |
+
|
| 203 |
+
def forward(self, x):
|
| 204 |
+
x2 = self.x2(x)
|
| 205 |
+
x4 = self.x4(x2)
|
| 206 |
+
x8 = self.x8(x4)
|
| 207 |
+
x16 = self.x16(x8)
|
| 208 |
+
return x8, x16
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class Lap_Pyramid_Conv(nn.Module):
|
| 212 |
+
"""
|
| 213 |
+
Laplacian pyramid decomposition.
|
| 214 |
+
|
| 215 |
+
Extracts high-frequency details at multiple scales.
|
| 216 |
+
"""
|
| 217 |
+
|
| 218 |
+
def __init__(self, num_high=3, gauss_chl=3):
|
| 219 |
+
super().__init__()
|
| 220 |
+
self.num_high = num_high
|
| 221 |
+
self.gauss_chl = gauss_chl
|
| 222 |
+
|
| 223 |
+
k = torch.tensor([
|
| 224 |
+
[1., 4., 6., 4., 1],
|
| 225 |
+
[4., 16., 24., 16., 4.],
|
| 226 |
+
[6., 24., 36., 24., 6.],
|
| 227 |
+
[4., 16., 24., 16., 4.],
|
| 228 |
+
[1., 4., 6., 4., 1.]
|
| 229 |
+
]) / 256.
|
| 230 |
+
self.register_buffer('kernel', k.repeat(gauss_chl, 1, 1, 1))
|
| 231 |
+
|
| 232 |
+
def conv_gauss(self, img, k):
|
| 233 |
+
return F.conv2d(F.pad(img, (2, 2, 2, 2), mode='reflect'), k, groups=img.shape[1])
|
| 234 |
+
|
| 235 |
+
def downsample(self, x):
|
| 236 |
+
return x[:, :, ::2, ::2]
|
| 237 |
+
|
| 238 |
+
def upsample(self, x):
|
| 239 |
+
cc = torch.cat([x, torch.zeros_like(x)], dim=3)
|
| 240 |
+
cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])
|
| 241 |
+
cc = cc.permute(0, 1, 3, 2)
|
| 242 |
+
cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3],
|
| 243 |
+
x.shape[2] * 2, device=x.device)], dim=3)
|
| 244 |
+
cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)
|
| 245 |
+
return self.conv_gauss(cc.permute(0, 1, 3, 2), 4 * self.kernel)
|
| 246 |
+
|
| 247 |
+
def pyramid_decom(self, img):
|
| 248 |
+
"""Decompose image into Laplacian pyramid (high-frequency residuals)."""
|
| 249 |
+
current = img
|
| 250 |
+
pyr = []
|
| 251 |
+
for _ in range(self.num_high):
|
| 252 |
+
down = self.downsample(self.conv_gauss(current, self.kernel))
|
| 253 |
+
up = self.upsample(down)
|
| 254 |
+
if up.shape[2:] != current.shape[2:]:
|
| 255 |
+
up = F.interpolate(up, current.shape[2:])
|
| 256 |
+
pyr.append(current - up)
|
| 257 |
+
current = down
|
| 258 |
+
return pyr
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ISDNet utilities
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .distributed import (
|
| 6 |
+
setup_distributed,
|
| 7 |
+
cleanup_distributed,
|
| 8 |
+
print_rank0,
|
| 9 |
+
batch_mm_loop,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"setup_distributed",
|
| 14 |
+
"cleanup_distributed",
|
| 15 |
+
"print_rank0",
|
| 16 |
+
"batch_mm_loop",
|
| 17 |
+
]
|
utils/distributed.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Distributed training utilities
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def setup_distributed():
|
| 11 |
+
"""Initialize distributed training."""
|
| 12 |
+
if 'RANK' in os.environ:
|
| 13 |
+
rank = int(os.environ['RANK'])
|
| 14 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
| 15 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
| 16 |
+
else:
|
| 17 |
+
rank = 0
|
| 18 |
+
world_size = 1
|
| 19 |
+
local_rank = 0
|
| 20 |
+
|
| 21 |
+
if world_size > 1:
|
| 22 |
+
dist.init_process_group('nccl')
|
| 23 |
+
torch.cuda.set_device(local_rank)
|
| 24 |
+
|
| 25 |
+
return rank, world_size, local_rank
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def cleanup_distributed():
|
| 29 |
+
"""Cleanup distributed training."""
|
| 30 |
+
if dist.is_initialized():
|
| 31 |
+
dist.destroy_process_group()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def print_rank0(msg, rank=0):
|
| 35 |
+
"""Print only from rank 0."""
|
| 36 |
+
if rank == 0:
|
| 37 |
+
print(msg)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def batch_mm_loop(a, b):
|
| 41 |
+
"""
|
| 42 |
+
Batch matrix multiply using a loop over the batch dimension.
|
| 43 |
+
Avoids CUBLAS strided batched routines which have issues on L40S/CUDA 12.8/PyTorch 2.10.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
a: Tensor of shape (batch, m, k)
|
| 47 |
+
b: Tensor of shape (batch, k, n)
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Tensor of shape (batch, m, n)
|
| 51 |
+
"""
|
| 52 |
+
batch = a.shape[0]
|
| 53 |
+
results = []
|
| 54 |
+
for i in range(batch):
|
| 55 |
+
results.append(torch.mm(a[i], b[i]))
|
| 56 |
+
return torch.stack(results, dim=0)
|
weights/isdnet_flair_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:233b4a931fe370f395d0ce60d636036eefc35e596b09b1acfa54950d7f1d89e1
|
| 3 |
+
size 142441755
|