Antoine1091 commited on
Commit
a585f5a
·
verified ·
1 Parent(s): bebbee0

Upload folder using huggingface_hub

Browse files
__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ISDNet: Integrating Shallow and Deep Networks for Efficient Ultra-high Resolution Segmentation
3
+
4
+ A standalone PyTorch implementation.
5
+ """
6
+
7
+ from .models import ISDNet
8
+ from .datasets import FLAIRDataset
9
+ from .config import (
10
+ DATA_ROOT,
11
+ STDC_PRETRAIN_PATH,
12
+ BATCH_SIZE_PER_GPU,
13
+ NUM_WORKERS,
14
+ BASE_LR,
15
+ WEIGHT_DECAY,
16
+ NUM_EPOCHS,
17
+ NUM_CLASSES,
18
+ CROP_SIZE,
19
+ DOWN_RATIO,
20
+ IGNORE_INDEX,
21
+ SAVE_INTERVAL,
22
+ )
23
+
24
+ __version__ = "1.0.0"
25
+ __all__ = [
26
+ "ISDNet",
27
+ "FLAIRDataset",
28
+ ]
config.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ISDNet Configuration
3
+ """
4
+
5
+ # Data paths
6
+ DATA_ROOT = "/ccast/FLAIR1024_optimal"
7
+ STDC_PRETRAIN_PATH = "STDCNet813M_73.91.tar"
8
+
9
+ # Training hyperparameters
10
+ BATCH_SIZE_PER_GPU = 16
11
+ NUM_WORKERS = 4
12
+ BASE_LR = 1e-3
13
+ WEIGHT_DECAY = 0.0005
14
+ NUM_EPOCHS = 80
15
+
16
+ # Model configuration
17
+ NUM_CLASSES = 15 # Classes 0-14 only
18
+ CROP_SIZE = 512
19
+ DOWN_RATIO = 4
20
+ IGNORE_INDEX = 255 # For classes >= 15
21
+
22
+ # Checkpointing
23
+ SAVE_INTERVAL = 5
datasets/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ ISDNet datasets
3
+ """
4
+
5
+ from .flair import FLAIRDataset
6
+
7
+ __all__ = ["FLAIRDataset"]
datasets/flair.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FLAIR French Land Cover Dataset
3
+ """
4
+
5
+ import os
6
+ import numpy as np
7
+ from PIL import Image
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+
11
+
12
+ class FLAIRDataset(Dataset):
13
+ """
14
+ FLAIR French Land Cover dataset.
15
+
16
+ 15 classes (0-14), classes >= 15 are mapped to ignore_index (255).
17
+
18
+ Args:
19
+ data_root: Path to dataset root
20
+ split: 'train', 'valid', or 'test'
21
+ crop_size: Size of random/center crop
22
+ augment: Whether to apply augmentations (auto-disabled for non-train splits)
23
+ ignore_index: Label value to use for ignored classes
24
+ """
25
+
26
+ # ImageNet normalization
27
+ MEAN = np.array([123.675, 116.28, 103.53], dtype=np.float32)
28
+ STD = np.array([58.395, 57.12, 57.375], dtype=np.float32)
29
+
30
+ # Class names
31
+ CLASSES = [
32
+ 'building', 'pervious', 'impervious', 'bare_soil', 'water',
33
+ 'coniferous', 'deciduous', 'brushwood', 'vineyard', 'herbaceous',
34
+ 'agricultural', 'plowed_land', 'swimming_pool', 'snow', 'greenhouse'
35
+ ]
36
+
37
+ def __init__(self, data_root, split='train', crop_size=512, augment=True, ignore_index=255):
38
+ self.data_root = data_root
39
+ self.split = split
40
+ self.crop_size = crop_size
41
+ self.augment = augment and split == 'train'
42
+ self.ignore_index = ignore_index
43
+
44
+ self.img_dir = os.path.join(data_root, split, 'img')
45
+ self.msk_dir = os.path.join(data_root, split, 'msk')
46
+ self.img_files = sorted(os.listdir(self.img_dir))
47
+
48
+ def __len__(self):
49
+ return len(self.img_files)
50
+
51
+ def _photometric_distortion(self, img):
52
+ """Apply photometric distortion (brightness, contrast, saturation, hue)."""
53
+ # Random brightness
54
+ if np.random.rand() > 0.5:
55
+ delta = np.random.uniform(-32, 32)
56
+ img = img + delta
57
+
58
+ # Random contrast
59
+ if np.random.rand() > 0.5:
60
+ alpha = np.random.uniform(0.5, 1.5)
61
+ img = img * alpha
62
+
63
+ # Convert to HSV for saturation and hue
64
+ img_uint8 = np.clip(img, 0, 255).astype(np.uint8)
65
+ img_hsv = np.array(Image.fromarray(img_uint8).convert('HSV')).astype(np.float32)
66
+
67
+ # Random saturation
68
+ if np.random.rand() > 0.5:
69
+ img_hsv[:, :, 1] = img_hsv[:, :, 1] * np.random.uniform(0.5, 1.5)
70
+
71
+ # Random hue
72
+ if np.random.rand() > 0.5:
73
+ img_hsv[:, :, 0] = (img_hsv[:, :, 0] + np.random.uniform(-18, 18)) % 256
74
+
75
+ # Convert back to RGB
76
+ img_hsv = np.clip(img_hsv, 0, 255).astype(np.uint8)
77
+ img = np.array(Image.fromarray(img_hsv, mode='HSV').convert('RGB')).astype(np.float32)
78
+
79
+ return np.clip(img, 0, 255)
80
+
81
+ def _random_rotate(self, img, msk):
82
+ """Random rotation by 90, 180, or 270 degrees."""
83
+ k = np.random.choice([0, 1, 2, 3])
84
+ if k > 0:
85
+ img = np.rot90(img, k).copy()
86
+ msk = np.rot90(msk, k).copy()
87
+ return img, msk
88
+
89
+ def __getitem__(self, idx):
90
+ img_path = os.path.join(self.img_dir, self.img_files[idx])
91
+ msk_path = os.path.join(self.msk_dir, self.img_files[idx].replace('_RGBI_', '_LABEL-COSIA_'))
92
+
93
+ img = np.array(Image.open(img_path)).astype(np.float32)[:, :, :3]
94
+ msk = np.array(Image.open(msk_path)).astype(np.int64)
95
+
96
+ # Remap classes: keep 0-14, map >=15 to ignore_index
97
+ msk[msk >= 15] = self.ignore_index
98
+
99
+ # Apply photometric distortion BEFORE normalization
100
+ if self.augment:
101
+ img = self._photometric_distortion(img)
102
+
103
+ # Normalize
104
+ img = (img - self.MEAN) / self.STD
105
+
106
+ # Random/center crop
107
+ if self.crop_size and img.shape[0] >= self.crop_size:
108
+ h, w = img.shape[:2]
109
+ if self.augment:
110
+ # Try to find a crop with good class coverage (cat_max_ratio logic)
111
+ for _ in range(10):
112
+ top = np.random.randint(0, h - self.crop_size + 1)
113
+ left = np.random.randint(0, w - self.crop_size + 1)
114
+ crop_msk = msk[top:top+self.crop_size, left:left+self.crop_size]
115
+ valid_msk = crop_msk[crop_msk != self.ignore_index]
116
+ if len(valid_msk) > 0:
117
+ unique, counts = np.unique(valid_msk, return_counts=True)
118
+ if len(unique) > 1:
119
+ max_ratio = counts.max() / counts.sum()
120
+ if max_ratio < 0.75:
121
+ break
122
+ img = img[top:top+self.crop_size, left:left+self.crop_size]
123
+ msk = msk[top:top+self.crop_size, left:left+self.crop_size]
124
+ else:
125
+ # Center crop for validation
126
+ top = (h - self.crop_size) // 2
127
+ left = (w - self.crop_size) // 2
128
+ img = img[top:top+self.crop_size, left:left+self.crop_size]
129
+ msk = msk[top:top+self.crop_size, left:left+self.crop_size]
130
+
131
+ # Random rotation
132
+ if self.augment and np.random.rand() > 0.5:
133
+ img, msk = self._random_rotate(img, msk)
134
+
135
+ # Random horizontal flip
136
+ if self.augment and np.random.rand() > 0.5:
137
+ img = np.fliplr(img).copy()
138
+ msk = np.fliplr(msk).copy()
139
+
140
+ return torch.from_numpy(img.transpose(2, 0, 1).astype(np.float32)), torch.from_numpy(msk)
models/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ISDNet models
3
+ """
4
+
5
+ from .isdnet import ISDNet
6
+ from .modules import ConvX, AddBottleneck, CatBottleneck, ShallowNet, Lap_Pyramid_Conv
7
+ from .heads import ASPPModule, ISDHead, RefineASPPHead
8
+
9
+ __all__ = [
10
+ "ISDNet",
11
+ "ConvX",
12
+ "AddBottleneck",
13
+ "CatBottleneck",
14
+ "ShallowNet",
15
+ "Lap_Pyramid_Conv",
16
+ "ASPPModule",
17
+ "ISDHead",
18
+ "RefineASPPHead",
19
+ ]
models/heads.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ISDNet decoder heads: ASPP, ISDHead, RefineASPPHead
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from mmcv.cnn import ConvModule
9
+
10
+ from .modules import ShallowNet, Lap_Pyramid_Conv
11
+ from ..utils import batch_mm_loop
12
+
13
+
14
+ class ASPPModule(nn.ModuleList):
15
+ """Atrous Spatial Pyramid Pooling module."""
16
+
17
+ def __init__(self, dilations, in_ch, ch, conv_cfg, norm_cfg, act_cfg):
18
+ super().__init__([
19
+ ConvModule(
20
+ in_ch, ch,
21
+ 1 if d == 1 else 3,
22
+ dilation=d,
23
+ padding=0 if d == 1 else d,
24
+ conv_cfg=conv_cfg,
25
+ norm_cfg=norm_cfg,
26
+ act_cfg=act_cfg
27
+ )
28
+ for d in dilations
29
+ ])
30
+
31
+ def forward(self, x):
32
+ return [m(x) for m in self]
33
+
34
+
35
+ class SegmentationHead(nn.Module):
36
+ """Simple segmentation head with conv + classifier."""
37
+
38
+ def __init__(self, conv_cfg, norm_cfg, act_cfg, in_ch, mid_ch, n_classes, **kw):
39
+ super().__init__()
40
+ self.conv = ConvModule(in_ch, mid_ch, 3, 1, 1,
41
+ conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
42
+ self.out = nn.Conv2d(mid_ch, n_classes, 1, bias=True)
43
+
44
+ def forward(self, x):
45
+ return self.out(self.conv(x))
46
+
47
+
48
+ class SRDecoder(nn.Module):
49
+ """Super-resolution decoder for feature alignment loss."""
50
+
51
+ def __init__(self, conv_cfg, norm_cfg, act_cfg, ch=128, up_lists=[2, 2, 2]):
52
+ super().__init__()
53
+ self.up1 = nn.Upsample(scale_factor=up_lists[0])
54
+ self.conv1 = ConvModule(ch, ch // 2, 3, 1, 1,
55
+ conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
56
+ self.up2 = nn.Upsample(scale_factor=up_lists[1])
57
+ self.conv2 = ConvModule(ch // 2, ch // 2, 3, 1, 1,
58
+ conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
59
+ self.up3 = nn.Upsample(scale_factor=up_lists[2])
60
+ self.conv3 = ConvModule(ch // 2, ch, 3, 1, 1,
61
+ conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
62
+ self.conv_sr = SegmentationHead(conv_cfg, norm_cfg, act_cfg, ch, ch // 2, 3)
63
+
64
+ def forward(self, x, fa=False):
65
+ feats = self.conv3(self.up3(self.conv2(self.up2(self.conv1(self.up1(x))))))
66
+ if fa:
67
+ return feats, self.conv_sr(feats)
68
+ return self.conv_sr(feats)
69
+
70
+
71
+ class ChannelAtt(nn.Module):
72
+ """Channel attention module."""
73
+
74
+ def __init__(self, in_ch, out_ch, conv_cfg, norm_cfg, act_cfg):
75
+ super().__init__()
76
+ self.conv = ConvModule(in_ch, out_ch, 3, 1, 1,
77
+ conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
78
+ self.conv1x1 = ConvModule(out_ch, out_ch, 1, 1, 0,
79
+ conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None)
80
+
81
+ def forward(self, x):
82
+ feat = self.conv(x)
83
+ return feat, self.conv1x1(feat.mean(dim=(2, 3), keepdim=True))
84
+
85
+
86
+ class RelationAwareFusion(nn.Module):
87
+ """
88
+ Relation-aware fusion module.
89
+
90
+ Fuses shallow (spatial) and deep (context) features using
91
+ cross-attention mechanism.
92
+ """
93
+
94
+ def __init__(self, ch, conv_cfg, norm_cfg, act_cfg, ext=2, r=16):
95
+ super().__init__()
96
+ self.r = r
97
+ self.g1 = nn.Parameter(torch.zeros(1))
98
+ self.g2 = nn.Parameter(torch.zeros(1))
99
+ self.sp_mlp = nn.Sequential(
100
+ nn.Linear(ch * 2, ch),
101
+ nn.ReLU(),
102
+ nn.Linear(ch, ch)
103
+ )
104
+ self.sp_att = ChannelAtt(ch * ext, ch, conv_cfg, norm_cfg, act_cfg)
105
+ self.co_mlp = nn.Sequential(
106
+ nn.Linear(ch * 2, ch),
107
+ nn.ReLU(),
108
+ nn.Linear(ch, ch)
109
+ )
110
+ self.co_att = ChannelAtt(ch, ch, conv_cfg, norm_cfg, act_cfg)
111
+ self.co_head = ConvModule(ch, ch, 3, 1, 1,
112
+ conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
113
+ self.smooth = ConvModule(ch, ch, 3, 1, 1,
114
+ conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None)
115
+
116
+ def forward(self, sp_feat, co_feat):
117
+ s_f, s_a = self.sp_att(sp_feat)
118
+ c_f, c_a = self.co_att(co_feat)
119
+ b, c = s_a.shape[:2]
120
+
121
+ # Use loop-based batch mm to avoid CUBLAS strided batched issues
122
+ s_a_reshaped = s_a.view(b, self.r, c // self.r)
123
+ c_a_reshaped = c_a.view(b, self.r, c // self.r).permute(0, 2, 1)
124
+ aff = batch_mm_loop(s_a_reshaped, c_a_reshaped).view(b, -1)
125
+
126
+ re_s = torch.sigmoid(s_a + self.g1 * F.relu(self.sp_mlp(aff)).unsqueeze(-1).unsqueeze(-1))
127
+ re_c = torch.sigmoid(c_a + self.g2 * F.relu(self.co_mlp(aff)).unsqueeze(-1).unsqueeze(-1))
128
+
129
+ c_f = self.co_head(
130
+ F.interpolate(c_f * re_c, s_f.shape[2:], mode='bilinear', align_corners=False)
131
+ )
132
+ return s_f, c_f, self.smooth(s_f * re_s + c_f)
133
+
134
+
135
+ class Reducer(nn.Module):
136
+ """Channel reducer module."""
137
+
138
+ def __init__(self, in_ch=512, reduce=128):
139
+ super().__init__()
140
+ self.conv = nn.Conv2d(in_ch, reduce, 1, bias=False)
141
+ self.bn = nn.SyncBatchNorm(reduce)
142
+
143
+ def forward(self, x):
144
+ return F.relu(self.bn(self.conv(x)))
145
+
146
+
147
+ class ISDHead(nn.Module):
148
+ """
149
+ ISD decoder head.
150
+
151
+ Combines shallow STDC features with deep backbone features
152
+ using relation-aware fusion at multiple scales.
153
+ """
154
+
155
+ def __init__(self, in_ch, ch, num_classes, down_ratio, prev_ch,
156
+ conv_cfg=None, norm_cfg=dict(type='SyncBN'), act_cfg=dict(type='ReLU'),
157
+ dropout=0.1, reduce=False, stdc_pretrain=''):
158
+ super().__init__()
159
+ self.ch = ch
160
+ self.fuse8 = RelationAwareFusion(ch, conv_cfg, norm_cfg, act_cfg, ext=2)
161
+ self.fuse16 = RelationAwareFusion(ch, conv_cfg, norm_cfg, act_cfg, ext=4)
162
+ self.sr_dec = SRDecoder(conv_cfg, norm_cfg, act_cfg, ch, [4, 2, 2])
163
+ self.stdc = ShallowNet(in_channels=6, pretrain_model=stdc_pretrain)
164
+ self.lap = Lap_Pyramid_Conv(num_high=2)
165
+ self.seg_aux16 = SegmentationHead(conv_cfg, norm_cfg, act_cfg, ch, ch // 2, num_classes)
166
+ self.seg_aux8 = SegmentationHead(conv_cfg, norm_cfg, act_cfg, ch, ch // 2, num_classes)
167
+ self.seg = SegmentationHead(conv_cfg, norm_cfg, act_cfg, ch, ch // 2, num_classes)
168
+ self.reduce = Reducer() if reduce else None
169
+ self.drop = nn.Dropout2d(dropout) if dropout > 0 else None
170
+
171
+ def forward(self, inputs, prev_output, train_flag=True):
172
+ # Laplacian pyramid decomposition
173
+ pyr = self.lap.pyramid_decom(inputs)
174
+ pyr1_up = F.interpolate(pyr[1], pyr[0].shape[2:], mode='bilinear', align_corners=False)
175
+ high_in = torch.cat([pyr[0], pyr1_up], dim=1)
176
+
177
+ # Shallow features
178
+ s8, s16 = self.stdc(high_in)
179
+
180
+ # Deep features
181
+ deep = self.reduce(prev_output[0]) if self.reduce else prev_output[0]
182
+
183
+ # Multi-scale fusion
184
+ _, a16, f16 = self.fuse16(s16, deep)
185
+ _, a8, f8 = self.fuse8(s8, f16)
186
+
187
+ # Segmentation output
188
+ out = self.seg(self.drop(f8) if self.drop else f8)
189
+
190
+ if train_flag:
191
+ feats, sr_out = self.sr_dec(deep, True)
192
+ target = pyr[0] + pyr1_up
193
+ if sr_out.shape[2:] != target.shape[2:]:
194
+ sr_out = F.interpolate(sr_out, target.shape[2:], mode='bilinear', align_corners=False)
195
+ return (out,
196
+ self.seg_aux16(a8),
197
+ self.seg_aux8(a16),
198
+ {'recon_losses': F.mse_loss(sr_out, target) * 0.1},
199
+ {'fa_loss': self._fa(deep, feats)})
200
+ return out
201
+
202
+ def _fa(self, seg_f, sr_f, eps=1e-6):
203
+ """Feature alignment loss."""
204
+ if seg_f.shape[2:] != sr_f.shape[2:]:
205
+ sr_f = F.interpolate(sr_f, seg_f.shape[2:], mode='bilinear', align_corners=False)
206
+ sf = torch.flatten(seg_f, 2)
207
+ srf = torch.flatten(sr_f, 2)
208
+ sf = sf / (sf.norm(p=2, dim=2, keepdim=True) + eps)
209
+ srf = srf / (srf.norm(p=2, dim=2, keepdim=True) + eps)
210
+ # Use loop-based batch mm for CUBLAS compatibility
211
+ sf_t = sf.permute(0, 2, 1)
212
+ srf_t = srf.permute(0, 2, 1)
213
+ return F.l1_loss(batch_mm_loop(sf_t, sf), batch_mm_loop(srf_t, srf).detach())
214
+
215
+
216
+ class RefineASPPHead(nn.Module):
217
+ """
218
+ ASPP-based decoder head for deep path.
219
+
220
+ Processes low-resolution backbone features with
221
+ atrous spatial pyramid pooling.
222
+ """
223
+
224
+ def __init__(self, in_ch, ch, num_classes, dilations=(1, 12, 24, 36),
225
+ conv_cfg=None, norm_cfg=dict(type='SyncBN'), act_cfg=dict(type='ReLU'),
226
+ dropout=0.1, in_index=-1):
227
+ super().__init__()
228
+ self.in_index = in_index
229
+ self.pool = nn.Sequential(
230
+ nn.AdaptiveAvgPool2d(1),
231
+ ConvModule(in_ch, ch, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
232
+ )
233
+ self.aspp = ASPPModule(dilations, in_ch, ch, conv_cfg, norm_cfg, act_cfg)
234
+ self.bottle = ConvModule(
235
+ (len(dilations) + 1) * ch, ch, 3, padding=1,
236
+ conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg
237
+ )
238
+ self.seg = nn.Conv2d(ch, num_classes, 1)
239
+ self.drop = nn.Dropout2d(dropout) if dropout > 0 else None
240
+
241
+ def forward(self, inputs):
242
+ x = inputs[self.in_index] if isinstance(inputs, (list, tuple)) else inputs
243
+ outs = [F.interpolate(self.pool(x), x.shape[2:], mode='bilinear', align_corners=False)]
244
+ outs.extend(self.aspp(x))
245
+ feat = self.bottle(torch.cat(outs, dim=1))
246
+ return self.seg(self.drop(feat) if self.drop else feat), [feat]
models/isdnet.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ISDNet: Integrating Shallow and Deep Networks for Efficient Ultra-high Resolution Segmentation
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from mmcv.cnn import ConvModule
9
+ import timm
10
+
11
+ from .heads import ISDHead, RefineASPPHead
12
+
13
+
14
+ class ISDNet(nn.Module):
15
+ """
16
+ ISDNet model for ultra-high resolution segmentation.
17
+
18
+ Combines a deep ResNet backbone with a shallow STDC-like network
19
+ to efficiently process both global context and local details.
20
+
21
+ Args:
22
+ num_classes: Number of segmentation classes
23
+ backbone: Backbone model name (from timm)
24
+ ch: Base channel number for decoder
25
+ down_ratio: Downsampling ratio for deep path
26
+ dilations: ASPP dilation rates
27
+ pretrained: Use pretrained backbone weights
28
+ stdc_pretrain: Path to pretrained STDC weights
29
+ """
30
+
31
+ def __init__(self, num_classes=15, backbone='resnet18', ch=128,
32
+ down_ratio=4, dilations=(1, 12, 24, 36),
33
+ pretrained=True, stdc_pretrain=''):
34
+ super().__init__()
35
+ self.ds = down_ratio
36
+
37
+ # Backbone (deep path)
38
+ self.bb = timm.create_model(backbone, pretrained=pretrained, features_only=True)
39
+ bb_ch = self.bb.feature_info.channels()
40
+ print(f"Backbone channels: {bb_ch}")
41
+
42
+ # Deep decoder (ASPP)
43
+ self.dec = RefineASPPHead(bb_ch[-1], ch, num_classes, dilations, in_index=-1)
44
+
45
+ # Shallow decoder (ISD head)
46
+ self.ref = ISDHead(3, ch, num_classes, down_ratio, ch, stdc_pretrain=stdc_pretrain)
47
+
48
+ # Auxiliary head
49
+ self.aux = nn.Sequential(
50
+ ConvModule(bb_ch[-2], 64, 3, padding=1,
51
+ norm_cfg=dict(type='SyncBN'), act_cfg=dict(type='ReLU')),
52
+ nn.Dropout2d(0.1),
53
+ nn.Conv2d(64, num_classes, 1)
54
+ )
55
+
56
+ def forward(self, img, return_loss=True):
57
+ """
58
+ Forward pass.
59
+
60
+ Args:
61
+ img: Input image tensor (B, C, H, W)
62
+ return_loss: If True, return dict with all outputs for loss computation
63
+ If False, return only final segmentation output
64
+
65
+ Returns:
66
+ If return_loss=True: Dict with 'out', 'out_deep', 'out_aux16', 'out_aux8',
67
+ 'aux_out', 'losses_re', 'losses_fa'
68
+ If return_loss=False: Segmentation logits (B, num_classes, H, W)
69
+ """
70
+ # Downsample for deep path
71
+ x = self.bb(F.interpolate(
72
+ img,
73
+ [s // self.ds for s in img.shape[2:]],
74
+ mode='bilinear',
75
+ align_corners=False
76
+ ))
77
+
78
+ # Deep path output
79
+ out_g, prev = self.dec(x)
80
+
81
+ if return_loss:
82
+ # Full training forward with all auxiliary outputs
83
+ out_r, a16, a8, l_re, l_fa = self.ref(img, prev, True)
84
+ sz = img.shape[2:]
85
+ return {
86
+ 'out': F.interpolate(out_r, sz, mode='bilinear', align_corners=False),
87
+ 'out_deep': F.interpolate(out_g, sz, mode='bilinear', align_corners=False),
88
+ 'out_aux16': F.interpolate(a16, sz, mode='bilinear', align_corners=False),
89
+ 'out_aux8': F.interpolate(a8, sz, mode='bilinear', align_corners=False),
90
+ 'aux_out': F.interpolate(self.aux(x[-2]), sz, mode='bilinear', align_corners=False),
91
+ 'losses_re': l_re,
92
+ 'losses_fa': l_fa
93
+ }
94
+
95
+ # Inference: only shallow path output
96
+ return F.interpolate(
97
+ self.ref(img, prev, False),
98
+ img.shape[2:],
99
+ mode='bilinear',
100
+ align_corners=False
101
+ )
models/modules.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ISDNet building blocks: STDC-like modules and Laplacian pyramid
3
+ """
4
+
5
+ import os
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.nn import init
11
+
12
+
13
+ class ConvX(nn.Module):
14
+ """Basic conv-bn-relu block."""
15
+
16
+ def __init__(self, in_planes, out_planes, kernel=3, stride=1):
17
+ super().__init__()
18
+ self.conv = nn.Conv2d(
19
+ in_planes, out_planes,
20
+ kernel_size=kernel, stride=stride,
21
+ padding=kernel // 2, bias=False
22
+ )
23
+ self.bn = nn.SyncBatchNorm(out_planes)
24
+ self.relu = nn.ReLU(inplace=True)
25
+
26
+ def forward(self, x):
27
+ return self.relu(self.bn(self.conv(x)))
28
+
29
+
30
+ class AddBottleneck(nn.Module):
31
+ """STDC AddBottleneck: residual addition fusion."""
32
+
33
+ def __init__(self, in_planes, out_planes, block_num=3, stride=1):
34
+ super().__init__()
35
+ self.conv_list = nn.ModuleList()
36
+ self.stride = stride
37
+
38
+ if stride == 2:
39
+ self.avd_layer = nn.Sequential(
40
+ nn.Conv2d(out_planes // 2, out_planes // 2, 3, 2, 1,
41
+ groups=out_planes // 2, bias=False),
42
+ nn.SyncBatchNorm(out_planes // 2)
43
+ )
44
+ self.skip = nn.Sequential(
45
+ nn.Conv2d(in_planes, in_planes, 3, 2, 1, groups=in_planes, bias=False),
46
+ nn.SyncBatchNorm(in_planes),
47
+ nn.Conv2d(in_planes, out_planes, 1, bias=False),
48
+ nn.SyncBatchNorm(out_planes)
49
+ )
50
+ stride = 1
51
+
52
+ for idx in range(block_num):
53
+ if idx == 0:
54
+ self.conv_list.append(ConvX(in_planes, out_planes // 2, kernel=1))
55
+ elif idx == 1 and block_num == 2:
56
+ self.conv_list.append(ConvX(out_planes // 2, out_planes // 2, stride=stride))
57
+ elif idx == 1:
58
+ self.conv_list.append(ConvX(out_planes // 2, out_planes // 4, stride=stride))
59
+ elif idx < block_num - 1:
60
+ self.conv_list.append(
61
+ ConvX(out_planes // int(math.pow(2, idx)),
62
+ out_planes // int(math.pow(2, idx + 1)))
63
+ )
64
+ else:
65
+ self.conv_list.append(
66
+ ConvX(out_planes // int(math.pow(2, idx)),
67
+ out_planes // int(math.pow(2, idx)))
68
+ )
69
+
70
+ def forward(self, x):
71
+ out_list, out = [], x
72
+ for idx, conv in enumerate(self.conv_list):
73
+ if idx == 0 and self.stride == 2:
74
+ out = self.avd_layer(conv(out))
75
+ else:
76
+ out = conv(out)
77
+ out_list.append(out)
78
+
79
+ if self.stride == 2:
80
+ return torch.cat(out_list, dim=1) + self.skip(x)
81
+ return torch.cat(out_list, dim=1) + x
82
+
83
+
84
+ class CatBottleneck(nn.Module):
85
+ """STDC CatBottleneck: concatenation fusion."""
86
+
87
+ def __init__(self, in_planes, out_planes, block_num=3, stride=1):
88
+ super().__init__()
89
+ self.conv_list = nn.ModuleList()
90
+ self.stride = stride
91
+
92
+ if stride == 2:
93
+ self.avd_layer = nn.Sequential(
94
+ nn.Conv2d(out_planes // 2, out_planes // 2, 3, 2, 1,
95
+ groups=out_planes // 2, bias=False),
96
+ nn.SyncBatchNorm(out_planes // 2)
97
+ )
98
+ self.skip = nn.AvgPool2d(3, 2, 1)
99
+ stride = 1
100
+
101
+ for idx in range(block_num):
102
+ if idx == 0:
103
+ self.conv_list.append(ConvX(in_planes, out_planes // 2, kernel=1))
104
+ elif idx == 1 and block_num == 2:
105
+ self.conv_list.append(ConvX(out_planes // 2, out_planes // 2, stride=stride))
106
+ elif idx == 1:
107
+ self.conv_list.append(ConvX(out_planes // 2, out_planes // 4, stride=stride))
108
+ elif idx < block_num - 1:
109
+ self.conv_list.append(
110
+ ConvX(out_planes // int(math.pow(2, idx)),
111
+ out_planes // int(math.pow(2, idx + 1)))
112
+ )
113
+ else:
114
+ self.conv_list.append(
115
+ ConvX(out_planes // int(math.pow(2, idx)),
116
+ out_planes // int(math.pow(2, idx)))
117
+ )
118
+
119
+ def forward(self, x):
120
+ out_list = []
121
+ out1 = self.conv_list[0](x)
122
+
123
+ for idx, conv in enumerate(self.conv_list[1:]):
124
+ if idx == 0 and self.stride == 2:
125
+ out = conv(self.avd_layer(out1))
126
+ elif idx == 0:
127
+ out = conv(out1)
128
+ else:
129
+ out = conv(out)
130
+ out_list.append(out)
131
+
132
+ if self.stride == 2:
133
+ out_list.insert(0, self.skip(out1))
134
+ else:
135
+ out_list.insert(0, out1)
136
+
137
+ return torch.cat(out_list, dim=1)
138
+
139
+
140
+ class ShallowNet(nn.Module):
141
+ """
142
+ STDC-like shallow network for high-resolution feature extraction.
143
+
144
+ Args:
145
+ base: Base channel number
146
+ in_channels: Input channels (3 for RGB, 6 for pyramid concat)
147
+ layers: Number of blocks per stage
148
+ block_num: Number of convs per block
149
+ type: 'cat' for CatBottleneck, 'add' for AddBottleneck
150
+ pretrain_model: Path to pretrained STDC weights
151
+ """
152
+
153
+ def __init__(self, base=64, in_channels=3, layers=[2, 2], block_num=4,
154
+ type="cat", pretrain_model=''):
155
+ super().__init__()
156
+ block = CatBottleneck if type == "cat" else AddBottleneck
157
+ self.in_channels = in_channels
158
+
159
+ features = [
160
+ ConvX(in_channels, base // 2, 3, 2),
161
+ ConvX(base // 2, base, 3, 2)
162
+ ]
163
+
164
+ for i, layer in enumerate(layers):
165
+ for j in range(layer):
166
+ if i == 0 and j == 0:
167
+ features.append(block(base, base * 4, block_num, 2))
168
+ elif j == 0:
169
+ features.append(
170
+ block(base * int(math.pow(2, i + 1)),
171
+ base * int(math.pow(2, i + 2)), block_num, 2)
172
+ )
173
+ else:
174
+ features.append(
175
+ block(base * int(math.pow(2, i + 2)),
176
+ base * int(math.pow(2, i + 2)), block_num, 1)
177
+ )
178
+
179
+ self.features = nn.Sequential(*features)
180
+ self.x2 = nn.Sequential(self.features[:1])
181
+ self.x4 = nn.Sequential(self.features[1:2])
182
+ self.x8 = nn.Sequential(self.features[2:4])
183
+ self.x16 = nn.Sequential(self.features[4:6])
184
+
185
+ if pretrain_model and os.path.exists(pretrain_model):
186
+ print(f'Loading pretrain model {pretrain_model}')
187
+ sd = torch.load(pretrain_model, weights_only=False)["state_dict"]
188
+ ssd = self.state_dict()
189
+ for k, v in sd.items():
190
+ if k == 'features.0.conv.weight' and in_channels != 3:
191
+ v = torch.cat([v, v], dim=1)
192
+ if k in ssd:
193
+ ssd.update({k: v})
194
+ self.load_state_dict(ssd, strict=False)
195
+ else:
196
+ for m in self.modules():
197
+ if isinstance(m, nn.Conv2d):
198
+ init.kaiming_normal_(m.weight, mode='fan_out')
199
+ elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
200
+ init.constant_(m.weight, 1)
201
+ init.constant_(m.bias, 0)
202
+
203
+ def forward(self, x):
204
+ x2 = self.x2(x)
205
+ x4 = self.x4(x2)
206
+ x8 = self.x8(x4)
207
+ x16 = self.x16(x8)
208
+ return x8, x16
209
+
210
+
211
+ class Lap_Pyramid_Conv(nn.Module):
212
+ """
213
+ Laplacian pyramid decomposition.
214
+
215
+ Extracts high-frequency details at multiple scales.
216
+ """
217
+
218
+ def __init__(self, num_high=3, gauss_chl=3):
219
+ super().__init__()
220
+ self.num_high = num_high
221
+ self.gauss_chl = gauss_chl
222
+
223
+ k = torch.tensor([
224
+ [1., 4., 6., 4., 1],
225
+ [4., 16., 24., 16., 4.],
226
+ [6., 24., 36., 24., 6.],
227
+ [4., 16., 24., 16., 4.],
228
+ [1., 4., 6., 4., 1.]
229
+ ]) / 256.
230
+ self.register_buffer('kernel', k.repeat(gauss_chl, 1, 1, 1))
231
+
232
+ def conv_gauss(self, img, k):
233
+ return F.conv2d(F.pad(img, (2, 2, 2, 2), mode='reflect'), k, groups=img.shape[1])
234
+
235
+ def downsample(self, x):
236
+ return x[:, :, ::2, ::2]
237
+
238
+ def upsample(self, x):
239
+ cc = torch.cat([x, torch.zeros_like(x)], dim=3)
240
+ cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])
241
+ cc = cc.permute(0, 1, 3, 2)
242
+ cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3],
243
+ x.shape[2] * 2, device=x.device)], dim=3)
244
+ cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)
245
+ return self.conv_gauss(cc.permute(0, 1, 3, 2), 4 * self.kernel)
246
+
247
+ def pyramid_decom(self, img):
248
+ """Decompose image into Laplacian pyramid (high-frequency residuals)."""
249
+ current = img
250
+ pyr = []
251
+ for _ in range(self.num_high):
252
+ down = self.downsample(self.conv_gauss(current, self.kernel))
253
+ up = self.upsample(down)
254
+ if up.shape[2:] != current.shape[2:]:
255
+ up = F.interpolate(up, current.shape[2:])
256
+ pyr.append(current - up)
257
+ current = down
258
+ return pyr
utils/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ISDNet utilities
3
+ """
4
+
5
+ from .distributed import (
6
+ setup_distributed,
7
+ cleanup_distributed,
8
+ print_rank0,
9
+ batch_mm_loop,
10
+ )
11
+
12
+ __all__ = [
13
+ "setup_distributed",
14
+ "cleanup_distributed",
15
+ "print_rank0",
16
+ "batch_mm_loop",
17
+ ]
utils/distributed.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Distributed training utilities
3
+ """
4
+
5
+ import os
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+
10
+ def setup_distributed():
11
+ """Initialize distributed training."""
12
+ if 'RANK' in os.environ:
13
+ rank = int(os.environ['RANK'])
14
+ world_size = int(os.environ['WORLD_SIZE'])
15
+ local_rank = int(os.environ['LOCAL_RANK'])
16
+ else:
17
+ rank = 0
18
+ world_size = 1
19
+ local_rank = 0
20
+
21
+ if world_size > 1:
22
+ dist.init_process_group('nccl')
23
+ torch.cuda.set_device(local_rank)
24
+
25
+ return rank, world_size, local_rank
26
+
27
+
28
+ def cleanup_distributed():
29
+ """Cleanup distributed training."""
30
+ if dist.is_initialized():
31
+ dist.destroy_process_group()
32
+
33
+
34
+ def print_rank0(msg, rank=0):
35
+ """Print only from rank 0."""
36
+ if rank == 0:
37
+ print(msg)
38
+
39
+
40
+ def batch_mm_loop(a, b):
41
+ """
42
+ Batch matrix multiply using a loop over the batch dimension.
43
+ Avoids CUBLAS strided batched routines which have issues on L40S/CUDA 12.8/PyTorch 2.10.
44
+
45
+ Args:
46
+ a: Tensor of shape (batch, m, k)
47
+ b: Tensor of shape (batch, k, n)
48
+
49
+ Returns:
50
+ Tensor of shape (batch, m, n)
51
+ """
52
+ batch = a.shape[0]
53
+ results = []
54
+ for i in range(batch):
55
+ results.append(torch.mm(a[i], b[i]))
56
+ return torch.stack(results, dim=0)
weights/isdnet_flair_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:233b4a931fe370f395d0ce60d636036eefc35e596b09b1acfa54950d7f1d89e1
3
+ size 142441755