Image Segmentation
English
antoine.carreaud67 commited on
Commit
d43c376
·
1 Parent(s): 9367521

Update with new experiments

Browse files
model/CASWiT_fusion_last_stage_add.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CASWiT: Context-Aware Swin Transformer for Ultra-High Resolution Semantic Segmentation
3
+
4
+ This module implements the main CASWiT model architecture with dual-branch
5
+ high-resolution and low-resolution processing with cross-attention fusion.
6
+ """
7
+
8
+ import math
9
+ from typing import Dict
10
+ import torch
11
+ import torch.nn as nn
12
+ from transformers import UperNetForSemanticSegmentation
13
+ from transformers.utils import logging as hf_logging
14
+
15
+ hf_logging.set_verbosity_error()
16
+ hf_logging.disable_progress_bar()
17
+
18
+
19
+
20
+ class DropPath(nn.Module):
21
+ """Drop path (stochastic depth) regularization module."""
22
+ def __init__(self, drop_prob: float = 0.0):
23
+ super().__init__()
24
+ self.drop_prob = float(drop_prob)
25
+
26
+ def forward(self, x):
27
+ if self.drop_prob == 0.0 or (not self.training):
28
+ return x
29
+ keep = 1.0 - self.drop_prob
30
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
31
+ mask = x.new_empty(shape).bernoulli_(keep).div_(keep)
32
+ return x * mask
33
+
34
+
35
+ class CrossFusionBlock(nn.Module):
36
+ """
37
+ Cross-attention fusion block that enables HR features to attend to LR features.
38
+
39
+ Implements pre-norm cross-attention (Q=HR, K/V=LR).
40
+
41
+ Args:
42
+ C_hr: Channel dimension of HR features
43
+ C_lr: Channel dimension of LR features
44
+ num_heads: Number of attention heads
45
+ mlp_ratio: MLP expansion ratio
46
+ drop: Dropout rate
47
+ drop_path: Drop path rate
48
+ """
49
+ def __init__(self, C_hr: int, C_lr: int, num_heads: int = 8,
50
+ mlp_ratio: float = 4.0, drop: float = 0.0, drop_path: float = 0.1):
51
+ super().__init__()
52
+
53
+ self.norm_q = nn.LayerNorm(C_hr)
54
+ self.norm_kv = nn.LayerNorm(C_lr)
55
+ self.attn = nn.MultiheadAttention(
56
+ embed_dim=C_hr, num_heads=num_heads, kdim=C_lr, vdim=C_lr,
57
+ dropout=drop, batch_first=True
58
+ )
59
+
60
+ hidden = int(C_hr * mlp_ratio)
61
+ self.mlp = nn.Sequential(
62
+ nn.LayerNorm(C_hr),
63
+ nn.Linear(C_hr, hidden),
64
+ nn.GELU(),
65
+ nn.Linear(hidden, C_hr),
66
+ )
67
+
68
+ def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> torch.Tensor:
69
+ """
70
+ Forward pass through cross-attention fusion block.
71
+
72
+ Args:
73
+ x_hr: HR features [B, C_hr, H_hr, W_hr]
74
+ x_lr: LR features [B, C_lr, H_lr, W_lr]
75
+
76
+ Returns:
77
+ Fused HR features [B, C_hr, H_hr, W_hr]
78
+ """
79
+ B, C_hr, H_hr, W_hr = x_hr.shape
80
+ _, C_lr, H_lr, W_lr = x_lr.shape
81
+
82
+ # Flatten to sequences
83
+ q = x_hr.flatten(2).transpose(1, 2) # [B, N_hr, C_hr]
84
+ kv = x_lr.flatten(2).transpose(1, 2) # [B, N_lr, C_lr]
85
+
86
+ # Pre-norm
87
+ qn = self.norm_q(q)
88
+ kvn = self.norm_kv(kv)
89
+
90
+ attn_out, _ = self.attn(qn, kvn, kvn) # [B, N_hr, C_hr]
91
+
92
+ # Residual connection + MLP
93
+ y = q + attn_out
94
+ y = y + self.mlp(y)
95
+
96
+ return y.transpose(1, 2).view(B, C_hr, H_hr, W_hr)
97
+
98
+
99
+ class CASWiT(nn.Module):
100
+ """
101
+ CASWiT: Context-Aware Swin Transformer for Ultra-High Resolution Semantic Segmentation.
102
+
103
+ Dual-branch architecture with:
104
+ - HR branch: Processes high-resolution crops
105
+ - LR branch: Processes low-resolution context
106
+ - Cross-attention fusion at each encoder stage
107
+
108
+ Args:
109
+ num_head_xa: Number of cross-attention heads
110
+ num_classes: Number of segmentation classes
111
+ model_name: HuggingFace model identifier for UPerNet-Swin
112
+ mlp_ratio: MLP expansion ratio in fusion blocks
113
+ drop_path: Drop path rate
114
+ """
115
+ def __init__(self, num_head_xa: int = 1, num_classes: int = 12,
116
+ model_name: str = "openmmlab/upernet-swin-tiny",
117
+ mlp_ratio: float = 4.0, drop_path: float = 0.1):
118
+ super().__init__()
119
+ # Load two UPerNet backbones (HR and LR branches)
120
+ model_hr = UperNetForSemanticSegmentation.from_pretrained(
121
+ model_name, num_labels=num_classes, ignore_mismatched_sizes=True
122
+ )
123
+ model_lr = UperNetForSemanticSegmentation.from_pretrained(
124
+ model_name, num_labels=num_classes, ignore_mismatched_sizes=True
125
+ )
126
+
127
+ # Extract HR branch components
128
+ self.embeddings_hr = model_hr.backbone.embeddings
129
+ self.encoder_layers_hr = model_hr.backbone.encoder.layers
130
+ self.hidden_states_norms_hr = model_hr.backbone.hidden_states_norms
131
+ self.decoder = model_hr.decode_head
132
+
133
+ # Extract LR branch components
134
+ self.embeddings_lr = model_lr.backbone.embeddings
135
+ self.encoder_layers_lr = model_lr.backbone.encoder.layers
136
+ self.hidden_states_norms_lr = model_lr.backbone.hidden_states_norms
137
+ self.decoder_lr = model_lr.decode_head
138
+
139
+ # Cross-attention blocks at each stage
140
+ # Dimensions: tiny:[96, 192, 384, 768] base:[128, 256, 512, 1024] large:[192, 384, 768, 1536]
141
+ dims_map = {
142
+ "tiny": [96, 192, 384, 768],
143
+ "base": [128, 256, 512, 1024],
144
+ "large": [192, 384, 768, 1536]
145
+ }
146
+ # Infer dimensions from model name
147
+ if "tiny" in model_name.lower():
148
+ dims = dims_map["tiny"]
149
+ elif "large" in model_name.lower():
150
+ dims = dims_map["large"]
151
+ else:
152
+ dims = dims_map["base"] # default to base
153
+
154
+
155
+ def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> Dict[str, torch.Tensor]:
156
+ """
157
+ Forward pass through CASWiT model.
158
+
159
+ Args:
160
+ x_hr: HR input images [B, 3, H_hr, W_hr]
161
+ x_lr: LR input images [B, 3, H_lr, W_lr]
162
+
163
+ Returns:
164
+ Dictionary with 'logits_hr' and 'logits_lr' segmentation logits
165
+ """
166
+ B = x_hr.size(0)
167
+
168
+ # Patch embeddings
169
+ x_hr_seq, _ = self.embeddings_hr(x_hr)
170
+ x_lr_seq, _ = self.embeddings_lr(x_lr)
171
+
172
+ N_hr, C_hr = x_hr_seq.shape[1], x_hr_seq.shape[2]
173
+ N_lr, C_lr = x_lr_seq.shape[1], x_lr_seq.shape[2]
174
+ H_hr = W_hr = int(math.sqrt(N_hr))
175
+ H_lr = W_lr = int(math.sqrt(N_lr))
176
+ dims_hr = (H_hr, W_hr)
177
+ dims_lr = (H_lr, W_lr)
178
+
179
+ features_hr: Dict[str, torch.Tensor] = {}
180
+ features_lr: Dict[str, torch.Tensor] = {}
181
+
182
+ # Process through encoder stages with cross-attention fusion
183
+ for idx, (stage_hr, stage_lr) in enumerate(zip(
184
+ self.encoder_layers_hr, self.encoder_layers_lr
185
+ )):
186
+ # HR branch blocks
187
+ for block in stage_hr.blocks:
188
+ x_hr_seq = block(x_hr_seq, dims_hr)
189
+ if isinstance(x_hr_seq, tuple):
190
+ x_hr_seq = x_hr_seq[0]
191
+
192
+ # LR branch blocks
193
+ for block in stage_lr.blocks:
194
+ x_lr_seq = block(x_lr_seq, dims_lr)
195
+ if isinstance(x_lr_seq, tuple):
196
+ x_lr_seq = x_lr_seq[0]
197
+
198
+ # Layer normalization
199
+ x_hr_seq = self.hidden_states_norms_hr[f"stage{idx+1}"](x_hr_seq)
200
+ x_lr_seq = self.hidden_states_norms_lr[f"stage{idx+1}"](x_lr_seq)
201
+
202
+ H_hr, W_hr = dims_hr
203
+ H_lr, W_lr = dims_lr
204
+ C_hr = x_hr_seq.shape[-1]
205
+ C_lr = x_lr_seq.shape[-1]
206
+
207
+ # Reshape to spatial format
208
+ feat_hr = x_hr_seq.transpose(1, 2).contiguous().view(B, C_hr, H_hr, W_hr)
209
+ feat_lr = x_lr_seq.transpose(1, 2).contiguous().view(B, C_lr, H_lr, W_lr)
210
+
211
+ # Cross-attend HR to LR
212
+ if idx == 3:
213
+ fused_hr = feat_hr + feat_lr
214
+ else:
215
+ fused_hr = feat_hr
216
+ #fused_hr = ca(feat_hr, feat_lr)
217
+ fused_hr_seq = fused_hr.flatten(2).transpose(1, 2).contiguous()
218
+
219
+ # Downsample if stage has it
220
+ if stage_hr.downsample is not None:
221
+ fused_hr_seq = stage_hr.downsample(fused_hr_seq, dims_hr)
222
+ dims_hr = (dims_hr[0] // 2, dims_hr[1] // 2)
223
+ if stage_lr.downsample is not None:
224
+ x_lr_seq = stage_lr.downsample(x_lr_seq, dims_lr)
225
+ dims_lr = (dims_lr[0] // 2, dims_lr[1] // 2)
226
+
227
+ features_hr[f"stage{idx+1}"] = fused_hr
228
+ features_lr[f"stage{idx+1}"] = feat_lr
229
+ x_hr_seq = fused_hr_seq
230
+
231
+ # Decode HR features
232
+ features_tuple = (
233
+ features_hr["stage1"],
234
+ features_hr["stage2"],
235
+ features_hr["stage3"],
236
+ features_hr["stage4"],
237
+ )
238
+ logits = self.decoder(features_tuple)
239
+
240
+ # Decode LR features (for auxiliary supervision)
241
+ features_tuple_lr = (
242
+ features_lr["stage1"],
243
+ features_lr["stage2"],
244
+ features_lr["stage3"],
245
+ features_lr["stage4"],
246
+ )
247
+ logits_lr = self.decoder_lr(features_tuple_lr)
248
+
249
+ return {"logits_hr": logits, "logits_lr": logits_lr}
250
+
model/CASWiT_m2f.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CASWiT with Mask2Former heads (HuggingFace).
3
+
4
+ This file is identical to the original CASWiT implementation except that:
5
+ - self.decoder and self.decoder_lr are replaced by a Mask2Former semantic head
6
+ implemented using HuggingFace's Mask2Former pixel decoder + transformer module.
7
+
8
+ The rest of the model (embeddings, Swin encoder stages, cross-attention fusion) is unchanged.
9
+ """
10
+
11
+ import math
12
+ from typing import Dict, Tuple, List
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+ from transformers import UperNetForSemanticSegmentation, Mask2FormerConfig
17
+ from transformers.models.mask2former.modeling_mask2former import (
18
+ Mask2FormerPixelDecoder,
19
+ Mask2FormerTransformerModule,
20
+ )
21
+ from transformers.utils import logging as hf_logging
22
+
23
+ hf_logging.set_verbosity_error()
24
+ hf_logging.disable_progress_bar()
25
+
26
+
27
+
28
+ class DropPath(nn.Module):
29
+ """Drop path (stochastic depth) regularization module."""
30
+ def __init__(self, drop_prob: float = 0.0):
31
+ super().__init__()
32
+ self.drop_prob = float(drop_prob)
33
+
34
+ def forward(self, x):
35
+ if self.drop_prob == 0.0 or (not self.training):
36
+ return x
37
+ keep = 1.0 - self.drop_prob
38
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
39
+ mask = x.new_empty(shape).bernoulli_(keep).div_(keep)
40
+ return x * mask
41
+
42
+
43
+ class CrossFusionBlock(nn.Module):
44
+ """
45
+ Cross-attention fusion block that enables HR features to attend to LR features.
46
+ Implements pre-norm cross-attention (Q=HR, K/V=LR).
47
+ """
48
+ def __init__(self, C_hr: int, C_lr: int, num_heads: int = 8,
49
+ mlp_ratio: float = 4.0, drop: float = 0.0, drop_path: float = 0.1):
50
+ super().__init__()
51
+
52
+ self.norm_q = nn.LayerNorm(C_hr)
53
+ self.norm_kv = nn.LayerNorm(C_lr)
54
+ self.attn = nn.MultiheadAttention(
55
+ embed_dim=C_hr, num_heads=num_heads, kdim=C_lr, vdim=C_lr,
56
+ dropout=drop, batch_first=True
57
+ )
58
+
59
+ hidden = int(C_hr * mlp_ratio)
60
+ self.mlp = nn.Sequential(
61
+ nn.LayerNorm(C_hr),
62
+ nn.Linear(C_hr, hidden),
63
+ nn.GELU(),
64
+ nn.Linear(hidden, C_hr),
65
+ )
66
+
67
+ def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> torch.Tensor:
68
+ B, C_hr, H_hr, W_hr = x_hr.shape
69
+
70
+ q = x_hr.flatten(2).transpose(1, 2) # [B, N_hr, C_hr]
71
+ kv = x_lr.flatten(2).transpose(1, 2) # [B, N_lr, C_lr]
72
+
73
+ qn = self.norm_q(q)
74
+ kvn = self.norm_kv(kv)
75
+
76
+ attn_out, _ = self.attn(qn, kvn, kvn) # [B, N_hr, C_hr]
77
+
78
+ y = q + attn_out
79
+ y = y + self.mlp(y)
80
+
81
+ return y.transpose(1, 2).view(B, C_hr, H_hr, W_hr)
82
+
83
+
84
+ class Mask2FormerSemanticHead(nn.Module):
85
+ """
86
+ A minimal Mask2Former "semantic segmentation head" that consumes multi-scale backbone features
87
+ and outputs per-class per-pixel scores.
88
+
89
+ Input:
90
+ features: tuple/list of 4 feature maps (stage1..stage4), each [B, C_i, H_i, W_i].
91
+ The spatial strides should typically be [4, 8, 16, 32] relative to the input image.
92
+
93
+ Output:
94
+ semantic_scores: [B, num_classes, H_out, W_out], where H_out/W_out match the mask_features
95
+ resolution produced by Mask2Former pixel decoder (typically stride 4).
96
+
97
+ Notes:
98
+ Mask2Former natively predicts:
99
+ - class_queries_logits: [B, Q, num_classes+1] (includes "no object")
100
+ - masks_queries_logits: [B, Q, H_out, W_out]
101
+ For semantic segmentation, a common aggregation is:
102
+ semantic_probs = sum_q softmax(class_logits_q)[c] * sigmoid(mask_logits_q)[h,w]
103
+ Here we return these aggregated per-class *scores* (in [0,1]) as "logits" for compatibility
104
+ with the original CASWiT API. If you need true logits, apply logit() carefully (numerical stability).
105
+ """
106
+ def __init__(
107
+ self,
108
+ feature_channels: List[int],
109
+ num_classes: int,
110
+ num_queries: int = 100,
111
+ feature_size: int = 256,
112
+ mask_feature_size: int = 256,
113
+ common_stride: int = 4,
114
+ ):
115
+ super().__init__()
116
+
117
+ cfg = Mask2FormerConfig(
118
+ num_labels=num_classes,
119
+ num_queries=num_queries,
120
+ feature_size=feature_size,
121
+ mask_feature_size=mask_feature_size,
122
+ common_stride=common_stride,
123
+ feature_strides=[4, 8, 16, 32],
124
+ encoder_layers=1,
125
+ decoder_layers=1,
126
+ num_attention_heads=8,
127
+ dim_feedforward=1024,
128
+ output_auxiliary_logits=False,
129
+ # keep defaults for transformer, heads, etc.
130
+ )
131
+
132
+ self.config = cfg
133
+ self.num_classes = num_classes
134
+ self.num_queries = num_queries
135
+
136
+ # Pixel decoder consumes backbone channels and produces:
137
+ # - multi_scale_features (3 levels: 1/8,1/16,1/32)
138
+ # - mask_features (typically 1/4)
139
+ self.pixel_decoder = Mask2FormerPixelDecoder(cfg, feature_channels=feature_channels)
140
+
141
+ # Transformer module consumes:
142
+ # - multi_scale_features (list of 3 tensors)
143
+ # - mask_features (tensor at stride 4)
144
+ # and returns masks_queries_logits for each decoder layer + intermediate states
145
+ self.transformer_module = Mask2FormerTransformerModule(in_features=cfg.feature_size, config=cfg)
146
+
147
+ # Class predictor (same idea as HF Mask2FormerForUniversalSegmentation)
148
+ self.class_predictor = nn.Linear(cfg.hidden_dim, num_classes + 1)
149
+
150
+ def forward(self, features: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor:
151
+ if not isinstance(features, (tuple, list)) or len(features) != 4:
152
+ raise ValueError("Mask2FormerSemanticHead expects a tuple/list of 4 feature maps: (stage1, stage2, stage3, stage4).")
153
+
154
+ # Expected order: [stage1, stage2, stage3, stage4] (increasing stride).
155
+ # Pixel decoder internally reverses and uses the last 3 feature maps for deformable attention.
156
+ pixel_out = self.pixel_decoder(list(features), return_dict=True)
157
+ multi_scale = list(pixel_out.multi_scale_features) # 3 levels
158
+ mask_features = pixel_out.mask_features # stride 4
159
+
160
+ dec_out = self.transformer_module(
161
+ multi_scale_features=multi_scale,
162
+ mask_features=mask_features,
163
+ output_hidden_states=True,
164
+ output_attentions=False,
165
+ )
166
+
167
+ # Use last decoder layer predictions
168
+ masks_queries_logits = dec_out.masks_queries_logits[-1] # [B, Q, H, W]
169
+
170
+ # Last layer hidden state can be in shape [B, Q, D] OR [Q, B, D] depending on HF internals.
171
+ # For intermediate_hidden_states, HF uses [Q, B, D] (then transposes in their heads).
172
+ # We'll robustly support both:
173
+ hidden = dec_out.last_hidden_state
174
+ if hidden.dim() != 3:
175
+ raise RuntimeError(f"Unexpected last_hidden_state shape: {tuple(hidden.shape)}")
176
+
177
+ if hidden.shape[0] == self.num_queries and hidden.shape[1] == masks_queries_logits.shape[0]:
178
+ # [Q, B, D] -> [B, Q, D]
179
+ hidden_bqd = hidden.transpose(0, 1)
180
+ else:
181
+ # assume [B, Q, D]
182
+ hidden_bqd = hidden
183
+
184
+ class_queries_logits = self.class_predictor(hidden_bqd) # [B, Q, C+1]
185
+
186
+ # Aggregate to semantic per-class scores at mask resolution:
187
+ # softmax over classes (including no-object), then drop no-object channel
188
+ class_probs = class_queries_logits.softmax(dim=-1)[..., :-1] # [B, Q, C]
189
+ mask_probs = masks_queries_logits.sigmoid() # [B, Q, H, W]
190
+
191
+ # semantic_scores[b,c,h,w] = sum_q class_probs[b,q,c] * mask_probs[b,q,h,w]
192
+ semantic_scores = torch.einsum("bqc,bqhw->bchw", class_probs, mask_probs)
193
+
194
+ return semantic_scores
195
+
196
+
197
+ class CASWiT(nn.Module):
198
+ """
199
+ CASWiT: Context-Aware Swin Transformer for Ultra-High Resolution Semantic Segmentation.
200
+
201
+ Only change vs original: replace self.decoder and self.decoder_lr with Mask2FormerSemanticHead.
202
+ """
203
+ def __init__(self, num_head_xa: int = 1, num_classes: int = 12,
204
+ model_name: str = "openmmlab/upernet-swin-tiny",
205
+ mlp_ratio: float = 4.0, drop_path: float = 0.1):
206
+ super().__init__()
207
+
208
+ model_hr = UperNetForSemanticSegmentation.from_pretrained(
209
+ model_name, num_labels=num_classes, ignore_mismatched_sizes=True
210
+ )
211
+ model_lr = UperNetForSemanticSegmentation.from_pretrained(
212
+ model_name, num_labels=num_classes, ignore_mismatched_sizes=True
213
+ )
214
+
215
+ # Extract HR branch components
216
+ self.embeddings_hr = model_hr.backbone.embeddings
217
+ self.encoder_layers_hr = model_hr.backbone.encoder.layers
218
+ self.hidden_states_norms_hr = model_hr.backbone.hidden_states_norms
219
+
220
+ # Extract LR branch components
221
+ self.embeddings_lr = model_lr.backbone.embeddings
222
+ self.encoder_layers_lr = model_lr.backbone.encoder.layers
223
+ self.hidden_states_norms_lr = model_lr.backbone.hidden_states_norms
224
+
225
+ # Infer Swin stage dims from model name (same as original)
226
+ dims_map = {
227
+ "tiny": [96, 192, 384, 768],
228
+ "base": [128, 256, 512, 1024],
229
+ "large": [192, 384, 768, 1536]
230
+ }
231
+ if "tiny" in model_name.lower():
232
+ dims = dims_map["tiny"]
233
+ elif "large" in model_name.lower():
234
+ dims = dims_map["large"]
235
+ else:
236
+ dims = dims_map["base"]
237
+
238
+ # >>> ONLY MODIFIED PART: decoder / decoder_lr <<<
239
+ self.decoder = Mask2FormerSemanticHead(feature_channels=dims, num_classes=num_classes)
240
+ self.decoder_lr = Mask2FormerSemanticHead(feature_channels=dims, num_classes=num_classes)
241
+
242
+ # Cross-attention blocks at each stage
243
+ self.cross_attn_blocks = nn.ModuleList([
244
+ CrossFusionBlock(dim, dim, num_heads=num_head_xa,
245
+ mlp_ratio=mlp_ratio, drop=0.0, drop_path=drop_path)
246
+ for dim in dims
247
+ ])
248
+
249
+ def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> Dict[str, torch.Tensor]:
250
+ B = x_hr.size(0)
251
+
252
+ # Patch embeddings
253
+ x_hr_seq, _ = self.embeddings_hr(x_hr)
254
+ x_lr_seq, _ = self.embeddings_lr(x_lr)
255
+
256
+ N_hr = x_hr_seq.shape[1]
257
+ N_lr = x_lr_seq.shape[1]
258
+ H_hr = W_hr = int(math.sqrt(N_hr))
259
+ H_lr = W_lr = int(math.sqrt(N_lr))
260
+ dims_hr = (H_hr, W_hr)
261
+ dims_lr = (H_lr, W_lr)
262
+
263
+ features_hr: Dict[str, torch.Tensor] = {}
264
+ features_lr: Dict[str, torch.Tensor] = {}
265
+
266
+ for idx, (stage_hr, stage_lr, ca) in enumerate(zip(
267
+ self.encoder_layers_hr, self.encoder_layers_lr, self.cross_attn_blocks
268
+ )):
269
+ for block in stage_hr.blocks:
270
+ x_hr_seq = block(x_hr_seq, dims_hr)
271
+ if isinstance(x_hr_seq, tuple):
272
+ x_hr_seq = x_hr_seq[0]
273
+
274
+ for block in stage_lr.blocks:
275
+ x_lr_seq = block(x_lr_seq, dims_lr)
276
+ if isinstance(x_lr_seq, tuple):
277
+ x_lr_seq = x_lr_seq[0]
278
+
279
+ x_hr_seq = self.hidden_states_norms_hr[f"stage{idx+1}"](x_hr_seq)
280
+ x_lr_seq = self.hidden_states_norms_lr[f"stage{idx+1}"](x_lr_seq)
281
+
282
+ H_hr, W_hr = dims_hr
283
+ H_lr, W_lr = dims_lr
284
+ C_hr = x_hr_seq.shape[-1]
285
+ C_lr = x_lr_seq.shape[-1]
286
+
287
+ feat_hr = x_hr_seq.transpose(1, 2).contiguous().view(B, C_hr, H_hr, W_hr)
288
+ feat_lr = x_lr_seq.transpose(1, 2).contiguous().view(B, C_lr, H_lr, W_lr)
289
+
290
+ fused_hr = ca(feat_hr, feat_lr)
291
+ fused_hr_seq = fused_hr.flatten(2).transpose(1, 2).contiguous()
292
+
293
+ if stage_hr.downsample is not None:
294
+ fused_hr_seq = stage_hr.downsample(fused_hr_seq, dims_hr)
295
+ dims_hr = (dims_hr[0] // 2, dims_hr[1] // 2)
296
+ if stage_lr.downsample is not None:
297
+ x_lr_seq = stage_lr.downsample(x_lr_seq, dims_lr)
298
+ dims_lr = (dims_lr[0] // 2, dims_lr[1] // 2)
299
+
300
+ features_hr[f"stage{idx+1}"] = fused_hr
301
+ features_lr[f"stage{idx+1}"] = feat_lr
302
+ x_hr_seq = fused_hr_seq
303
+
304
+ # Decode HR features
305
+ features_tuple = (
306
+ features_hr["stage1"],
307
+ features_hr["stage2"],
308
+ features_hr["stage3"],
309
+ features_hr["stage4"],
310
+ )
311
+ logits = self.decoder(features_tuple)
312
+
313
+ # Decode LR features
314
+ features_tuple_lr = (
315
+ features_lr["stage1"],
316
+ features_lr["stage2"],
317
+ features_lr["stage3"],
318
+ features_lr["stage4"],
319
+ )
320
+ logits_lr = self.decoder_lr(features_tuple_lr)
321
+
322
+ return {"logits_hr": logits, "logits_lr": logits_lr}
323
+
324
+
325
+ def _test_mask2former_head():
326
+ """
327
+ Minimal sanity test: validates that the Mask2FormerSemanticHead consumes
328
+ a (stage1..stage4) feature tuple and returns [B, C, H1, W1] scores.
329
+ """
330
+ torch.manual_seed(0)
331
+ B = 1
332
+ num_classes = 12
333
+ dims = [96, 192, 384, 768]
334
+ H1, W1 = 8, 8
335
+
336
+ feats = (
337
+ torch.randn(B, dims[0], H1, W1),
338
+ torch.randn(B, dims[1], H1 // 2, W1 // 2),
339
+ torch.randn(B, dims[2], H1 // 4, W1 // 4),
340
+ torch.randn(B, dims[3], H1 // 8, W1 // 8),
341
+ )
342
+
343
+ head = Mask2FormerSemanticHead(feature_channels=dims, num_classes=num_classes, num_queries=50)
344
+ with torch.no_grad():
345
+ out = head(feats)
346
+
347
+ assert out.shape == (B, num_classes, H1, W1), f"Unexpected output shape: {out.shape}"
348
+ assert torch.isfinite(out).all(), "NaN/Inf in output"
349
+ return out.shape
350
+
351
+
352
+ if __name__ == "__main__":
353
+ # Run head test
354
+ print("Mask2Former head test output shape:", _test_mask2former_head())
model/CASWiT_segformer.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CASWiT: Context-Aware Swin Transformer for Ultra-High Resolution Semantic Segmentation
3
+
4
+ This module implements the main CASWiT model architecture with dual-branch
5
+ high-resolution and low-resolution processing with cross-attention fusion.
6
+ """
7
+
8
+ import math
9
+ from typing import Dict
10
+ import torch
11
+ import torch.nn as nn
12
+ from transformers import UperNetForSemanticSegmentation, SegformerConfig
13
+ from transformers.models.segformer.modeling_segformer import SegformerDecodeHead
14
+ from transformers.utils import logging as hf_logging
15
+
16
+ hf_logging.set_verbosity_error()
17
+ hf_logging.disable_progress_bar()
18
+
19
+
20
+
21
+ class DropPath(nn.Module):
22
+ """Drop path (stochastic depth) regularization module."""
23
+ def __init__(self, drop_prob: float = 0.0):
24
+ super().__init__()
25
+ self.drop_prob = float(drop_prob)
26
+
27
+ def forward(self, x):
28
+ if self.drop_prob == 0.0 or (not self.training):
29
+ return x
30
+ keep = 1.0 - self.drop_prob
31
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
32
+ mask = x.new_empty(shape).bernoulli_(keep).div_(keep)
33
+ return x * mask
34
+
35
+
36
+ class CrossFusionBlock(nn.Module):
37
+ """
38
+ Cross-attention fusion block that enables HR features to attend to LR features.
39
+
40
+ Implements pre-norm cross-attention (Q=HR, K/V=LR).
41
+
42
+ Args:
43
+ C_hr: Channel dimension of HR features
44
+ C_lr: Channel dimension of LR features
45
+ num_heads: Number of attention heads
46
+ mlp_ratio: MLP expansion ratio
47
+ drop: Dropout rate
48
+ drop_path: Drop path rate
49
+ """
50
+ def __init__(self, C_hr: int, C_lr: int, num_heads: int = 8,
51
+ mlp_ratio: float = 4.0, drop: float = 0.0, drop_path: float = 0.1):
52
+ super().__init__()
53
+
54
+ self.norm_q = nn.LayerNorm(C_hr)
55
+ self.norm_kv = nn.LayerNorm(C_lr)
56
+ self.attn = nn.MultiheadAttention(
57
+ embed_dim=C_hr, num_heads=num_heads, kdim=C_lr, vdim=C_lr,
58
+ dropout=drop, batch_first=True
59
+ )
60
+
61
+ hidden = int(C_hr * mlp_ratio)
62
+ self.mlp = nn.Sequential(
63
+ nn.LayerNorm(C_hr),
64
+ nn.Linear(C_hr, hidden),
65
+ nn.GELU(),
66
+ nn.Linear(hidden, C_hr),
67
+ )
68
+
69
+ def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> torch.Tensor:
70
+ """
71
+ Forward pass through cross-attention fusion block.
72
+
73
+ Args:
74
+ x_hr: HR features [B, C_hr, H_hr, W_hr]
75
+ x_lr: LR features [B, C_lr, H_lr, W_lr]
76
+
77
+ Returns:
78
+ Fused HR features [B, C_hr, H_hr, W_hr]
79
+ """
80
+ B, C_hr, H_hr, W_hr = x_hr.shape
81
+ _, C_lr, H_lr, W_lr = x_lr.shape
82
+
83
+ # Flatten to sequences
84
+ q = x_hr.flatten(2).transpose(1, 2) # [B, N_hr, C_hr]
85
+ kv = x_lr.flatten(2).transpose(1, 2) # [B, N_lr, C_lr]
86
+
87
+ # Pre-norm
88
+ qn = self.norm_q(q)
89
+ kvn = self.norm_kv(kv)
90
+
91
+ attn_out, _ = self.attn(qn, kvn, kvn) # [B, N_hr, C_hr]
92
+
93
+ # Residual connection + MLP
94
+ y = q + attn_out
95
+ y = y + self.mlp(y)
96
+
97
+ return y.transpose(1, 2).view(B, C_hr, H_hr, W_hr)
98
+
99
+
100
+ class CASWiT(nn.Module):
101
+ """
102
+ CASWiT: Context-Aware Swin Transformer for Ultra-High Resolution Semantic Segmentation.
103
+
104
+ Dual-branch architecture with:
105
+ - HR branch: Processes high-resolution crops
106
+ - LR branch: Processes low-resolution context
107
+ - Cross-attention fusion at each encoder stage
108
+
109
+ Args:
110
+ num_head_xa: Number of cross-attention heads
111
+ num_classes: Number of segmentation classes
112
+ model_name: HuggingFace model identifier for UPerNet-Swin
113
+ mlp_ratio: MLP expansion ratio in fusion blocks
114
+ drop_path: Drop path rate
115
+ """
116
+ def __init__(self, num_head_xa: int = 1, num_classes: int = 12,
117
+ model_name: str = "openmmlab/upernet-swin-tiny",
118
+ mlp_ratio: float = 4.0, drop_path: float = 0.1):
119
+ super().__init__()
120
+ # Load two UPerNet backbones (HR and LR branches)
121
+ model_hr = UperNetForSemanticSegmentation.from_pretrained(
122
+ model_name, num_labels=num_classes, ignore_mismatched_sizes=True
123
+ )
124
+ model_lr = UperNetForSemanticSegmentation.from_pretrained(
125
+ model_name, num_labels=num_classes, ignore_mismatched_sizes=True
126
+ )
127
+
128
+ # Extract HR branch components
129
+ self.embeddings_hr = model_hr.backbone.embeddings
130
+ self.encoder_layers_hr = model_hr.backbone.encoder.layers
131
+ self.hidden_states_norms_hr = model_hr.backbone.hidden_states_norms
132
+ self.decoder = None # placeholder, set after dims inference
133
+
134
+ # Extract LR branch components
135
+ self.embeddings_lr = model_lr.backbone.embeddings
136
+ self.encoder_layers_lr = model_lr.backbone.encoder.layers
137
+ self.hidden_states_norms_lr = model_lr.backbone.hidden_states_norms
138
+ self.decoder_lr = None # placeholder, set after dims inference
139
+
140
+ # Cross-attention blocks at each stage
141
+ # Dimensions: tiny:[96, 192, 384, 768] base:[128, 256, 512, 1024] large:[192, 384, 768, 1536]
142
+ dims_map = {
143
+ "tiny": [96, 192, 384, 768],
144
+ "base": [128, 256, 512, 1024],
145
+ "large": [192, 384, 768, 1536]
146
+ }
147
+ # Infer dimensions from model name
148
+ if "tiny" in model_name.lower():
149
+ dims = dims_map["tiny"]
150
+ elif "large" in model_name.lower():
151
+ dims = dims_map["large"]
152
+ else:
153
+ dims = dims_map["base"] # default to base
154
+
155
+ segformer_cfg = SegformerConfig(
156
+ num_labels=num_classes,
157
+ hidden_sizes=dims,
158
+ num_encoder_blocks=4,
159
+ decoder_hidden_size=512,
160
+ classifier_dropout_prob=0.0,
161
+ )
162
+ self.decoder = SegformerDecodeHead(segformer_cfg)
163
+ self.decoder_lr = SegformerDecodeHead(segformer_cfg)
164
+
165
+ self.cross_attn_blocks = nn.ModuleList([
166
+ CrossFusionBlock(dim, dim, num_heads=num_head_xa,
167
+ mlp_ratio=mlp_ratio, drop=0.0, drop_path=drop_path)
168
+ for dim in dims
169
+ ])
170
+
171
+ def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> Dict[str, torch.Tensor]:
172
+ """
173
+ Forward pass through CASWiT model.
174
+
175
+ Args:
176
+ x_hr: HR input images [B, 3, H_hr, W_hr]
177
+ x_lr: LR input images [B, 3, H_lr, W_lr]
178
+
179
+ Returns:
180
+ Dictionary with 'logits_hr' and 'logits_lr' segmentation logits
181
+ """
182
+ B = x_hr.size(0)
183
+
184
+ # Patch embeddings
185
+ x_hr_seq, _ = self.embeddings_hr(x_hr)
186
+ x_lr_seq, _ = self.embeddings_lr(x_lr)
187
+
188
+ N_hr, C_hr = x_hr_seq.shape[1], x_hr_seq.shape[2]
189
+ N_lr, C_lr = x_lr_seq.shape[1], x_lr_seq.shape[2]
190
+ H_hr = W_hr = int(math.sqrt(N_hr))
191
+ H_lr = W_lr = int(math.sqrt(N_lr))
192
+ dims_hr = (H_hr, W_hr)
193
+ dims_lr = (H_lr, W_lr)
194
+
195
+ features_hr: Dict[str, torch.Tensor] = {}
196
+ features_lr: Dict[str, torch.Tensor] = {}
197
+
198
+ # Process through encoder stages with cross-attention fusion
199
+ for idx, (stage_hr, stage_lr, ca) in enumerate(zip(
200
+ self.encoder_layers_hr, self.encoder_layers_lr, self.cross_attn_blocks
201
+ )):
202
+ # HR branch blocks
203
+ for block in stage_hr.blocks:
204
+ x_hr_seq = block(x_hr_seq, dims_hr)
205
+ if isinstance(x_hr_seq, tuple):
206
+ x_hr_seq = x_hr_seq[0]
207
+
208
+ # LR branch blocks
209
+ for block in stage_lr.blocks:
210
+ x_lr_seq = block(x_lr_seq, dims_lr)
211
+ if isinstance(x_lr_seq, tuple):
212
+ x_lr_seq = x_lr_seq[0]
213
+
214
+ # Layer normalization
215
+ x_hr_seq = self.hidden_states_norms_hr[f"stage{idx+1}"](x_hr_seq)
216
+ x_lr_seq = self.hidden_states_norms_lr[f"stage{idx+1}"](x_lr_seq)
217
+
218
+ H_hr, W_hr = dims_hr
219
+ H_lr, W_lr = dims_lr
220
+ C_hr = x_hr_seq.shape[-1]
221
+ C_lr = x_lr_seq.shape[-1]
222
+
223
+ # Reshape to spatial format
224
+ feat_hr = x_hr_seq.transpose(1, 2).contiguous().view(B, C_hr, H_hr, W_hr)
225
+ feat_lr = x_lr_seq.transpose(1, 2).contiguous().view(B, C_lr, H_lr, W_lr)
226
+
227
+ fused_hr = ca(feat_hr, feat_lr)
228
+ fused_hr_seq = fused_hr.flatten(2).transpose(1, 2).contiguous()
229
+
230
+ # Downsample if stage has it
231
+ if stage_hr.downsample is not None:
232
+ fused_hr_seq = stage_hr.downsample(fused_hr_seq, dims_hr)
233
+ dims_hr = (dims_hr[0] // 2, dims_hr[1] // 2)
234
+ if stage_lr.downsample is not None:
235
+ x_lr_seq = stage_lr.downsample(x_lr_seq, dims_lr)
236
+ dims_lr = (dims_lr[0] // 2, dims_lr[1] // 2)
237
+
238
+ features_hr[f"stage{idx+1}"] = fused_hr
239
+ features_lr[f"stage{idx+1}"] = feat_lr
240
+ x_hr_seq = fused_hr_seq
241
+
242
+ # Decode HR features
243
+ features_tuple = (
244
+ features_hr["stage1"],
245
+ features_hr["stage2"],
246
+ features_hr["stage3"],
247
+ features_hr["stage4"],
248
+ )
249
+ logits = self.decoder(features_tuple)
250
+
251
+ # Decode LR features (for auxiliary supervision)
252
+ features_tuple_lr = (
253
+ features_lr["stage1"],
254
+ features_lr["stage2"],
255
+ features_lr["stage3"],
256
+ features_lr["stage4"],
257
+ )
258
+ logits_lr = self.decoder_lr(features_tuple_lr)
259
+
260
+ return {"logits_hr": logits, "logits_lr": logits_lr}
261
+
262
+
263
+
264
+ def _test_segformer_head():
265
+ """Quick, offline test for SegFormer head input/output shapes."""
266
+ # Example dims for Swin-Tiny stages:
267
+ dims = [96, 192, 384, 768]
268
+ cfg = SegformerConfig(
269
+ num_labels=7,
270
+ hidden_sizes=dims,
271
+ num_encoder_blocks=4,
272
+ decoder_hidden_size=512,
273
+ classifier_dropout_prob=0.0,
274
+ )
275
+ head = SegformerDecodeHead(cfg)
276
+
277
+ B = 2
278
+ # Stage resolutions typically differ by /2 each time; here we mimic that.
279
+ f1 = torch.randn(B, dims[0], 128, 128)
280
+ f2 = torch.randn(B, dims[1], 64, 64)
281
+ f3 = torch.randn(B, dims[2], 32, 32)
282
+ f4 = torch.randn(B, dims[3], 16, 16)
283
+
284
+ logits = head((f1, f2, f3, f4))
285
+ assert logits.shape == (B, cfg.num_labels, 128, 128), f"Unexpected logits shape: {logits.shape}"
286
+ return logits.shape
287
+
288
+
289
+ if __name__ == "__main__":
290
+ print("SegFormer head test logits shape:", _test_segformer_head())
model/CASWiT_ssl.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CASWiT Self-Supervised Learning (SSL) Module
3
+
4
+ Implements SimMIM-based self-supervised pre-training for CASWiT using
5
+ masked image modeling with dual-branch HR/LR processing.
6
+ """
7
+
8
+ import math
9
+ from typing import Optional, Tuple
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from transformers import UperNetForSemanticSegmentation
14
+ from transformers.utils import logging as hf_logging
15
+
16
+ hf_logging.set_verbosity_error()
17
+ hf_logging.disable_progress_bar()
18
+
19
+
20
+
21
+ def random_masking_with_tokens(x: torch.Tensor, mask_ratio: float = 0.75,
22
+ mask_token: Optional[torch.Tensor] = None):
23
+ """
24
+ Random masking at token level with learned mask token.
25
+
26
+ Args:
27
+ x: Input tokens [B, N, C]
28
+ mask_ratio: Ratio of tokens to mask
29
+ mask_token: Learnable mask token
30
+
31
+ Returns:
32
+ x_masked: Masked tokens [B, N, C]
33
+ mask: Binary mask [B, N] where 0=visible, 1=masked
34
+ ids_restore: Indices to restore original order
35
+ """
36
+ B, N, C = x.shape
37
+ len_keep = int(N * (1 - mask_ratio))
38
+
39
+ noise = torch.rand(B, N, device=x.device)
40
+ ids_shuffle = torch.argsort(noise, dim=1)
41
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
42
+ ids_keep = ids_shuffle[:, :len_keep]
43
+
44
+ x_keep = torch.gather(x, 1, ids_keep.unsqueeze(-1).expand(-1, -1, C))
45
+
46
+ if mask_token is None:
47
+ mask_token = torch.zeros((1, C), device=x.device)
48
+ m_tok = mask_token.view(1, 1, C).expand(B, N - len_keep, C)
49
+
50
+ x_cat = torch.cat([x_keep, m_tok], dim=1)
51
+ x_masked = torch.gather(x_cat, 1, ids_restore.unsqueeze(-1).expand(-1, -1, C))
52
+
53
+ mask = torch.ones(B, N, device=x.device)
54
+ mask[:, :len_keep] = 0
55
+ mask = torch.gather(mask, 1, ids_restore)
56
+ return x_masked, mask, ids_restore
57
+
58
+
59
+ def center_masking_with_tokens(x: torch.Tensor, mask_token: Optional[torch.Tensor] = None,
60
+ mask_ratio: float = 0.5):
61
+ """
62
+ Deterministic centered square mask.
63
+
64
+ Args:
65
+ x: Input tokens [B, N, C]
66
+ mask_token: Learnable mask token
67
+ mask_ratio: Ratio of tokens to mask
68
+
69
+ Returns:
70
+ x_masked: Masked tokens [B, N, C]
71
+ mask: Binary mask [B, N]
72
+ ids_restore: Indices to restore original order
73
+ """
74
+ B, N, C = x.shape
75
+ H = W = int(N**0.5)
76
+ assert H * W == N, "N must be a perfect square"
77
+ L = int(round(H * (mask_ratio ** 0.5)))
78
+ start = (H - L) // 2
79
+ end = start + L
80
+
81
+ mask_2d = torch.zeros(H, W, device=x.device, dtype=torch.bool)
82
+ mask_2d[start:end, start:end] = True
83
+ mask = mask_2d.view(1, -1).expand(B, -1) # (B,N)
84
+
85
+ if mask_token is None:
86
+ mask_token = torch.zeros(C, device=x.device)
87
+ mask_token = mask_token.view(-1)
88
+
89
+ x_masked = x * (~mask).unsqueeze(-1) + mask.unsqueeze(-1) * mask_token.view(1, 1, C)
90
+ ids_restore = torch.arange(N, device=x.device).unsqueeze(0).expand(B, N)
91
+ return x_masked, mask.to(x_masked.dtype), ids_restore
92
+
93
+
94
+ class CrossAttentionBlock(nn.Module):
95
+ """Simplified cross-attention block for SSL."""
96
+ def __init__(self, C_hr, C_lr, num_heads=8, dropout=0.0):
97
+ super().__init__()
98
+ self.cross_attn = nn.MultiheadAttention(
99
+ embed_dim=C_hr, num_heads=num_heads, kdim=C_lr, vdim=C_lr,
100
+ dropout=dropout, batch_first=True
101
+ )
102
+ self.norm = nn.LayerNorm(C_hr)
103
+ self.mlp = nn.Sequential(
104
+ nn.LayerNorm(C_hr),
105
+ nn.Linear(C_hr, C_hr * 4),
106
+ nn.GELU(),
107
+ nn.Linear(C_hr * 4, C_hr),
108
+ )
109
+
110
+ def forward(self, x_hr, x_lr):
111
+ B, C_hr, H_hr, W_hr = x_hr.shape
112
+ _, C_lr, H_lr, W_lr = x_lr.shape
113
+ q = x_hr.flatten(2).transpose(1, 2) # (B,N_hr,C_hr)
114
+ kv = x_lr.flatten(2).transpose(1, 2) # (B,N_lr,C_lr)
115
+ attn_out, _ = self.cross_attn(q, kv, kv)
116
+ y = self.norm(q + attn_out)
117
+ y = y + self.mlp(y)
118
+ return y.transpose(1, 2).view(B, C_hr, H_hr, W_hr)
119
+
120
+
121
+ class CASWiT_SSL(nn.Module):
122
+ """
123
+ CASWiT Self-Supervised Learning model using SimMIM.
124
+
125
+ Encoder: Dual Swin backbones with cross-attention blocks
126
+ Decoder: Conv1x1 + PixelShuffle for reconstruction
127
+ Masking: HR random masking, LR center masking
128
+
129
+ Args:
130
+ model_name: HuggingFace model identifier
131
+ mask_ratio_hr: Masking ratio for HR branch
132
+ mask_ratio_lr: Masking ratio for LR branch
133
+ patch_size: Patch size for masking
134
+ encoder_stride: Encoder stride for decoder
135
+ xa_heads: Number of cross-attention heads per stage
136
+ """
137
+ def __init__(self, model_name: str = "openmmlab/upernet-swin-base",
138
+ mask_ratio_hr: float = 0.75, mask_ratio_lr: float = 0.5,
139
+ patch_size: int = 4, encoder_stride: int = 32,
140
+ xa_heads: Tuple[int, int, int, int] = (8, 8, 8, 8)):
141
+ super().__init__()
142
+ self.mask_ratio_hr = mask_ratio_hr
143
+ self.mask_ratio_lr = mask_ratio_lr
144
+ self.patch_size = patch_size
145
+ self.encoder_stride = encoder_stride
146
+
147
+ # Load two UPerNet (Swin) backbones
148
+ model_hr = UperNetForSemanticSegmentation.from_pretrained(
149
+ model_name, ignore_mismatched_sizes=True
150
+ )
151
+ model_lr = UperNetForSemanticSegmentation.from_pretrained(
152
+ model_name, ignore_mismatched_sizes=True
153
+ )
154
+
155
+ self.embeddings_hr = model_hr.backbone.embeddings
156
+ self.encoder_layers_hr = model_hr.backbone.encoder.layers
157
+ self.hidden_states_norms_hr = model_hr.backbone.hidden_states_norms
158
+
159
+ self.embeddings_lr = model_lr.backbone.embeddings
160
+ self.encoder_layers_lr = model_lr.backbone.encoder.layers
161
+ self.hidden_states_norms_lr = model_lr.backbone.hidden_states_norms
162
+
163
+ # Cross-attention blocks with explicit Swin-Base dims
164
+ dims = [128, 256, 512, 1024]
165
+ self.cross_attn_blocks = nn.ModuleList([
166
+ CrossAttentionBlock(d, d, num_heads=h) for d, h in zip(dims, xa_heads)
167
+ ])
168
+
169
+ # Learnable mask tokens
170
+ self.mask_token_hr = nn.Parameter(torch.zeros(1, dims[0]))
171
+ self.mask_token_lr = nn.Parameter(torch.zeros(1, dims[0]))
172
+
173
+ # SimMIM decoder: Conv1×1 → PixelShuffle(stride)
174
+ self.decoder_conv = None # lazy init after we know C_last
175
+ self.decoder_shuffle = nn.PixelShuffle(self.encoder_stride)
176
+
177
+ # Store masks for visualization
178
+ self.last_mask_hr = None
179
+ self.last_mask_lr = None
180
+
181
+ def _encode(self, x_hr: torch.Tensor, x_lr: torch.Tensor):
182
+ """Encode with masking and return reconstruction targets."""
183
+ B, C, H, W = x_hr.shape
184
+ target_img = x_hr
185
+ target_lr = x_lr
186
+
187
+ # Patch embeddings
188
+ x_hr_seq, _ = self.embeddings_hr(x_hr) # (B, N_hr, C1)
189
+ x_lr_seq, _ = self.embeddings_lr(x_lr) # (B, N_lr, C1)
190
+
191
+ # Masking
192
+ x_hr_seq, mask_hr, _ = random_masking_with_tokens(
193
+ x_hr_seq, self.mask_ratio_hr, self.mask_token_hr
194
+ )
195
+ x_lr_seq, mask_lr, _ = center_masking_with_tokens(
196
+ x_lr_seq, self.mask_token_lr, mask_ratio=self.mask_ratio_lr
197
+ )
198
+
199
+ # Initial spatial dims
200
+ H_hr = W_hr = int(math.sqrt(x_hr_seq.shape[1]))
201
+ H_lr = W_lr = int(math.sqrt(x_lr_seq.shape[1]))
202
+ dims_hr = (H_hr, W_hr)
203
+ dims_lr = (H_lr, W_lr)
204
+
205
+ # Walk encoder stages with cross attention at each stage
206
+ for idx, (stage_hr, stage_lr, ca) in enumerate(zip(
207
+ self.encoder_layers_hr, self.encoder_layers_lr, self.cross_attn_blocks
208
+ )):
209
+ # HR blocks
210
+ for block in stage_hr.blocks:
211
+ x_hr_seq = block(x_hr_seq, dims_hr)
212
+ if isinstance(x_hr_seq, tuple):
213
+ x_hr_seq = x_hr_seq[0]
214
+ # LR blocks
215
+ for block in stage_lr.blocks:
216
+ x_lr_seq = block(x_lr_seq, dims_lr)
217
+ if isinstance(x_lr_seq, tuple):
218
+ x_lr_seq = x_lr_seq[0]
219
+
220
+ # Norms
221
+ x_hr_seq = self.hidden_states_norms_hr[f"stage{idx+1}"](x_hr_seq)
222
+ x_lr_seq = self.hidden_states_norms_lr[f"stage{idx+1}"](x_lr_seq)
223
+
224
+ # Maps
225
+ B_, N_hr_, C_hr_ = x_hr_seq.shape
226
+ B_, N_lr_, C_lr_ = x_lr_seq.shape
227
+ Hh, Wh = dims_hr
228
+ Hl, Wl = dims_lr
229
+ feat_hr = x_hr_seq.transpose(1, 2).contiguous().view(B_, C_hr_, Hh, Wh)
230
+ feat_lr = x_lr_seq.transpose(1, 2).contiguous().view(B_, C_lr_, Hl, Wl)
231
+
232
+ # Cross-fuse HR <- LR
233
+ fused_hr = ca(feat_hr, feat_lr)
234
+ x_hr_seq = fused_hr.flatten(2).transpose(1, 2).contiguous()
235
+
236
+ # Downsample to next stage
237
+ if stage_hr.downsample is not None:
238
+ x_hr_seq = stage_hr.downsample(x_hr_seq, dims_hr)
239
+ dims_hr = (dims_hr[0] // 2, dims_hr[1] // 2)
240
+ if stage_lr.downsample is not None:
241
+ x_lr_seq = stage_lr.downsample(x_lr_seq, dims_lr)
242
+ dims_lr = (dims_lr[0] // 2, dims_lr[1] // 2)
243
+
244
+ # Last-stage feature map z (B, C_last, H/stride, W/stride)
245
+ Hs, Ws = dims_hr
246
+ C_last = x_hr_seq.shape[-1]
247
+ z = x_hr_seq.transpose(1, 2).contiguous().view(B, C_last, Hs, Ws)
248
+
249
+ # Lazy init decoder conv
250
+ if self.decoder_conv is None:
251
+ self.decoder_conv = nn.Conv2d(
252
+ C_last, (self.encoder_stride ** 2) * 3, kernel_size=1
253
+ ).to(z.device)
254
+
255
+ # Reconstruction
256
+ x_rec = self.decoder_shuffle(self.decoder_conv(z)) # (B,3,H,W)
257
+
258
+ # Convert patch masks to pixel masks
259
+ Mh = int(math.sqrt(mask_hr.shape[1]))
260
+ mask_patch_hr = mask_hr.view(B, Mh, Mh)
261
+ mask_pix_hr = mask_patch_hr.repeat_interleave(
262
+ self.patch_size, 1
263
+ ).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous()
264
+
265
+ Ml = int(math.sqrt(mask_lr.shape[1]))
266
+ mask_patch_lr = mask_lr.view(B, Ml, Ml)
267
+ mask_pix_lr = mask_patch_lr.repeat_interleave(
268
+ self.patch_size, 1
269
+ ).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous()
270
+
271
+ self.last_mask_hr = mask_patch_hr
272
+ self.last_mask_lr = mask_patch_lr
273
+
274
+ return x_rec, target_img, mask_pix_hr, target_lr, mask_pix_lr
275
+
276
+ def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> torch.Tensor:
277
+ """
278
+ Forward pass for SSL training.
279
+
280
+ Returns reconstruction loss on masked pixels only.
281
+ """
282
+ x_rec, target_img, mask_pix, _, _ = self._encode(x_hr, x_lr)
283
+ loss_recon = F.l1_loss(target_img, x_rec, reduction='none')
284
+ loss = (loss_recon * mask_pix).sum() / (mask_pix.sum() + 1e-6) / target_img.shape[1]
285
+ return loss
286
+
287
+ @torch.no_grad()
288
+ def forward_outputs(self, x_hr: torch.Tensor, x_lr: torch.Tensor):
289
+ """Forward pass returning all outputs for visualization."""
290
+ x_rec, target_img, mask_pix_hr, target_lr, mask_pix_lr = self._encode(x_hr, x_lr)
291
+ return x_rec, target_img, mask_pix_hr, target_lr, mask_pix_lr
292
+
model/CASWiT_upernet.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CASWiT: Context-Aware Swin Transformer for Ultra-High Resolution Semantic Segmentation
3
+
4
+ This module implements the main CASWiT model architecture with dual-branch
5
+ high-resolution and low-resolution processing with cross-attention fusion.
6
+ """
7
+
8
+ import math
9
+ from typing import Dict
10
+ import torch
11
+ import torch.nn as nn
12
+ from transformers import UperNetForSemanticSegmentation
13
+ from transformers.utils import logging as hf_logging
14
+
15
+ hf_logging.set_verbosity_error()
16
+ hf_logging.disable_progress_bar()
17
+
18
+
19
+
20
+ class DropPath(nn.Module):
21
+ """Drop path (stochastic depth) regularization module."""
22
+ def __init__(self, drop_prob: float = 0.0):
23
+ super().__init__()
24
+ self.drop_prob = float(drop_prob)
25
+
26
+ def forward(self, x):
27
+ if self.drop_prob == 0.0 or (not self.training):
28
+ return x
29
+ keep = 1.0 - self.drop_prob
30
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
31
+ mask = x.new_empty(shape).bernoulli_(keep).div_(keep)
32
+ return x * mask
33
+
34
+
35
+ class CrossFusionBlock(nn.Module):
36
+ """
37
+ Cross-attention fusion block that enables HR features to attend to LR features.
38
+
39
+ Implements pre-norm cross-attention (Q=HR, K/V=LR).
40
+
41
+ Args:
42
+ C_hr: Channel dimension of HR features
43
+ C_lr: Channel dimension of LR features
44
+ num_heads: Number of attention heads
45
+ mlp_ratio: MLP expansion ratio
46
+ drop: Dropout rate
47
+ drop_path: Drop path rate
48
+ """
49
+ def __init__(self, C_hr: int, C_lr: int, num_heads: int = 8,
50
+ mlp_ratio: float = 4.0, drop: float = 0.0, drop_path: float = 0.1):
51
+ super().__init__()
52
+
53
+ self.norm_q = nn.LayerNorm(C_hr)
54
+ self.norm_kv = nn.LayerNorm(C_lr)
55
+ self.attn = nn.MultiheadAttention(
56
+ embed_dim=C_hr, num_heads=num_heads, kdim=C_lr, vdim=C_lr,
57
+ dropout=drop, batch_first=True
58
+ )
59
+
60
+ hidden = int(C_hr * mlp_ratio)
61
+ self.mlp = nn.Sequential(
62
+ nn.LayerNorm(C_hr),
63
+ nn.Linear(C_hr, hidden),
64
+ nn.GELU(),
65
+ nn.Linear(hidden, C_hr),
66
+ )
67
+
68
+ def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> torch.Tensor:
69
+ """
70
+ Forward pass through cross-attention fusion block.
71
+
72
+ Args:
73
+ x_hr: HR features [B, C_hr, H_hr, W_hr]
74
+ x_lr: LR features [B, C_lr, H_lr, W_lr]
75
+
76
+ Returns:
77
+ Fused HR features [B, C_hr, H_hr, W_hr]
78
+ """
79
+ B, C_hr, H_hr, W_hr = x_hr.shape
80
+ _, C_lr, H_lr, W_lr = x_lr.shape
81
+
82
+ # Flatten to sequences
83
+ q = x_hr.flatten(2).transpose(1, 2) # [B, N_hr, C_hr]
84
+ kv = x_lr.flatten(2).transpose(1, 2) # [B, N_lr, C_lr]
85
+
86
+ # Pre-norm
87
+ qn = self.norm_q(q)
88
+ kvn = self.norm_kv(kv)
89
+
90
+ attn_out, _ = self.attn(qn, kvn, kvn) # [B, N_hr, C_hr]
91
+
92
+ # Residual connection + MLP
93
+ y = q + attn_out
94
+ y = y + self.mlp(y)
95
+
96
+ return y.transpose(1, 2).view(B, C_hr, H_hr, W_hr)
97
+
98
+
99
+ class CASWiT(nn.Module):
100
+ """
101
+ CASWiT: Context-Aware Swin Transformer for Ultra-High Resolution Semantic Segmentation.
102
+
103
+ Dual-branch architecture with:
104
+ - HR branch: Processes high-resolution crops
105
+ - LR branch: Processes low-resolution context
106
+ - Cross-attention fusion at each encoder stage
107
+
108
+ Args:
109
+ num_head_xa: Number of cross-attention heads
110
+ num_classes: Number of segmentation classes
111
+ model_name: HuggingFace model identifier for UPerNet-Swin
112
+ mlp_ratio: MLP expansion ratio in fusion blocks
113
+ drop_path: Drop path rate
114
+ """
115
+ def __init__(self, num_head_xa: int = 1, num_classes: int = 12,
116
+ model_name: str = "openmmlab/upernet-swin-tiny",
117
+ mlp_ratio: float = 4.0, drop_path: float = 0.1):
118
+ super().__init__()
119
+ # Load two UPerNet backbones (HR and LR branches)
120
+ model_hr = UperNetForSemanticSegmentation.from_pretrained(
121
+ model_name, num_labels=num_classes, ignore_mismatched_sizes=True
122
+ )
123
+ model_lr = UperNetForSemanticSegmentation.from_pretrained(
124
+ model_name, num_labels=num_classes, ignore_mismatched_sizes=True
125
+ )
126
+
127
+ # Extract HR branch components
128
+ self.embeddings_hr = model_hr.backbone.embeddings
129
+ self.encoder_layers_hr = model_hr.backbone.encoder.layers
130
+ self.hidden_states_norms_hr = model_hr.backbone.hidden_states_norms
131
+ self.decoder = model_hr.decode_head
132
+
133
+ # Extract LR branch components
134
+ self.embeddings_lr = model_lr.backbone.embeddings
135
+ self.encoder_layers_lr = model_lr.backbone.encoder.layers
136
+ self.hidden_states_norms_lr = model_lr.backbone.hidden_states_norms
137
+ self.decoder_lr = model_lr.decode_head
138
+
139
+ # Cross-attention blocks at each stage
140
+ # Dimensions: tiny:[96, 192, 384, 768] base:[128, 256, 512, 1024] large:[192, 384, 768, 1536]
141
+ dims_map = {
142
+ "tiny": [96, 192, 384, 768],
143
+ "base": [128, 256, 512, 1024],
144
+ "large": [192, 384, 768, 1536]
145
+ }
146
+ # Infer dimensions from model name
147
+ if "tiny" in model_name.lower():
148
+ dims = dims_map["tiny"]
149
+ elif "large" in model_name.lower():
150
+ dims = dims_map["large"]
151
+ else:
152
+ dims = dims_map["base"] # default to base
153
+
154
+ self.cross_attn_blocks = nn.ModuleList([
155
+ CrossFusionBlock(dim, dim, num_heads=num_head_xa,
156
+ mlp_ratio=mlp_ratio, drop=0.0, drop_path=drop_path)
157
+ for dim in dims
158
+ ])
159
+
160
+ def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> Dict[str, torch.Tensor]:
161
+ """
162
+ Forward pass through CASWiT model.
163
+
164
+ Args:
165
+ x_hr: HR input images [B, 3, H_hr, W_hr]
166
+ x_lr: LR input images [B, 3, H_lr, W_lr]
167
+
168
+ Returns:
169
+ Dictionary with 'logits_hr' and 'logits_lr' segmentation logits
170
+ """
171
+ B = x_hr.size(0)
172
+
173
+ # Patch embeddings
174
+ x_hr_seq, _ = self.embeddings_hr(x_hr)
175
+ x_lr_seq, _ = self.embeddings_lr(x_lr)
176
+
177
+ N_hr, C_hr = x_hr_seq.shape[1], x_hr_seq.shape[2]
178
+ N_lr, C_lr = x_lr_seq.shape[1], x_lr_seq.shape[2]
179
+ H_hr = W_hr = int(math.sqrt(N_hr))
180
+ H_lr = W_lr = int(math.sqrt(N_lr))
181
+ dims_hr = (H_hr, W_hr)
182
+ dims_lr = (H_lr, W_lr)
183
+
184
+ features_hr: Dict[str, torch.Tensor] = {}
185
+ features_lr: Dict[str, torch.Tensor] = {}
186
+
187
+ # Process through encoder stages with cross-attention fusion
188
+ for idx, (stage_hr, stage_lr, ca) in enumerate(zip(
189
+ self.encoder_layers_hr, self.encoder_layers_lr, self.cross_attn_blocks
190
+ )):
191
+ # HR branch blocks
192
+ for block in stage_hr.blocks:
193
+ x_hr_seq = block(x_hr_seq, dims_hr)
194
+ if isinstance(x_hr_seq, tuple):
195
+ x_hr_seq = x_hr_seq[0]
196
+
197
+ # LR branch blocks
198
+ for block in stage_lr.blocks:
199
+ x_lr_seq = block(x_lr_seq, dims_lr)
200
+ if isinstance(x_lr_seq, tuple):
201
+ x_lr_seq = x_lr_seq[0]
202
+
203
+ # Layer normalization
204
+ x_hr_seq = self.hidden_states_norms_hr[f"stage{idx+1}"](x_hr_seq)
205
+ x_lr_seq = self.hidden_states_norms_lr[f"stage{idx+1}"](x_lr_seq)
206
+
207
+ H_hr, W_hr = dims_hr
208
+ H_lr, W_lr = dims_lr
209
+ C_hr = x_hr_seq.shape[-1]
210
+ C_lr = x_lr_seq.shape[-1]
211
+
212
+ # Reshape to spatial format
213
+ feat_hr = x_hr_seq.transpose(1, 2).contiguous().view(B, C_hr, H_hr, W_hr)
214
+ feat_lr = x_lr_seq.transpose(1, 2).contiguous().view(B, C_lr, H_lr, W_lr)
215
+
216
+ fused_hr = ca(feat_hr, feat_lr)
217
+ fused_hr_seq = fused_hr.flatten(2).transpose(1, 2).contiguous()
218
+
219
+ # Downsample if stage has it
220
+ if stage_hr.downsample is not None:
221
+ fused_hr_seq = stage_hr.downsample(fused_hr_seq, dims_hr)
222
+ dims_hr = (dims_hr[0] // 2, dims_hr[1] // 2)
223
+ if stage_lr.downsample is not None:
224
+ x_lr_seq = stage_lr.downsample(x_lr_seq, dims_lr)
225
+ dims_lr = (dims_lr[0] // 2, dims_lr[1] // 2)
226
+
227
+ features_hr[f"stage{idx+1}"] = fused_hr
228
+ features_lr[f"stage{idx+1}"] = feat_lr
229
+ x_hr_seq = fused_hr_seq
230
+
231
+ # Decode HR features
232
+ features_tuple = (
233
+ features_hr["stage1"],
234
+ features_hr["stage2"],
235
+ features_hr["stage3"],
236
+ features_hr["stage4"],
237
+ )
238
+ logits = self.decoder(features_tuple)
239
+
240
+ # Decode LR features (for auxiliary supervision)
241
+ features_tuple_lr = (
242
+ features_lr["stage1"],
243
+ features_lr["stage2"],
244
+ features_lr["stage3"],
245
+ features_lr["stage4"],
246
+ )
247
+ logits_lr = self.decoder_lr(features_tuple_lr)
248
+
249
+ return {"logits_hr": logits, "logits_lr": logits_lr}
250
+
model/build_model.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from model.CASWiT_upernet import CASWiT as CASWiT_UperNet
6
+ from model.CASWiT_segformer import CASWiT as CASWiT_SegFormer
7
+ from model.CASWiT_m2f import CASWiT as CASWiT_Mask2Former
8
+ from model.CASWiT_fusion_last_stage_add import CASWiT as CASWiT_FusionLastStageAdd
9
+ from model.CASWiT_ssl import CASWiT_SSL
10
+
11
+
12
+ def _get(cfg: Any, name: str, default: Any = None) -> Any:
13
+ return getattr(cfg, name, default)
14
+
15
+
16
+ def build_model(cfg: Any):
17
+ head = _get(cfg, "model_head", None) or _get(cfg, "head", None) or "upernet"
18
+ head = str(head).lower()
19
+
20
+ common = dict(
21
+ num_head_xa=int(_get(cfg, "cross_attention_heads")),
22
+ num_classes=int(_get(cfg, "num_classes")),
23
+ model_name=str(_get(cfg, "model_name")),
24
+ mlp_ratio=float(_get(cfg, "fusion_mlp_ratio")),
25
+ drop_path=float(_get(cfg, "fusion_drop_path")),
26
+ )
27
+
28
+ if head in ("upernet", "caswit", "default"):
29
+ return CASWiT_UperNet(**common)
30
+ if head in ("segformer",):
31
+ return CASWiT_SegFormer(**common)
32
+ if head in ("mask2former", "m2f"):
33
+ return CASWiT_Mask2Former(**common)
34
+ if head in ("fusion_last_stage_add", "last_stage_add"):
35
+ return CASWiT_FusionLastStageAdd(**common)
36
+ if head in ("ssl", "caswit_ssl"):
37
+ return CASWiT_SSL(model_name=str(_get(cfg, "model_name")))
38
+ raise ValueError(f"Unknown model head: {head}. Available: upernet, segformer, mask2former, fusion_last_stage_add, ssl")