antoine.carreaud67 commited on
Commit ·
d43c376
1
Parent(s): 9367521
Update with new experiments
Browse files- model/CASWiT_fusion_last_stage_add.py +250 -0
- model/CASWiT_m2f.py +354 -0
- model/CASWiT_segformer.py +290 -0
- model/CASWiT_ssl.py +292 -0
- model/CASWiT_upernet.py +250 -0
- model/build_model.py +38 -0
model/CASWiT_fusion_last_stage_add.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CASWiT: Context-Aware Swin Transformer for Ultra-High Resolution Semantic Segmentation
|
| 3 |
+
|
| 4 |
+
This module implements the main CASWiT model architecture with dual-branch
|
| 5 |
+
high-resolution and low-resolution processing with cross-attention fusion.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
from typing import Dict
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from transformers import UperNetForSemanticSegmentation
|
| 13 |
+
from transformers.utils import logging as hf_logging
|
| 14 |
+
|
| 15 |
+
hf_logging.set_verbosity_error()
|
| 16 |
+
hf_logging.disable_progress_bar()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class DropPath(nn.Module):
|
| 21 |
+
"""Drop path (stochastic depth) regularization module."""
|
| 22 |
+
def __init__(self, drop_prob: float = 0.0):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.drop_prob = float(drop_prob)
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
if self.drop_prob == 0.0 or (not self.training):
|
| 28 |
+
return x
|
| 29 |
+
keep = 1.0 - self.drop_prob
|
| 30 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
| 31 |
+
mask = x.new_empty(shape).bernoulli_(keep).div_(keep)
|
| 32 |
+
return x * mask
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class CrossFusionBlock(nn.Module):
|
| 36 |
+
"""
|
| 37 |
+
Cross-attention fusion block that enables HR features to attend to LR features.
|
| 38 |
+
|
| 39 |
+
Implements pre-norm cross-attention (Q=HR, K/V=LR).
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
C_hr: Channel dimension of HR features
|
| 43 |
+
C_lr: Channel dimension of LR features
|
| 44 |
+
num_heads: Number of attention heads
|
| 45 |
+
mlp_ratio: MLP expansion ratio
|
| 46 |
+
drop: Dropout rate
|
| 47 |
+
drop_path: Drop path rate
|
| 48 |
+
"""
|
| 49 |
+
def __init__(self, C_hr: int, C_lr: int, num_heads: int = 8,
|
| 50 |
+
mlp_ratio: float = 4.0, drop: float = 0.0, drop_path: float = 0.1):
|
| 51 |
+
super().__init__()
|
| 52 |
+
|
| 53 |
+
self.norm_q = nn.LayerNorm(C_hr)
|
| 54 |
+
self.norm_kv = nn.LayerNorm(C_lr)
|
| 55 |
+
self.attn = nn.MultiheadAttention(
|
| 56 |
+
embed_dim=C_hr, num_heads=num_heads, kdim=C_lr, vdim=C_lr,
|
| 57 |
+
dropout=drop, batch_first=True
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
hidden = int(C_hr * mlp_ratio)
|
| 61 |
+
self.mlp = nn.Sequential(
|
| 62 |
+
nn.LayerNorm(C_hr),
|
| 63 |
+
nn.Linear(C_hr, hidden),
|
| 64 |
+
nn.GELU(),
|
| 65 |
+
nn.Linear(hidden, C_hr),
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> torch.Tensor:
|
| 69 |
+
"""
|
| 70 |
+
Forward pass through cross-attention fusion block.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
x_hr: HR features [B, C_hr, H_hr, W_hr]
|
| 74 |
+
x_lr: LR features [B, C_lr, H_lr, W_lr]
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
Fused HR features [B, C_hr, H_hr, W_hr]
|
| 78 |
+
"""
|
| 79 |
+
B, C_hr, H_hr, W_hr = x_hr.shape
|
| 80 |
+
_, C_lr, H_lr, W_lr = x_lr.shape
|
| 81 |
+
|
| 82 |
+
# Flatten to sequences
|
| 83 |
+
q = x_hr.flatten(2).transpose(1, 2) # [B, N_hr, C_hr]
|
| 84 |
+
kv = x_lr.flatten(2).transpose(1, 2) # [B, N_lr, C_lr]
|
| 85 |
+
|
| 86 |
+
# Pre-norm
|
| 87 |
+
qn = self.norm_q(q)
|
| 88 |
+
kvn = self.norm_kv(kv)
|
| 89 |
+
|
| 90 |
+
attn_out, _ = self.attn(qn, kvn, kvn) # [B, N_hr, C_hr]
|
| 91 |
+
|
| 92 |
+
# Residual connection + MLP
|
| 93 |
+
y = q + attn_out
|
| 94 |
+
y = y + self.mlp(y)
|
| 95 |
+
|
| 96 |
+
return y.transpose(1, 2).view(B, C_hr, H_hr, W_hr)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class CASWiT(nn.Module):
|
| 100 |
+
"""
|
| 101 |
+
CASWiT: Context-Aware Swin Transformer for Ultra-High Resolution Semantic Segmentation.
|
| 102 |
+
|
| 103 |
+
Dual-branch architecture with:
|
| 104 |
+
- HR branch: Processes high-resolution crops
|
| 105 |
+
- LR branch: Processes low-resolution context
|
| 106 |
+
- Cross-attention fusion at each encoder stage
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
num_head_xa: Number of cross-attention heads
|
| 110 |
+
num_classes: Number of segmentation classes
|
| 111 |
+
model_name: HuggingFace model identifier for UPerNet-Swin
|
| 112 |
+
mlp_ratio: MLP expansion ratio in fusion blocks
|
| 113 |
+
drop_path: Drop path rate
|
| 114 |
+
"""
|
| 115 |
+
def __init__(self, num_head_xa: int = 1, num_classes: int = 12,
|
| 116 |
+
model_name: str = "openmmlab/upernet-swin-tiny",
|
| 117 |
+
mlp_ratio: float = 4.0, drop_path: float = 0.1):
|
| 118 |
+
super().__init__()
|
| 119 |
+
# Load two UPerNet backbones (HR and LR branches)
|
| 120 |
+
model_hr = UperNetForSemanticSegmentation.from_pretrained(
|
| 121 |
+
model_name, num_labels=num_classes, ignore_mismatched_sizes=True
|
| 122 |
+
)
|
| 123 |
+
model_lr = UperNetForSemanticSegmentation.from_pretrained(
|
| 124 |
+
model_name, num_labels=num_classes, ignore_mismatched_sizes=True
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Extract HR branch components
|
| 128 |
+
self.embeddings_hr = model_hr.backbone.embeddings
|
| 129 |
+
self.encoder_layers_hr = model_hr.backbone.encoder.layers
|
| 130 |
+
self.hidden_states_norms_hr = model_hr.backbone.hidden_states_norms
|
| 131 |
+
self.decoder = model_hr.decode_head
|
| 132 |
+
|
| 133 |
+
# Extract LR branch components
|
| 134 |
+
self.embeddings_lr = model_lr.backbone.embeddings
|
| 135 |
+
self.encoder_layers_lr = model_lr.backbone.encoder.layers
|
| 136 |
+
self.hidden_states_norms_lr = model_lr.backbone.hidden_states_norms
|
| 137 |
+
self.decoder_lr = model_lr.decode_head
|
| 138 |
+
|
| 139 |
+
# Cross-attention blocks at each stage
|
| 140 |
+
# Dimensions: tiny:[96, 192, 384, 768] base:[128, 256, 512, 1024] large:[192, 384, 768, 1536]
|
| 141 |
+
dims_map = {
|
| 142 |
+
"tiny": [96, 192, 384, 768],
|
| 143 |
+
"base": [128, 256, 512, 1024],
|
| 144 |
+
"large": [192, 384, 768, 1536]
|
| 145 |
+
}
|
| 146 |
+
# Infer dimensions from model name
|
| 147 |
+
if "tiny" in model_name.lower():
|
| 148 |
+
dims = dims_map["tiny"]
|
| 149 |
+
elif "large" in model_name.lower():
|
| 150 |
+
dims = dims_map["large"]
|
| 151 |
+
else:
|
| 152 |
+
dims = dims_map["base"] # default to base
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> Dict[str, torch.Tensor]:
|
| 156 |
+
"""
|
| 157 |
+
Forward pass through CASWiT model.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
x_hr: HR input images [B, 3, H_hr, W_hr]
|
| 161 |
+
x_lr: LR input images [B, 3, H_lr, W_lr]
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
Dictionary with 'logits_hr' and 'logits_lr' segmentation logits
|
| 165 |
+
"""
|
| 166 |
+
B = x_hr.size(0)
|
| 167 |
+
|
| 168 |
+
# Patch embeddings
|
| 169 |
+
x_hr_seq, _ = self.embeddings_hr(x_hr)
|
| 170 |
+
x_lr_seq, _ = self.embeddings_lr(x_lr)
|
| 171 |
+
|
| 172 |
+
N_hr, C_hr = x_hr_seq.shape[1], x_hr_seq.shape[2]
|
| 173 |
+
N_lr, C_lr = x_lr_seq.shape[1], x_lr_seq.shape[2]
|
| 174 |
+
H_hr = W_hr = int(math.sqrt(N_hr))
|
| 175 |
+
H_lr = W_lr = int(math.sqrt(N_lr))
|
| 176 |
+
dims_hr = (H_hr, W_hr)
|
| 177 |
+
dims_lr = (H_lr, W_lr)
|
| 178 |
+
|
| 179 |
+
features_hr: Dict[str, torch.Tensor] = {}
|
| 180 |
+
features_lr: Dict[str, torch.Tensor] = {}
|
| 181 |
+
|
| 182 |
+
# Process through encoder stages with cross-attention fusion
|
| 183 |
+
for idx, (stage_hr, stage_lr) in enumerate(zip(
|
| 184 |
+
self.encoder_layers_hr, self.encoder_layers_lr
|
| 185 |
+
)):
|
| 186 |
+
# HR branch blocks
|
| 187 |
+
for block in stage_hr.blocks:
|
| 188 |
+
x_hr_seq = block(x_hr_seq, dims_hr)
|
| 189 |
+
if isinstance(x_hr_seq, tuple):
|
| 190 |
+
x_hr_seq = x_hr_seq[0]
|
| 191 |
+
|
| 192 |
+
# LR branch blocks
|
| 193 |
+
for block in stage_lr.blocks:
|
| 194 |
+
x_lr_seq = block(x_lr_seq, dims_lr)
|
| 195 |
+
if isinstance(x_lr_seq, tuple):
|
| 196 |
+
x_lr_seq = x_lr_seq[0]
|
| 197 |
+
|
| 198 |
+
# Layer normalization
|
| 199 |
+
x_hr_seq = self.hidden_states_norms_hr[f"stage{idx+1}"](x_hr_seq)
|
| 200 |
+
x_lr_seq = self.hidden_states_norms_lr[f"stage{idx+1}"](x_lr_seq)
|
| 201 |
+
|
| 202 |
+
H_hr, W_hr = dims_hr
|
| 203 |
+
H_lr, W_lr = dims_lr
|
| 204 |
+
C_hr = x_hr_seq.shape[-1]
|
| 205 |
+
C_lr = x_lr_seq.shape[-1]
|
| 206 |
+
|
| 207 |
+
# Reshape to spatial format
|
| 208 |
+
feat_hr = x_hr_seq.transpose(1, 2).contiguous().view(B, C_hr, H_hr, W_hr)
|
| 209 |
+
feat_lr = x_lr_seq.transpose(1, 2).contiguous().view(B, C_lr, H_lr, W_lr)
|
| 210 |
+
|
| 211 |
+
# Cross-attend HR to LR
|
| 212 |
+
if idx == 3:
|
| 213 |
+
fused_hr = feat_hr + feat_lr
|
| 214 |
+
else:
|
| 215 |
+
fused_hr = feat_hr
|
| 216 |
+
#fused_hr = ca(feat_hr, feat_lr)
|
| 217 |
+
fused_hr_seq = fused_hr.flatten(2).transpose(1, 2).contiguous()
|
| 218 |
+
|
| 219 |
+
# Downsample if stage has it
|
| 220 |
+
if stage_hr.downsample is not None:
|
| 221 |
+
fused_hr_seq = stage_hr.downsample(fused_hr_seq, dims_hr)
|
| 222 |
+
dims_hr = (dims_hr[0] // 2, dims_hr[1] // 2)
|
| 223 |
+
if stage_lr.downsample is not None:
|
| 224 |
+
x_lr_seq = stage_lr.downsample(x_lr_seq, dims_lr)
|
| 225 |
+
dims_lr = (dims_lr[0] // 2, dims_lr[1] // 2)
|
| 226 |
+
|
| 227 |
+
features_hr[f"stage{idx+1}"] = fused_hr
|
| 228 |
+
features_lr[f"stage{idx+1}"] = feat_lr
|
| 229 |
+
x_hr_seq = fused_hr_seq
|
| 230 |
+
|
| 231 |
+
# Decode HR features
|
| 232 |
+
features_tuple = (
|
| 233 |
+
features_hr["stage1"],
|
| 234 |
+
features_hr["stage2"],
|
| 235 |
+
features_hr["stage3"],
|
| 236 |
+
features_hr["stage4"],
|
| 237 |
+
)
|
| 238 |
+
logits = self.decoder(features_tuple)
|
| 239 |
+
|
| 240 |
+
# Decode LR features (for auxiliary supervision)
|
| 241 |
+
features_tuple_lr = (
|
| 242 |
+
features_lr["stage1"],
|
| 243 |
+
features_lr["stage2"],
|
| 244 |
+
features_lr["stage3"],
|
| 245 |
+
features_lr["stage4"],
|
| 246 |
+
)
|
| 247 |
+
logits_lr = self.decoder_lr(features_tuple_lr)
|
| 248 |
+
|
| 249 |
+
return {"logits_hr": logits, "logits_lr": logits_lr}
|
| 250 |
+
|
model/CASWiT_m2f.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CASWiT with Mask2Former heads (HuggingFace).
|
| 3 |
+
|
| 4 |
+
This file is identical to the original CASWiT implementation except that:
|
| 5 |
+
- self.decoder and self.decoder_lr are replaced by a Mask2Former semantic head
|
| 6 |
+
implemented using HuggingFace's Mask2Former pixel decoder + transformer module.
|
| 7 |
+
|
| 8 |
+
The rest of the model (embeddings, Swin encoder stages, cross-attention fusion) is unchanged.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
from typing import Dict, Tuple, List
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
|
| 16 |
+
from transformers import UperNetForSemanticSegmentation, Mask2FormerConfig
|
| 17 |
+
from transformers.models.mask2former.modeling_mask2former import (
|
| 18 |
+
Mask2FormerPixelDecoder,
|
| 19 |
+
Mask2FormerTransformerModule,
|
| 20 |
+
)
|
| 21 |
+
from transformers.utils import logging as hf_logging
|
| 22 |
+
|
| 23 |
+
hf_logging.set_verbosity_error()
|
| 24 |
+
hf_logging.disable_progress_bar()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DropPath(nn.Module):
|
| 29 |
+
"""Drop path (stochastic depth) regularization module."""
|
| 30 |
+
def __init__(self, drop_prob: float = 0.0):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.drop_prob = float(drop_prob)
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
if self.drop_prob == 0.0 or (not self.training):
|
| 36 |
+
return x
|
| 37 |
+
keep = 1.0 - self.drop_prob
|
| 38 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
| 39 |
+
mask = x.new_empty(shape).bernoulli_(keep).div_(keep)
|
| 40 |
+
return x * mask
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class CrossFusionBlock(nn.Module):
|
| 44 |
+
"""
|
| 45 |
+
Cross-attention fusion block that enables HR features to attend to LR features.
|
| 46 |
+
Implements pre-norm cross-attention (Q=HR, K/V=LR).
|
| 47 |
+
"""
|
| 48 |
+
def __init__(self, C_hr: int, C_lr: int, num_heads: int = 8,
|
| 49 |
+
mlp_ratio: float = 4.0, drop: float = 0.0, drop_path: float = 0.1):
|
| 50 |
+
super().__init__()
|
| 51 |
+
|
| 52 |
+
self.norm_q = nn.LayerNorm(C_hr)
|
| 53 |
+
self.norm_kv = nn.LayerNorm(C_lr)
|
| 54 |
+
self.attn = nn.MultiheadAttention(
|
| 55 |
+
embed_dim=C_hr, num_heads=num_heads, kdim=C_lr, vdim=C_lr,
|
| 56 |
+
dropout=drop, batch_first=True
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
hidden = int(C_hr * mlp_ratio)
|
| 60 |
+
self.mlp = nn.Sequential(
|
| 61 |
+
nn.LayerNorm(C_hr),
|
| 62 |
+
nn.Linear(C_hr, hidden),
|
| 63 |
+
nn.GELU(),
|
| 64 |
+
nn.Linear(hidden, C_hr),
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> torch.Tensor:
|
| 68 |
+
B, C_hr, H_hr, W_hr = x_hr.shape
|
| 69 |
+
|
| 70 |
+
q = x_hr.flatten(2).transpose(1, 2) # [B, N_hr, C_hr]
|
| 71 |
+
kv = x_lr.flatten(2).transpose(1, 2) # [B, N_lr, C_lr]
|
| 72 |
+
|
| 73 |
+
qn = self.norm_q(q)
|
| 74 |
+
kvn = self.norm_kv(kv)
|
| 75 |
+
|
| 76 |
+
attn_out, _ = self.attn(qn, kvn, kvn) # [B, N_hr, C_hr]
|
| 77 |
+
|
| 78 |
+
y = q + attn_out
|
| 79 |
+
y = y + self.mlp(y)
|
| 80 |
+
|
| 81 |
+
return y.transpose(1, 2).view(B, C_hr, H_hr, W_hr)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class Mask2FormerSemanticHead(nn.Module):
|
| 85 |
+
"""
|
| 86 |
+
A minimal Mask2Former "semantic segmentation head" that consumes multi-scale backbone features
|
| 87 |
+
and outputs per-class per-pixel scores.
|
| 88 |
+
|
| 89 |
+
Input:
|
| 90 |
+
features: tuple/list of 4 feature maps (stage1..stage4), each [B, C_i, H_i, W_i].
|
| 91 |
+
The spatial strides should typically be [4, 8, 16, 32] relative to the input image.
|
| 92 |
+
|
| 93 |
+
Output:
|
| 94 |
+
semantic_scores: [B, num_classes, H_out, W_out], where H_out/W_out match the mask_features
|
| 95 |
+
resolution produced by Mask2Former pixel decoder (typically stride 4).
|
| 96 |
+
|
| 97 |
+
Notes:
|
| 98 |
+
Mask2Former natively predicts:
|
| 99 |
+
- class_queries_logits: [B, Q, num_classes+1] (includes "no object")
|
| 100 |
+
- masks_queries_logits: [B, Q, H_out, W_out]
|
| 101 |
+
For semantic segmentation, a common aggregation is:
|
| 102 |
+
semantic_probs = sum_q softmax(class_logits_q)[c] * sigmoid(mask_logits_q)[h,w]
|
| 103 |
+
Here we return these aggregated per-class *scores* (in [0,1]) as "logits" for compatibility
|
| 104 |
+
with the original CASWiT API. If you need true logits, apply logit() carefully (numerical stability).
|
| 105 |
+
"""
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
feature_channels: List[int],
|
| 109 |
+
num_classes: int,
|
| 110 |
+
num_queries: int = 100,
|
| 111 |
+
feature_size: int = 256,
|
| 112 |
+
mask_feature_size: int = 256,
|
| 113 |
+
common_stride: int = 4,
|
| 114 |
+
):
|
| 115 |
+
super().__init__()
|
| 116 |
+
|
| 117 |
+
cfg = Mask2FormerConfig(
|
| 118 |
+
num_labels=num_classes,
|
| 119 |
+
num_queries=num_queries,
|
| 120 |
+
feature_size=feature_size,
|
| 121 |
+
mask_feature_size=mask_feature_size,
|
| 122 |
+
common_stride=common_stride,
|
| 123 |
+
feature_strides=[4, 8, 16, 32],
|
| 124 |
+
encoder_layers=1,
|
| 125 |
+
decoder_layers=1,
|
| 126 |
+
num_attention_heads=8,
|
| 127 |
+
dim_feedforward=1024,
|
| 128 |
+
output_auxiliary_logits=False,
|
| 129 |
+
# keep defaults for transformer, heads, etc.
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
self.config = cfg
|
| 133 |
+
self.num_classes = num_classes
|
| 134 |
+
self.num_queries = num_queries
|
| 135 |
+
|
| 136 |
+
# Pixel decoder consumes backbone channels and produces:
|
| 137 |
+
# - multi_scale_features (3 levels: 1/8,1/16,1/32)
|
| 138 |
+
# - mask_features (typically 1/4)
|
| 139 |
+
self.pixel_decoder = Mask2FormerPixelDecoder(cfg, feature_channels=feature_channels)
|
| 140 |
+
|
| 141 |
+
# Transformer module consumes:
|
| 142 |
+
# - multi_scale_features (list of 3 tensors)
|
| 143 |
+
# - mask_features (tensor at stride 4)
|
| 144 |
+
# and returns masks_queries_logits for each decoder layer + intermediate states
|
| 145 |
+
self.transformer_module = Mask2FormerTransformerModule(in_features=cfg.feature_size, config=cfg)
|
| 146 |
+
|
| 147 |
+
# Class predictor (same idea as HF Mask2FormerForUniversalSegmentation)
|
| 148 |
+
self.class_predictor = nn.Linear(cfg.hidden_dim, num_classes + 1)
|
| 149 |
+
|
| 150 |
+
def forward(self, features: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
| 151 |
+
if not isinstance(features, (tuple, list)) or len(features) != 4:
|
| 152 |
+
raise ValueError("Mask2FormerSemanticHead expects a tuple/list of 4 feature maps: (stage1, stage2, stage3, stage4).")
|
| 153 |
+
|
| 154 |
+
# Expected order: [stage1, stage2, stage3, stage4] (increasing stride).
|
| 155 |
+
# Pixel decoder internally reverses and uses the last 3 feature maps for deformable attention.
|
| 156 |
+
pixel_out = self.pixel_decoder(list(features), return_dict=True)
|
| 157 |
+
multi_scale = list(pixel_out.multi_scale_features) # 3 levels
|
| 158 |
+
mask_features = pixel_out.mask_features # stride 4
|
| 159 |
+
|
| 160 |
+
dec_out = self.transformer_module(
|
| 161 |
+
multi_scale_features=multi_scale,
|
| 162 |
+
mask_features=mask_features,
|
| 163 |
+
output_hidden_states=True,
|
| 164 |
+
output_attentions=False,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Use last decoder layer predictions
|
| 168 |
+
masks_queries_logits = dec_out.masks_queries_logits[-1] # [B, Q, H, W]
|
| 169 |
+
|
| 170 |
+
# Last layer hidden state can be in shape [B, Q, D] OR [Q, B, D] depending on HF internals.
|
| 171 |
+
# For intermediate_hidden_states, HF uses [Q, B, D] (then transposes in their heads).
|
| 172 |
+
# We'll robustly support both:
|
| 173 |
+
hidden = dec_out.last_hidden_state
|
| 174 |
+
if hidden.dim() != 3:
|
| 175 |
+
raise RuntimeError(f"Unexpected last_hidden_state shape: {tuple(hidden.shape)}")
|
| 176 |
+
|
| 177 |
+
if hidden.shape[0] == self.num_queries and hidden.shape[1] == masks_queries_logits.shape[0]:
|
| 178 |
+
# [Q, B, D] -> [B, Q, D]
|
| 179 |
+
hidden_bqd = hidden.transpose(0, 1)
|
| 180 |
+
else:
|
| 181 |
+
# assume [B, Q, D]
|
| 182 |
+
hidden_bqd = hidden
|
| 183 |
+
|
| 184 |
+
class_queries_logits = self.class_predictor(hidden_bqd) # [B, Q, C+1]
|
| 185 |
+
|
| 186 |
+
# Aggregate to semantic per-class scores at mask resolution:
|
| 187 |
+
# softmax over classes (including no-object), then drop no-object channel
|
| 188 |
+
class_probs = class_queries_logits.softmax(dim=-1)[..., :-1] # [B, Q, C]
|
| 189 |
+
mask_probs = masks_queries_logits.sigmoid() # [B, Q, H, W]
|
| 190 |
+
|
| 191 |
+
# semantic_scores[b,c,h,w] = sum_q class_probs[b,q,c] * mask_probs[b,q,h,w]
|
| 192 |
+
semantic_scores = torch.einsum("bqc,bqhw->bchw", class_probs, mask_probs)
|
| 193 |
+
|
| 194 |
+
return semantic_scores
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class CASWiT(nn.Module):
|
| 198 |
+
"""
|
| 199 |
+
CASWiT: Context-Aware Swin Transformer for Ultra-High Resolution Semantic Segmentation.
|
| 200 |
+
|
| 201 |
+
Only change vs original: replace self.decoder and self.decoder_lr with Mask2FormerSemanticHead.
|
| 202 |
+
"""
|
| 203 |
+
def __init__(self, num_head_xa: int = 1, num_classes: int = 12,
|
| 204 |
+
model_name: str = "openmmlab/upernet-swin-tiny",
|
| 205 |
+
mlp_ratio: float = 4.0, drop_path: float = 0.1):
|
| 206 |
+
super().__init__()
|
| 207 |
+
|
| 208 |
+
model_hr = UperNetForSemanticSegmentation.from_pretrained(
|
| 209 |
+
model_name, num_labels=num_classes, ignore_mismatched_sizes=True
|
| 210 |
+
)
|
| 211 |
+
model_lr = UperNetForSemanticSegmentation.from_pretrained(
|
| 212 |
+
model_name, num_labels=num_classes, ignore_mismatched_sizes=True
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# Extract HR branch components
|
| 216 |
+
self.embeddings_hr = model_hr.backbone.embeddings
|
| 217 |
+
self.encoder_layers_hr = model_hr.backbone.encoder.layers
|
| 218 |
+
self.hidden_states_norms_hr = model_hr.backbone.hidden_states_norms
|
| 219 |
+
|
| 220 |
+
# Extract LR branch components
|
| 221 |
+
self.embeddings_lr = model_lr.backbone.embeddings
|
| 222 |
+
self.encoder_layers_lr = model_lr.backbone.encoder.layers
|
| 223 |
+
self.hidden_states_norms_lr = model_lr.backbone.hidden_states_norms
|
| 224 |
+
|
| 225 |
+
# Infer Swin stage dims from model name (same as original)
|
| 226 |
+
dims_map = {
|
| 227 |
+
"tiny": [96, 192, 384, 768],
|
| 228 |
+
"base": [128, 256, 512, 1024],
|
| 229 |
+
"large": [192, 384, 768, 1536]
|
| 230 |
+
}
|
| 231 |
+
if "tiny" in model_name.lower():
|
| 232 |
+
dims = dims_map["tiny"]
|
| 233 |
+
elif "large" in model_name.lower():
|
| 234 |
+
dims = dims_map["large"]
|
| 235 |
+
else:
|
| 236 |
+
dims = dims_map["base"]
|
| 237 |
+
|
| 238 |
+
# >>> ONLY MODIFIED PART: decoder / decoder_lr <<<
|
| 239 |
+
self.decoder = Mask2FormerSemanticHead(feature_channels=dims, num_classes=num_classes)
|
| 240 |
+
self.decoder_lr = Mask2FormerSemanticHead(feature_channels=dims, num_classes=num_classes)
|
| 241 |
+
|
| 242 |
+
# Cross-attention blocks at each stage
|
| 243 |
+
self.cross_attn_blocks = nn.ModuleList([
|
| 244 |
+
CrossFusionBlock(dim, dim, num_heads=num_head_xa,
|
| 245 |
+
mlp_ratio=mlp_ratio, drop=0.0, drop_path=drop_path)
|
| 246 |
+
for dim in dims
|
| 247 |
+
])
|
| 248 |
+
|
| 249 |
+
def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> Dict[str, torch.Tensor]:
|
| 250 |
+
B = x_hr.size(0)
|
| 251 |
+
|
| 252 |
+
# Patch embeddings
|
| 253 |
+
x_hr_seq, _ = self.embeddings_hr(x_hr)
|
| 254 |
+
x_lr_seq, _ = self.embeddings_lr(x_lr)
|
| 255 |
+
|
| 256 |
+
N_hr = x_hr_seq.shape[1]
|
| 257 |
+
N_lr = x_lr_seq.shape[1]
|
| 258 |
+
H_hr = W_hr = int(math.sqrt(N_hr))
|
| 259 |
+
H_lr = W_lr = int(math.sqrt(N_lr))
|
| 260 |
+
dims_hr = (H_hr, W_hr)
|
| 261 |
+
dims_lr = (H_lr, W_lr)
|
| 262 |
+
|
| 263 |
+
features_hr: Dict[str, torch.Tensor] = {}
|
| 264 |
+
features_lr: Dict[str, torch.Tensor] = {}
|
| 265 |
+
|
| 266 |
+
for idx, (stage_hr, stage_lr, ca) in enumerate(zip(
|
| 267 |
+
self.encoder_layers_hr, self.encoder_layers_lr, self.cross_attn_blocks
|
| 268 |
+
)):
|
| 269 |
+
for block in stage_hr.blocks:
|
| 270 |
+
x_hr_seq = block(x_hr_seq, dims_hr)
|
| 271 |
+
if isinstance(x_hr_seq, tuple):
|
| 272 |
+
x_hr_seq = x_hr_seq[0]
|
| 273 |
+
|
| 274 |
+
for block in stage_lr.blocks:
|
| 275 |
+
x_lr_seq = block(x_lr_seq, dims_lr)
|
| 276 |
+
if isinstance(x_lr_seq, tuple):
|
| 277 |
+
x_lr_seq = x_lr_seq[0]
|
| 278 |
+
|
| 279 |
+
x_hr_seq = self.hidden_states_norms_hr[f"stage{idx+1}"](x_hr_seq)
|
| 280 |
+
x_lr_seq = self.hidden_states_norms_lr[f"stage{idx+1}"](x_lr_seq)
|
| 281 |
+
|
| 282 |
+
H_hr, W_hr = dims_hr
|
| 283 |
+
H_lr, W_lr = dims_lr
|
| 284 |
+
C_hr = x_hr_seq.shape[-1]
|
| 285 |
+
C_lr = x_lr_seq.shape[-1]
|
| 286 |
+
|
| 287 |
+
feat_hr = x_hr_seq.transpose(1, 2).contiguous().view(B, C_hr, H_hr, W_hr)
|
| 288 |
+
feat_lr = x_lr_seq.transpose(1, 2).contiguous().view(B, C_lr, H_lr, W_lr)
|
| 289 |
+
|
| 290 |
+
fused_hr = ca(feat_hr, feat_lr)
|
| 291 |
+
fused_hr_seq = fused_hr.flatten(2).transpose(1, 2).contiguous()
|
| 292 |
+
|
| 293 |
+
if stage_hr.downsample is not None:
|
| 294 |
+
fused_hr_seq = stage_hr.downsample(fused_hr_seq, dims_hr)
|
| 295 |
+
dims_hr = (dims_hr[0] // 2, dims_hr[1] // 2)
|
| 296 |
+
if stage_lr.downsample is not None:
|
| 297 |
+
x_lr_seq = stage_lr.downsample(x_lr_seq, dims_lr)
|
| 298 |
+
dims_lr = (dims_lr[0] // 2, dims_lr[1] // 2)
|
| 299 |
+
|
| 300 |
+
features_hr[f"stage{idx+1}"] = fused_hr
|
| 301 |
+
features_lr[f"stage{idx+1}"] = feat_lr
|
| 302 |
+
x_hr_seq = fused_hr_seq
|
| 303 |
+
|
| 304 |
+
# Decode HR features
|
| 305 |
+
features_tuple = (
|
| 306 |
+
features_hr["stage1"],
|
| 307 |
+
features_hr["stage2"],
|
| 308 |
+
features_hr["stage3"],
|
| 309 |
+
features_hr["stage4"],
|
| 310 |
+
)
|
| 311 |
+
logits = self.decoder(features_tuple)
|
| 312 |
+
|
| 313 |
+
# Decode LR features
|
| 314 |
+
features_tuple_lr = (
|
| 315 |
+
features_lr["stage1"],
|
| 316 |
+
features_lr["stage2"],
|
| 317 |
+
features_lr["stage3"],
|
| 318 |
+
features_lr["stage4"],
|
| 319 |
+
)
|
| 320 |
+
logits_lr = self.decoder_lr(features_tuple_lr)
|
| 321 |
+
|
| 322 |
+
return {"logits_hr": logits, "logits_lr": logits_lr}
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def _test_mask2former_head():
|
| 326 |
+
"""
|
| 327 |
+
Minimal sanity test: validates that the Mask2FormerSemanticHead consumes
|
| 328 |
+
a (stage1..stage4) feature tuple and returns [B, C, H1, W1] scores.
|
| 329 |
+
"""
|
| 330 |
+
torch.manual_seed(0)
|
| 331 |
+
B = 1
|
| 332 |
+
num_classes = 12
|
| 333 |
+
dims = [96, 192, 384, 768]
|
| 334 |
+
H1, W1 = 8, 8
|
| 335 |
+
|
| 336 |
+
feats = (
|
| 337 |
+
torch.randn(B, dims[0], H1, W1),
|
| 338 |
+
torch.randn(B, dims[1], H1 // 2, W1 // 2),
|
| 339 |
+
torch.randn(B, dims[2], H1 // 4, W1 // 4),
|
| 340 |
+
torch.randn(B, dims[3], H1 // 8, W1 // 8),
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
head = Mask2FormerSemanticHead(feature_channels=dims, num_classes=num_classes, num_queries=50)
|
| 344 |
+
with torch.no_grad():
|
| 345 |
+
out = head(feats)
|
| 346 |
+
|
| 347 |
+
assert out.shape == (B, num_classes, H1, W1), f"Unexpected output shape: {out.shape}"
|
| 348 |
+
assert torch.isfinite(out).all(), "NaN/Inf in output"
|
| 349 |
+
return out.shape
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
if __name__ == "__main__":
|
| 353 |
+
# Run head test
|
| 354 |
+
print("Mask2Former head test output shape:", _test_mask2former_head())
|
model/CASWiT_segformer.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CASWiT: Context-Aware Swin Transformer for Ultra-High Resolution Semantic Segmentation
|
| 3 |
+
|
| 4 |
+
This module implements the main CASWiT model architecture with dual-branch
|
| 5 |
+
high-resolution and low-resolution processing with cross-attention fusion.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
from typing import Dict
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from transformers import UperNetForSemanticSegmentation, SegformerConfig
|
| 13 |
+
from transformers.models.segformer.modeling_segformer import SegformerDecodeHead
|
| 14 |
+
from transformers.utils import logging as hf_logging
|
| 15 |
+
|
| 16 |
+
hf_logging.set_verbosity_error()
|
| 17 |
+
hf_logging.disable_progress_bar()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DropPath(nn.Module):
|
| 22 |
+
"""Drop path (stochastic depth) regularization module."""
|
| 23 |
+
def __init__(self, drop_prob: float = 0.0):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.drop_prob = float(drop_prob)
|
| 26 |
+
|
| 27 |
+
def forward(self, x):
|
| 28 |
+
if self.drop_prob == 0.0 or (not self.training):
|
| 29 |
+
return x
|
| 30 |
+
keep = 1.0 - self.drop_prob
|
| 31 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
| 32 |
+
mask = x.new_empty(shape).bernoulli_(keep).div_(keep)
|
| 33 |
+
return x * mask
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class CrossFusionBlock(nn.Module):
|
| 37 |
+
"""
|
| 38 |
+
Cross-attention fusion block that enables HR features to attend to LR features.
|
| 39 |
+
|
| 40 |
+
Implements pre-norm cross-attention (Q=HR, K/V=LR).
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
C_hr: Channel dimension of HR features
|
| 44 |
+
C_lr: Channel dimension of LR features
|
| 45 |
+
num_heads: Number of attention heads
|
| 46 |
+
mlp_ratio: MLP expansion ratio
|
| 47 |
+
drop: Dropout rate
|
| 48 |
+
drop_path: Drop path rate
|
| 49 |
+
"""
|
| 50 |
+
def __init__(self, C_hr: int, C_lr: int, num_heads: int = 8,
|
| 51 |
+
mlp_ratio: float = 4.0, drop: float = 0.0, drop_path: float = 0.1):
|
| 52 |
+
super().__init__()
|
| 53 |
+
|
| 54 |
+
self.norm_q = nn.LayerNorm(C_hr)
|
| 55 |
+
self.norm_kv = nn.LayerNorm(C_lr)
|
| 56 |
+
self.attn = nn.MultiheadAttention(
|
| 57 |
+
embed_dim=C_hr, num_heads=num_heads, kdim=C_lr, vdim=C_lr,
|
| 58 |
+
dropout=drop, batch_first=True
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
hidden = int(C_hr * mlp_ratio)
|
| 62 |
+
self.mlp = nn.Sequential(
|
| 63 |
+
nn.LayerNorm(C_hr),
|
| 64 |
+
nn.Linear(C_hr, hidden),
|
| 65 |
+
nn.GELU(),
|
| 66 |
+
nn.Linear(hidden, C_hr),
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> torch.Tensor:
|
| 70 |
+
"""
|
| 71 |
+
Forward pass through cross-attention fusion block.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
x_hr: HR features [B, C_hr, H_hr, W_hr]
|
| 75 |
+
x_lr: LR features [B, C_lr, H_lr, W_lr]
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
Fused HR features [B, C_hr, H_hr, W_hr]
|
| 79 |
+
"""
|
| 80 |
+
B, C_hr, H_hr, W_hr = x_hr.shape
|
| 81 |
+
_, C_lr, H_lr, W_lr = x_lr.shape
|
| 82 |
+
|
| 83 |
+
# Flatten to sequences
|
| 84 |
+
q = x_hr.flatten(2).transpose(1, 2) # [B, N_hr, C_hr]
|
| 85 |
+
kv = x_lr.flatten(2).transpose(1, 2) # [B, N_lr, C_lr]
|
| 86 |
+
|
| 87 |
+
# Pre-norm
|
| 88 |
+
qn = self.norm_q(q)
|
| 89 |
+
kvn = self.norm_kv(kv)
|
| 90 |
+
|
| 91 |
+
attn_out, _ = self.attn(qn, kvn, kvn) # [B, N_hr, C_hr]
|
| 92 |
+
|
| 93 |
+
# Residual connection + MLP
|
| 94 |
+
y = q + attn_out
|
| 95 |
+
y = y + self.mlp(y)
|
| 96 |
+
|
| 97 |
+
return y.transpose(1, 2).view(B, C_hr, H_hr, W_hr)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class CASWiT(nn.Module):
|
| 101 |
+
"""
|
| 102 |
+
CASWiT: Context-Aware Swin Transformer for Ultra-High Resolution Semantic Segmentation.
|
| 103 |
+
|
| 104 |
+
Dual-branch architecture with:
|
| 105 |
+
- HR branch: Processes high-resolution crops
|
| 106 |
+
- LR branch: Processes low-resolution context
|
| 107 |
+
- Cross-attention fusion at each encoder stage
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
num_head_xa: Number of cross-attention heads
|
| 111 |
+
num_classes: Number of segmentation classes
|
| 112 |
+
model_name: HuggingFace model identifier for UPerNet-Swin
|
| 113 |
+
mlp_ratio: MLP expansion ratio in fusion blocks
|
| 114 |
+
drop_path: Drop path rate
|
| 115 |
+
"""
|
| 116 |
+
def __init__(self, num_head_xa: int = 1, num_classes: int = 12,
|
| 117 |
+
model_name: str = "openmmlab/upernet-swin-tiny",
|
| 118 |
+
mlp_ratio: float = 4.0, drop_path: float = 0.1):
|
| 119 |
+
super().__init__()
|
| 120 |
+
# Load two UPerNet backbones (HR and LR branches)
|
| 121 |
+
model_hr = UperNetForSemanticSegmentation.from_pretrained(
|
| 122 |
+
model_name, num_labels=num_classes, ignore_mismatched_sizes=True
|
| 123 |
+
)
|
| 124 |
+
model_lr = UperNetForSemanticSegmentation.from_pretrained(
|
| 125 |
+
model_name, num_labels=num_classes, ignore_mismatched_sizes=True
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Extract HR branch components
|
| 129 |
+
self.embeddings_hr = model_hr.backbone.embeddings
|
| 130 |
+
self.encoder_layers_hr = model_hr.backbone.encoder.layers
|
| 131 |
+
self.hidden_states_norms_hr = model_hr.backbone.hidden_states_norms
|
| 132 |
+
self.decoder = None # placeholder, set after dims inference
|
| 133 |
+
|
| 134 |
+
# Extract LR branch components
|
| 135 |
+
self.embeddings_lr = model_lr.backbone.embeddings
|
| 136 |
+
self.encoder_layers_lr = model_lr.backbone.encoder.layers
|
| 137 |
+
self.hidden_states_norms_lr = model_lr.backbone.hidden_states_norms
|
| 138 |
+
self.decoder_lr = None # placeholder, set after dims inference
|
| 139 |
+
|
| 140 |
+
# Cross-attention blocks at each stage
|
| 141 |
+
# Dimensions: tiny:[96, 192, 384, 768] base:[128, 256, 512, 1024] large:[192, 384, 768, 1536]
|
| 142 |
+
dims_map = {
|
| 143 |
+
"tiny": [96, 192, 384, 768],
|
| 144 |
+
"base": [128, 256, 512, 1024],
|
| 145 |
+
"large": [192, 384, 768, 1536]
|
| 146 |
+
}
|
| 147 |
+
# Infer dimensions from model name
|
| 148 |
+
if "tiny" in model_name.lower():
|
| 149 |
+
dims = dims_map["tiny"]
|
| 150 |
+
elif "large" in model_name.lower():
|
| 151 |
+
dims = dims_map["large"]
|
| 152 |
+
else:
|
| 153 |
+
dims = dims_map["base"] # default to base
|
| 154 |
+
|
| 155 |
+
segformer_cfg = SegformerConfig(
|
| 156 |
+
num_labels=num_classes,
|
| 157 |
+
hidden_sizes=dims,
|
| 158 |
+
num_encoder_blocks=4,
|
| 159 |
+
decoder_hidden_size=512,
|
| 160 |
+
classifier_dropout_prob=0.0,
|
| 161 |
+
)
|
| 162 |
+
self.decoder = SegformerDecodeHead(segformer_cfg)
|
| 163 |
+
self.decoder_lr = SegformerDecodeHead(segformer_cfg)
|
| 164 |
+
|
| 165 |
+
self.cross_attn_blocks = nn.ModuleList([
|
| 166 |
+
CrossFusionBlock(dim, dim, num_heads=num_head_xa,
|
| 167 |
+
mlp_ratio=mlp_ratio, drop=0.0, drop_path=drop_path)
|
| 168 |
+
for dim in dims
|
| 169 |
+
])
|
| 170 |
+
|
| 171 |
+
def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> Dict[str, torch.Tensor]:
|
| 172 |
+
"""
|
| 173 |
+
Forward pass through CASWiT model.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
x_hr: HR input images [B, 3, H_hr, W_hr]
|
| 177 |
+
x_lr: LR input images [B, 3, H_lr, W_lr]
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
Dictionary with 'logits_hr' and 'logits_lr' segmentation logits
|
| 181 |
+
"""
|
| 182 |
+
B = x_hr.size(0)
|
| 183 |
+
|
| 184 |
+
# Patch embeddings
|
| 185 |
+
x_hr_seq, _ = self.embeddings_hr(x_hr)
|
| 186 |
+
x_lr_seq, _ = self.embeddings_lr(x_lr)
|
| 187 |
+
|
| 188 |
+
N_hr, C_hr = x_hr_seq.shape[1], x_hr_seq.shape[2]
|
| 189 |
+
N_lr, C_lr = x_lr_seq.shape[1], x_lr_seq.shape[2]
|
| 190 |
+
H_hr = W_hr = int(math.sqrt(N_hr))
|
| 191 |
+
H_lr = W_lr = int(math.sqrt(N_lr))
|
| 192 |
+
dims_hr = (H_hr, W_hr)
|
| 193 |
+
dims_lr = (H_lr, W_lr)
|
| 194 |
+
|
| 195 |
+
features_hr: Dict[str, torch.Tensor] = {}
|
| 196 |
+
features_lr: Dict[str, torch.Tensor] = {}
|
| 197 |
+
|
| 198 |
+
# Process through encoder stages with cross-attention fusion
|
| 199 |
+
for idx, (stage_hr, stage_lr, ca) in enumerate(zip(
|
| 200 |
+
self.encoder_layers_hr, self.encoder_layers_lr, self.cross_attn_blocks
|
| 201 |
+
)):
|
| 202 |
+
# HR branch blocks
|
| 203 |
+
for block in stage_hr.blocks:
|
| 204 |
+
x_hr_seq = block(x_hr_seq, dims_hr)
|
| 205 |
+
if isinstance(x_hr_seq, tuple):
|
| 206 |
+
x_hr_seq = x_hr_seq[0]
|
| 207 |
+
|
| 208 |
+
# LR branch blocks
|
| 209 |
+
for block in stage_lr.blocks:
|
| 210 |
+
x_lr_seq = block(x_lr_seq, dims_lr)
|
| 211 |
+
if isinstance(x_lr_seq, tuple):
|
| 212 |
+
x_lr_seq = x_lr_seq[0]
|
| 213 |
+
|
| 214 |
+
# Layer normalization
|
| 215 |
+
x_hr_seq = self.hidden_states_norms_hr[f"stage{idx+1}"](x_hr_seq)
|
| 216 |
+
x_lr_seq = self.hidden_states_norms_lr[f"stage{idx+1}"](x_lr_seq)
|
| 217 |
+
|
| 218 |
+
H_hr, W_hr = dims_hr
|
| 219 |
+
H_lr, W_lr = dims_lr
|
| 220 |
+
C_hr = x_hr_seq.shape[-1]
|
| 221 |
+
C_lr = x_lr_seq.shape[-1]
|
| 222 |
+
|
| 223 |
+
# Reshape to spatial format
|
| 224 |
+
feat_hr = x_hr_seq.transpose(1, 2).contiguous().view(B, C_hr, H_hr, W_hr)
|
| 225 |
+
feat_lr = x_lr_seq.transpose(1, 2).contiguous().view(B, C_lr, H_lr, W_lr)
|
| 226 |
+
|
| 227 |
+
fused_hr = ca(feat_hr, feat_lr)
|
| 228 |
+
fused_hr_seq = fused_hr.flatten(2).transpose(1, 2).contiguous()
|
| 229 |
+
|
| 230 |
+
# Downsample if stage has it
|
| 231 |
+
if stage_hr.downsample is not None:
|
| 232 |
+
fused_hr_seq = stage_hr.downsample(fused_hr_seq, dims_hr)
|
| 233 |
+
dims_hr = (dims_hr[0] // 2, dims_hr[1] // 2)
|
| 234 |
+
if stage_lr.downsample is not None:
|
| 235 |
+
x_lr_seq = stage_lr.downsample(x_lr_seq, dims_lr)
|
| 236 |
+
dims_lr = (dims_lr[0] // 2, dims_lr[1] // 2)
|
| 237 |
+
|
| 238 |
+
features_hr[f"stage{idx+1}"] = fused_hr
|
| 239 |
+
features_lr[f"stage{idx+1}"] = feat_lr
|
| 240 |
+
x_hr_seq = fused_hr_seq
|
| 241 |
+
|
| 242 |
+
# Decode HR features
|
| 243 |
+
features_tuple = (
|
| 244 |
+
features_hr["stage1"],
|
| 245 |
+
features_hr["stage2"],
|
| 246 |
+
features_hr["stage3"],
|
| 247 |
+
features_hr["stage4"],
|
| 248 |
+
)
|
| 249 |
+
logits = self.decoder(features_tuple)
|
| 250 |
+
|
| 251 |
+
# Decode LR features (for auxiliary supervision)
|
| 252 |
+
features_tuple_lr = (
|
| 253 |
+
features_lr["stage1"],
|
| 254 |
+
features_lr["stage2"],
|
| 255 |
+
features_lr["stage3"],
|
| 256 |
+
features_lr["stage4"],
|
| 257 |
+
)
|
| 258 |
+
logits_lr = self.decoder_lr(features_tuple_lr)
|
| 259 |
+
|
| 260 |
+
return {"logits_hr": logits, "logits_lr": logits_lr}
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def _test_segformer_head():
|
| 265 |
+
"""Quick, offline test for SegFormer head input/output shapes."""
|
| 266 |
+
# Example dims for Swin-Tiny stages:
|
| 267 |
+
dims = [96, 192, 384, 768]
|
| 268 |
+
cfg = SegformerConfig(
|
| 269 |
+
num_labels=7,
|
| 270 |
+
hidden_sizes=dims,
|
| 271 |
+
num_encoder_blocks=4,
|
| 272 |
+
decoder_hidden_size=512,
|
| 273 |
+
classifier_dropout_prob=0.0,
|
| 274 |
+
)
|
| 275 |
+
head = SegformerDecodeHead(cfg)
|
| 276 |
+
|
| 277 |
+
B = 2
|
| 278 |
+
# Stage resolutions typically differ by /2 each time; here we mimic that.
|
| 279 |
+
f1 = torch.randn(B, dims[0], 128, 128)
|
| 280 |
+
f2 = torch.randn(B, dims[1], 64, 64)
|
| 281 |
+
f3 = torch.randn(B, dims[2], 32, 32)
|
| 282 |
+
f4 = torch.randn(B, dims[3], 16, 16)
|
| 283 |
+
|
| 284 |
+
logits = head((f1, f2, f3, f4))
|
| 285 |
+
assert logits.shape == (B, cfg.num_labels, 128, 128), f"Unexpected logits shape: {logits.shape}"
|
| 286 |
+
return logits.shape
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
if __name__ == "__main__":
|
| 290 |
+
print("SegFormer head test logits shape:", _test_segformer_head())
|
model/CASWiT_ssl.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CASWiT Self-Supervised Learning (SSL) Module
|
| 3 |
+
|
| 4 |
+
Implements SimMIM-based self-supervised pre-training for CASWiT using
|
| 5 |
+
masked image modeling with dual-branch HR/LR processing.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
from typing import Optional, Tuple
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from transformers import UperNetForSemanticSegmentation
|
| 14 |
+
from transformers.utils import logging as hf_logging
|
| 15 |
+
|
| 16 |
+
hf_logging.set_verbosity_error()
|
| 17 |
+
hf_logging.disable_progress_bar()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def random_masking_with_tokens(x: torch.Tensor, mask_ratio: float = 0.75,
|
| 22 |
+
mask_token: Optional[torch.Tensor] = None):
|
| 23 |
+
"""
|
| 24 |
+
Random masking at token level with learned mask token.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
x: Input tokens [B, N, C]
|
| 28 |
+
mask_ratio: Ratio of tokens to mask
|
| 29 |
+
mask_token: Learnable mask token
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
x_masked: Masked tokens [B, N, C]
|
| 33 |
+
mask: Binary mask [B, N] where 0=visible, 1=masked
|
| 34 |
+
ids_restore: Indices to restore original order
|
| 35 |
+
"""
|
| 36 |
+
B, N, C = x.shape
|
| 37 |
+
len_keep = int(N * (1 - mask_ratio))
|
| 38 |
+
|
| 39 |
+
noise = torch.rand(B, N, device=x.device)
|
| 40 |
+
ids_shuffle = torch.argsort(noise, dim=1)
|
| 41 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
| 42 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
| 43 |
+
|
| 44 |
+
x_keep = torch.gather(x, 1, ids_keep.unsqueeze(-1).expand(-1, -1, C))
|
| 45 |
+
|
| 46 |
+
if mask_token is None:
|
| 47 |
+
mask_token = torch.zeros((1, C), device=x.device)
|
| 48 |
+
m_tok = mask_token.view(1, 1, C).expand(B, N - len_keep, C)
|
| 49 |
+
|
| 50 |
+
x_cat = torch.cat([x_keep, m_tok], dim=1)
|
| 51 |
+
x_masked = torch.gather(x_cat, 1, ids_restore.unsqueeze(-1).expand(-1, -1, C))
|
| 52 |
+
|
| 53 |
+
mask = torch.ones(B, N, device=x.device)
|
| 54 |
+
mask[:, :len_keep] = 0
|
| 55 |
+
mask = torch.gather(mask, 1, ids_restore)
|
| 56 |
+
return x_masked, mask, ids_restore
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def center_masking_with_tokens(x: torch.Tensor, mask_token: Optional[torch.Tensor] = None,
|
| 60 |
+
mask_ratio: float = 0.5):
|
| 61 |
+
"""
|
| 62 |
+
Deterministic centered square mask.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
x: Input tokens [B, N, C]
|
| 66 |
+
mask_token: Learnable mask token
|
| 67 |
+
mask_ratio: Ratio of tokens to mask
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
x_masked: Masked tokens [B, N, C]
|
| 71 |
+
mask: Binary mask [B, N]
|
| 72 |
+
ids_restore: Indices to restore original order
|
| 73 |
+
"""
|
| 74 |
+
B, N, C = x.shape
|
| 75 |
+
H = W = int(N**0.5)
|
| 76 |
+
assert H * W == N, "N must be a perfect square"
|
| 77 |
+
L = int(round(H * (mask_ratio ** 0.5)))
|
| 78 |
+
start = (H - L) // 2
|
| 79 |
+
end = start + L
|
| 80 |
+
|
| 81 |
+
mask_2d = torch.zeros(H, W, device=x.device, dtype=torch.bool)
|
| 82 |
+
mask_2d[start:end, start:end] = True
|
| 83 |
+
mask = mask_2d.view(1, -1).expand(B, -1) # (B,N)
|
| 84 |
+
|
| 85 |
+
if mask_token is None:
|
| 86 |
+
mask_token = torch.zeros(C, device=x.device)
|
| 87 |
+
mask_token = mask_token.view(-1)
|
| 88 |
+
|
| 89 |
+
x_masked = x * (~mask).unsqueeze(-1) + mask.unsqueeze(-1) * mask_token.view(1, 1, C)
|
| 90 |
+
ids_restore = torch.arange(N, device=x.device).unsqueeze(0).expand(B, N)
|
| 91 |
+
return x_masked, mask.to(x_masked.dtype), ids_restore
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class CrossAttentionBlock(nn.Module):
|
| 95 |
+
"""Simplified cross-attention block for SSL."""
|
| 96 |
+
def __init__(self, C_hr, C_lr, num_heads=8, dropout=0.0):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.cross_attn = nn.MultiheadAttention(
|
| 99 |
+
embed_dim=C_hr, num_heads=num_heads, kdim=C_lr, vdim=C_lr,
|
| 100 |
+
dropout=dropout, batch_first=True
|
| 101 |
+
)
|
| 102 |
+
self.norm = nn.LayerNorm(C_hr)
|
| 103 |
+
self.mlp = nn.Sequential(
|
| 104 |
+
nn.LayerNorm(C_hr),
|
| 105 |
+
nn.Linear(C_hr, C_hr * 4),
|
| 106 |
+
nn.GELU(),
|
| 107 |
+
nn.Linear(C_hr * 4, C_hr),
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def forward(self, x_hr, x_lr):
|
| 111 |
+
B, C_hr, H_hr, W_hr = x_hr.shape
|
| 112 |
+
_, C_lr, H_lr, W_lr = x_lr.shape
|
| 113 |
+
q = x_hr.flatten(2).transpose(1, 2) # (B,N_hr,C_hr)
|
| 114 |
+
kv = x_lr.flatten(2).transpose(1, 2) # (B,N_lr,C_lr)
|
| 115 |
+
attn_out, _ = self.cross_attn(q, kv, kv)
|
| 116 |
+
y = self.norm(q + attn_out)
|
| 117 |
+
y = y + self.mlp(y)
|
| 118 |
+
return y.transpose(1, 2).view(B, C_hr, H_hr, W_hr)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class CASWiT_SSL(nn.Module):
|
| 122 |
+
"""
|
| 123 |
+
CASWiT Self-Supervised Learning model using SimMIM.
|
| 124 |
+
|
| 125 |
+
Encoder: Dual Swin backbones with cross-attention blocks
|
| 126 |
+
Decoder: Conv1x1 + PixelShuffle for reconstruction
|
| 127 |
+
Masking: HR random masking, LR center masking
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
model_name: HuggingFace model identifier
|
| 131 |
+
mask_ratio_hr: Masking ratio for HR branch
|
| 132 |
+
mask_ratio_lr: Masking ratio for LR branch
|
| 133 |
+
patch_size: Patch size for masking
|
| 134 |
+
encoder_stride: Encoder stride for decoder
|
| 135 |
+
xa_heads: Number of cross-attention heads per stage
|
| 136 |
+
"""
|
| 137 |
+
def __init__(self, model_name: str = "openmmlab/upernet-swin-base",
|
| 138 |
+
mask_ratio_hr: float = 0.75, mask_ratio_lr: float = 0.5,
|
| 139 |
+
patch_size: int = 4, encoder_stride: int = 32,
|
| 140 |
+
xa_heads: Tuple[int, int, int, int] = (8, 8, 8, 8)):
|
| 141 |
+
super().__init__()
|
| 142 |
+
self.mask_ratio_hr = mask_ratio_hr
|
| 143 |
+
self.mask_ratio_lr = mask_ratio_lr
|
| 144 |
+
self.patch_size = patch_size
|
| 145 |
+
self.encoder_stride = encoder_stride
|
| 146 |
+
|
| 147 |
+
# Load two UPerNet (Swin) backbones
|
| 148 |
+
model_hr = UperNetForSemanticSegmentation.from_pretrained(
|
| 149 |
+
model_name, ignore_mismatched_sizes=True
|
| 150 |
+
)
|
| 151 |
+
model_lr = UperNetForSemanticSegmentation.from_pretrained(
|
| 152 |
+
model_name, ignore_mismatched_sizes=True
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
self.embeddings_hr = model_hr.backbone.embeddings
|
| 156 |
+
self.encoder_layers_hr = model_hr.backbone.encoder.layers
|
| 157 |
+
self.hidden_states_norms_hr = model_hr.backbone.hidden_states_norms
|
| 158 |
+
|
| 159 |
+
self.embeddings_lr = model_lr.backbone.embeddings
|
| 160 |
+
self.encoder_layers_lr = model_lr.backbone.encoder.layers
|
| 161 |
+
self.hidden_states_norms_lr = model_lr.backbone.hidden_states_norms
|
| 162 |
+
|
| 163 |
+
# Cross-attention blocks with explicit Swin-Base dims
|
| 164 |
+
dims = [128, 256, 512, 1024]
|
| 165 |
+
self.cross_attn_blocks = nn.ModuleList([
|
| 166 |
+
CrossAttentionBlock(d, d, num_heads=h) for d, h in zip(dims, xa_heads)
|
| 167 |
+
])
|
| 168 |
+
|
| 169 |
+
# Learnable mask tokens
|
| 170 |
+
self.mask_token_hr = nn.Parameter(torch.zeros(1, dims[0]))
|
| 171 |
+
self.mask_token_lr = nn.Parameter(torch.zeros(1, dims[0]))
|
| 172 |
+
|
| 173 |
+
# SimMIM decoder: Conv1×1 → PixelShuffle(stride)
|
| 174 |
+
self.decoder_conv = None # lazy init after we know C_last
|
| 175 |
+
self.decoder_shuffle = nn.PixelShuffle(self.encoder_stride)
|
| 176 |
+
|
| 177 |
+
# Store masks for visualization
|
| 178 |
+
self.last_mask_hr = None
|
| 179 |
+
self.last_mask_lr = None
|
| 180 |
+
|
| 181 |
+
def _encode(self, x_hr: torch.Tensor, x_lr: torch.Tensor):
|
| 182 |
+
"""Encode with masking and return reconstruction targets."""
|
| 183 |
+
B, C, H, W = x_hr.shape
|
| 184 |
+
target_img = x_hr
|
| 185 |
+
target_lr = x_lr
|
| 186 |
+
|
| 187 |
+
# Patch embeddings
|
| 188 |
+
x_hr_seq, _ = self.embeddings_hr(x_hr) # (B, N_hr, C1)
|
| 189 |
+
x_lr_seq, _ = self.embeddings_lr(x_lr) # (B, N_lr, C1)
|
| 190 |
+
|
| 191 |
+
# Masking
|
| 192 |
+
x_hr_seq, mask_hr, _ = random_masking_with_tokens(
|
| 193 |
+
x_hr_seq, self.mask_ratio_hr, self.mask_token_hr
|
| 194 |
+
)
|
| 195 |
+
x_lr_seq, mask_lr, _ = center_masking_with_tokens(
|
| 196 |
+
x_lr_seq, self.mask_token_lr, mask_ratio=self.mask_ratio_lr
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# Initial spatial dims
|
| 200 |
+
H_hr = W_hr = int(math.sqrt(x_hr_seq.shape[1]))
|
| 201 |
+
H_lr = W_lr = int(math.sqrt(x_lr_seq.shape[1]))
|
| 202 |
+
dims_hr = (H_hr, W_hr)
|
| 203 |
+
dims_lr = (H_lr, W_lr)
|
| 204 |
+
|
| 205 |
+
# Walk encoder stages with cross attention at each stage
|
| 206 |
+
for idx, (stage_hr, stage_lr, ca) in enumerate(zip(
|
| 207 |
+
self.encoder_layers_hr, self.encoder_layers_lr, self.cross_attn_blocks
|
| 208 |
+
)):
|
| 209 |
+
# HR blocks
|
| 210 |
+
for block in stage_hr.blocks:
|
| 211 |
+
x_hr_seq = block(x_hr_seq, dims_hr)
|
| 212 |
+
if isinstance(x_hr_seq, tuple):
|
| 213 |
+
x_hr_seq = x_hr_seq[0]
|
| 214 |
+
# LR blocks
|
| 215 |
+
for block in stage_lr.blocks:
|
| 216 |
+
x_lr_seq = block(x_lr_seq, dims_lr)
|
| 217 |
+
if isinstance(x_lr_seq, tuple):
|
| 218 |
+
x_lr_seq = x_lr_seq[0]
|
| 219 |
+
|
| 220 |
+
# Norms
|
| 221 |
+
x_hr_seq = self.hidden_states_norms_hr[f"stage{idx+1}"](x_hr_seq)
|
| 222 |
+
x_lr_seq = self.hidden_states_norms_lr[f"stage{idx+1}"](x_lr_seq)
|
| 223 |
+
|
| 224 |
+
# Maps
|
| 225 |
+
B_, N_hr_, C_hr_ = x_hr_seq.shape
|
| 226 |
+
B_, N_lr_, C_lr_ = x_lr_seq.shape
|
| 227 |
+
Hh, Wh = dims_hr
|
| 228 |
+
Hl, Wl = dims_lr
|
| 229 |
+
feat_hr = x_hr_seq.transpose(1, 2).contiguous().view(B_, C_hr_, Hh, Wh)
|
| 230 |
+
feat_lr = x_lr_seq.transpose(1, 2).contiguous().view(B_, C_lr_, Hl, Wl)
|
| 231 |
+
|
| 232 |
+
# Cross-fuse HR <- LR
|
| 233 |
+
fused_hr = ca(feat_hr, feat_lr)
|
| 234 |
+
x_hr_seq = fused_hr.flatten(2).transpose(1, 2).contiguous()
|
| 235 |
+
|
| 236 |
+
# Downsample to next stage
|
| 237 |
+
if stage_hr.downsample is not None:
|
| 238 |
+
x_hr_seq = stage_hr.downsample(x_hr_seq, dims_hr)
|
| 239 |
+
dims_hr = (dims_hr[0] // 2, dims_hr[1] // 2)
|
| 240 |
+
if stage_lr.downsample is not None:
|
| 241 |
+
x_lr_seq = stage_lr.downsample(x_lr_seq, dims_lr)
|
| 242 |
+
dims_lr = (dims_lr[0] // 2, dims_lr[1] // 2)
|
| 243 |
+
|
| 244 |
+
# Last-stage feature map z (B, C_last, H/stride, W/stride)
|
| 245 |
+
Hs, Ws = dims_hr
|
| 246 |
+
C_last = x_hr_seq.shape[-1]
|
| 247 |
+
z = x_hr_seq.transpose(1, 2).contiguous().view(B, C_last, Hs, Ws)
|
| 248 |
+
|
| 249 |
+
# Lazy init decoder conv
|
| 250 |
+
if self.decoder_conv is None:
|
| 251 |
+
self.decoder_conv = nn.Conv2d(
|
| 252 |
+
C_last, (self.encoder_stride ** 2) * 3, kernel_size=1
|
| 253 |
+
).to(z.device)
|
| 254 |
+
|
| 255 |
+
# Reconstruction
|
| 256 |
+
x_rec = self.decoder_shuffle(self.decoder_conv(z)) # (B,3,H,W)
|
| 257 |
+
|
| 258 |
+
# Convert patch masks to pixel masks
|
| 259 |
+
Mh = int(math.sqrt(mask_hr.shape[1]))
|
| 260 |
+
mask_patch_hr = mask_hr.view(B, Mh, Mh)
|
| 261 |
+
mask_pix_hr = mask_patch_hr.repeat_interleave(
|
| 262 |
+
self.patch_size, 1
|
| 263 |
+
).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous()
|
| 264 |
+
|
| 265 |
+
Ml = int(math.sqrt(mask_lr.shape[1]))
|
| 266 |
+
mask_patch_lr = mask_lr.view(B, Ml, Ml)
|
| 267 |
+
mask_pix_lr = mask_patch_lr.repeat_interleave(
|
| 268 |
+
self.patch_size, 1
|
| 269 |
+
).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous()
|
| 270 |
+
|
| 271 |
+
self.last_mask_hr = mask_patch_hr
|
| 272 |
+
self.last_mask_lr = mask_patch_lr
|
| 273 |
+
|
| 274 |
+
return x_rec, target_img, mask_pix_hr, target_lr, mask_pix_lr
|
| 275 |
+
|
| 276 |
+
def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> torch.Tensor:
|
| 277 |
+
"""
|
| 278 |
+
Forward pass for SSL training.
|
| 279 |
+
|
| 280 |
+
Returns reconstruction loss on masked pixels only.
|
| 281 |
+
"""
|
| 282 |
+
x_rec, target_img, mask_pix, _, _ = self._encode(x_hr, x_lr)
|
| 283 |
+
loss_recon = F.l1_loss(target_img, x_rec, reduction='none')
|
| 284 |
+
loss = (loss_recon * mask_pix).sum() / (mask_pix.sum() + 1e-6) / target_img.shape[1]
|
| 285 |
+
return loss
|
| 286 |
+
|
| 287 |
+
@torch.no_grad()
|
| 288 |
+
def forward_outputs(self, x_hr: torch.Tensor, x_lr: torch.Tensor):
|
| 289 |
+
"""Forward pass returning all outputs for visualization."""
|
| 290 |
+
x_rec, target_img, mask_pix_hr, target_lr, mask_pix_lr = self._encode(x_hr, x_lr)
|
| 291 |
+
return x_rec, target_img, mask_pix_hr, target_lr, mask_pix_lr
|
| 292 |
+
|
model/CASWiT_upernet.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CASWiT: Context-Aware Swin Transformer for Ultra-High Resolution Semantic Segmentation
|
| 3 |
+
|
| 4 |
+
This module implements the main CASWiT model architecture with dual-branch
|
| 5 |
+
high-resolution and low-resolution processing with cross-attention fusion.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
from typing import Dict
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from transformers import UperNetForSemanticSegmentation
|
| 13 |
+
from transformers.utils import logging as hf_logging
|
| 14 |
+
|
| 15 |
+
hf_logging.set_verbosity_error()
|
| 16 |
+
hf_logging.disable_progress_bar()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class DropPath(nn.Module):
|
| 21 |
+
"""Drop path (stochastic depth) regularization module."""
|
| 22 |
+
def __init__(self, drop_prob: float = 0.0):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.drop_prob = float(drop_prob)
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
if self.drop_prob == 0.0 or (not self.training):
|
| 28 |
+
return x
|
| 29 |
+
keep = 1.0 - self.drop_prob
|
| 30 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
| 31 |
+
mask = x.new_empty(shape).bernoulli_(keep).div_(keep)
|
| 32 |
+
return x * mask
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class CrossFusionBlock(nn.Module):
|
| 36 |
+
"""
|
| 37 |
+
Cross-attention fusion block that enables HR features to attend to LR features.
|
| 38 |
+
|
| 39 |
+
Implements pre-norm cross-attention (Q=HR, K/V=LR).
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
C_hr: Channel dimension of HR features
|
| 43 |
+
C_lr: Channel dimension of LR features
|
| 44 |
+
num_heads: Number of attention heads
|
| 45 |
+
mlp_ratio: MLP expansion ratio
|
| 46 |
+
drop: Dropout rate
|
| 47 |
+
drop_path: Drop path rate
|
| 48 |
+
"""
|
| 49 |
+
def __init__(self, C_hr: int, C_lr: int, num_heads: int = 8,
|
| 50 |
+
mlp_ratio: float = 4.0, drop: float = 0.0, drop_path: float = 0.1):
|
| 51 |
+
super().__init__()
|
| 52 |
+
|
| 53 |
+
self.norm_q = nn.LayerNorm(C_hr)
|
| 54 |
+
self.norm_kv = nn.LayerNorm(C_lr)
|
| 55 |
+
self.attn = nn.MultiheadAttention(
|
| 56 |
+
embed_dim=C_hr, num_heads=num_heads, kdim=C_lr, vdim=C_lr,
|
| 57 |
+
dropout=drop, batch_first=True
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
hidden = int(C_hr * mlp_ratio)
|
| 61 |
+
self.mlp = nn.Sequential(
|
| 62 |
+
nn.LayerNorm(C_hr),
|
| 63 |
+
nn.Linear(C_hr, hidden),
|
| 64 |
+
nn.GELU(),
|
| 65 |
+
nn.Linear(hidden, C_hr),
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> torch.Tensor:
|
| 69 |
+
"""
|
| 70 |
+
Forward pass through cross-attention fusion block.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
x_hr: HR features [B, C_hr, H_hr, W_hr]
|
| 74 |
+
x_lr: LR features [B, C_lr, H_lr, W_lr]
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
Fused HR features [B, C_hr, H_hr, W_hr]
|
| 78 |
+
"""
|
| 79 |
+
B, C_hr, H_hr, W_hr = x_hr.shape
|
| 80 |
+
_, C_lr, H_lr, W_lr = x_lr.shape
|
| 81 |
+
|
| 82 |
+
# Flatten to sequences
|
| 83 |
+
q = x_hr.flatten(2).transpose(1, 2) # [B, N_hr, C_hr]
|
| 84 |
+
kv = x_lr.flatten(2).transpose(1, 2) # [B, N_lr, C_lr]
|
| 85 |
+
|
| 86 |
+
# Pre-norm
|
| 87 |
+
qn = self.norm_q(q)
|
| 88 |
+
kvn = self.norm_kv(kv)
|
| 89 |
+
|
| 90 |
+
attn_out, _ = self.attn(qn, kvn, kvn) # [B, N_hr, C_hr]
|
| 91 |
+
|
| 92 |
+
# Residual connection + MLP
|
| 93 |
+
y = q + attn_out
|
| 94 |
+
y = y + self.mlp(y)
|
| 95 |
+
|
| 96 |
+
return y.transpose(1, 2).view(B, C_hr, H_hr, W_hr)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class CASWiT(nn.Module):
|
| 100 |
+
"""
|
| 101 |
+
CASWiT: Context-Aware Swin Transformer for Ultra-High Resolution Semantic Segmentation.
|
| 102 |
+
|
| 103 |
+
Dual-branch architecture with:
|
| 104 |
+
- HR branch: Processes high-resolution crops
|
| 105 |
+
- LR branch: Processes low-resolution context
|
| 106 |
+
- Cross-attention fusion at each encoder stage
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
num_head_xa: Number of cross-attention heads
|
| 110 |
+
num_classes: Number of segmentation classes
|
| 111 |
+
model_name: HuggingFace model identifier for UPerNet-Swin
|
| 112 |
+
mlp_ratio: MLP expansion ratio in fusion blocks
|
| 113 |
+
drop_path: Drop path rate
|
| 114 |
+
"""
|
| 115 |
+
def __init__(self, num_head_xa: int = 1, num_classes: int = 12,
|
| 116 |
+
model_name: str = "openmmlab/upernet-swin-tiny",
|
| 117 |
+
mlp_ratio: float = 4.0, drop_path: float = 0.1):
|
| 118 |
+
super().__init__()
|
| 119 |
+
# Load two UPerNet backbones (HR and LR branches)
|
| 120 |
+
model_hr = UperNetForSemanticSegmentation.from_pretrained(
|
| 121 |
+
model_name, num_labels=num_classes, ignore_mismatched_sizes=True
|
| 122 |
+
)
|
| 123 |
+
model_lr = UperNetForSemanticSegmentation.from_pretrained(
|
| 124 |
+
model_name, num_labels=num_classes, ignore_mismatched_sizes=True
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Extract HR branch components
|
| 128 |
+
self.embeddings_hr = model_hr.backbone.embeddings
|
| 129 |
+
self.encoder_layers_hr = model_hr.backbone.encoder.layers
|
| 130 |
+
self.hidden_states_norms_hr = model_hr.backbone.hidden_states_norms
|
| 131 |
+
self.decoder = model_hr.decode_head
|
| 132 |
+
|
| 133 |
+
# Extract LR branch components
|
| 134 |
+
self.embeddings_lr = model_lr.backbone.embeddings
|
| 135 |
+
self.encoder_layers_lr = model_lr.backbone.encoder.layers
|
| 136 |
+
self.hidden_states_norms_lr = model_lr.backbone.hidden_states_norms
|
| 137 |
+
self.decoder_lr = model_lr.decode_head
|
| 138 |
+
|
| 139 |
+
# Cross-attention blocks at each stage
|
| 140 |
+
# Dimensions: tiny:[96, 192, 384, 768] base:[128, 256, 512, 1024] large:[192, 384, 768, 1536]
|
| 141 |
+
dims_map = {
|
| 142 |
+
"tiny": [96, 192, 384, 768],
|
| 143 |
+
"base": [128, 256, 512, 1024],
|
| 144 |
+
"large": [192, 384, 768, 1536]
|
| 145 |
+
}
|
| 146 |
+
# Infer dimensions from model name
|
| 147 |
+
if "tiny" in model_name.lower():
|
| 148 |
+
dims = dims_map["tiny"]
|
| 149 |
+
elif "large" in model_name.lower():
|
| 150 |
+
dims = dims_map["large"]
|
| 151 |
+
else:
|
| 152 |
+
dims = dims_map["base"] # default to base
|
| 153 |
+
|
| 154 |
+
self.cross_attn_blocks = nn.ModuleList([
|
| 155 |
+
CrossFusionBlock(dim, dim, num_heads=num_head_xa,
|
| 156 |
+
mlp_ratio=mlp_ratio, drop=0.0, drop_path=drop_path)
|
| 157 |
+
for dim in dims
|
| 158 |
+
])
|
| 159 |
+
|
| 160 |
+
def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> Dict[str, torch.Tensor]:
|
| 161 |
+
"""
|
| 162 |
+
Forward pass through CASWiT model.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
x_hr: HR input images [B, 3, H_hr, W_hr]
|
| 166 |
+
x_lr: LR input images [B, 3, H_lr, W_lr]
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Dictionary with 'logits_hr' and 'logits_lr' segmentation logits
|
| 170 |
+
"""
|
| 171 |
+
B = x_hr.size(0)
|
| 172 |
+
|
| 173 |
+
# Patch embeddings
|
| 174 |
+
x_hr_seq, _ = self.embeddings_hr(x_hr)
|
| 175 |
+
x_lr_seq, _ = self.embeddings_lr(x_lr)
|
| 176 |
+
|
| 177 |
+
N_hr, C_hr = x_hr_seq.shape[1], x_hr_seq.shape[2]
|
| 178 |
+
N_lr, C_lr = x_lr_seq.shape[1], x_lr_seq.shape[2]
|
| 179 |
+
H_hr = W_hr = int(math.sqrt(N_hr))
|
| 180 |
+
H_lr = W_lr = int(math.sqrt(N_lr))
|
| 181 |
+
dims_hr = (H_hr, W_hr)
|
| 182 |
+
dims_lr = (H_lr, W_lr)
|
| 183 |
+
|
| 184 |
+
features_hr: Dict[str, torch.Tensor] = {}
|
| 185 |
+
features_lr: Dict[str, torch.Tensor] = {}
|
| 186 |
+
|
| 187 |
+
# Process through encoder stages with cross-attention fusion
|
| 188 |
+
for idx, (stage_hr, stage_lr, ca) in enumerate(zip(
|
| 189 |
+
self.encoder_layers_hr, self.encoder_layers_lr, self.cross_attn_blocks
|
| 190 |
+
)):
|
| 191 |
+
# HR branch blocks
|
| 192 |
+
for block in stage_hr.blocks:
|
| 193 |
+
x_hr_seq = block(x_hr_seq, dims_hr)
|
| 194 |
+
if isinstance(x_hr_seq, tuple):
|
| 195 |
+
x_hr_seq = x_hr_seq[0]
|
| 196 |
+
|
| 197 |
+
# LR branch blocks
|
| 198 |
+
for block in stage_lr.blocks:
|
| 199 |
+
x_lr_seq = block(x_lr_seq, dims_lr)
|
| 200 |
+
if isinstance(x_lr_seq, tuple):
|
| 201 |
+
x_lr_seq = x_lr_seq[0]
|
| 202 |
+
|
| 203 |
+
# Layer normalization
|
| 204 |
+
x_hr_seq = self.hidden_states_norms_hr[f"stage{idx+1}"](x_hr_seq)
|
| 205 |
+
x_lr_seq = self.hidden_states_norms_lr[f"stage{idx+1}"](x_lr_seq)
|
| 206 |
+
|
| 207 |
+
H_hr, W_hr = dims_hr
|
| 208 |
+
H_lr, W_lr = dims_lr
|
| 209 |
+
C_hr = x_hr_seq.shape[-1]
|
| 210 |
+
C_lr = x_lr_seq.shape[-1]
|
| 211 |
+
|
| 212 |
+
# Reshape to spatial format
|
| 213 |
+
feat_hr = x_hr_seq.transpose(1, 2).contiguous().view(B, C_hr, H_hr, W_hr)
|
| 214 |
+
feat_lr = x_lr_seq.transpose(1, 2).contiguous().view(B, C_lr, H_lr, W_lr)
|
| 215 |
+
|
| 216 |
+
fused_hr = ca(feat_hr, feat_lr)
|
| 217 |
+
fused_hr_seq = fused_hr.flatten(2).transpose(1, 2).contiguous()
|
| 218 |
+
|
| 219 |
+
# Downsample if stage has it
|
| 220 |
+
if stage_hr.downsample is not None:
|
| 221 |
+
fused_hr_seq = stage_hr.downsample(fused_hr_seq, dims_hr)
|
| 222 |
+
dims_hr = (dims_hr[0] // 2, dims_hr[1] // 2)
|
| 223 |
+
if stage_lr.downsample is not None:
|
| 224 |
+
x_lr_seq = stage_lr.downsample(x_lr_seq, dims_lr)
|
| 225 |
+
dims_lr = (dims_lr[0] // 2, dims_lr[1] // 2)
|
| 226 |
+
|
| 227 |
+
features_hr[f"stage{idx+1}"] = fused_hr
|
| 228 |
+
features_lr[f"stage{idx+1}"] = feat_lr
|
| 229 |
+
x_hr_seq = fused_hr_seq
|
| 230 |
+
|
| 231 |
+
# Decode HR features
|
| 232 |
+
features_tuple = (
|
| 233 |
+
features_hr["stage1"],
|
| 234 |
+
features_hr["stage2"],
|
| 235 |
+
features_hr["stage3"],
|
| 236 |
+
features_hr["stage4"],
|
| 237 |
+
)
|
| 238 |
+
logits = self.decoder(features_tuple)
|
| 239 |
+
|
| 240 |
+
# Decode LR features (for auxiliary supervision)
|
| 241 |
+
features_tuple_lr = (
|
| 242 |
+
features_lr["stage1"],
|
| 243 |
+
features_lr["stage2"],
|
| 244 |
+
features_lr["stage3"],
|
| 245 |
+
features_lr["stage4"],
|
| 246 |
+
)
|
| 247 |
+
logits_lr = self.decoder_lr(features_tuple_lr)
|
| 248 |
+
|
| 249 |
+
return {"logits_hr": logits, "logits_lr": logits_lr}
|
| 250 |
+
|
model/build_model.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from model.CASWiT_upernet import CASWiT as CASWiT_UperNet
|
| 6 |
+
from model.CASWiT_segformer import CASWiT as CASWiT_SegFormer
|
| 7 |
+
from model.CASWiT_m2f import CASWiT as CASWiT_Mask2Former
|
| 8 |
+
from model.CASWiT_fusion_last_stage_add import CASWiT as CASWiT_FusionLastStageAdd
|
| 9 |
+
from model.CASWiT_ssl import CASWiT_SSL
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _get(cfg: Any, name: str, default: Any = None) -> Any:
|
| 13 |
+
return getattr(cfg, name, default)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def build_model(cfg: Any):
|
| 17 |
+
head = _get(cfg, "model_head", None) or _get(cfg, "head", None) or "upernet"
|
| 18 |
+
head = str(head).lower()
|
| 19 |
+
|
| 20 |
+
common = dict(
|
| 21 |
+
num_head_xa=int(_get(cfg, "cross_attention_heads")),
|
| 22 |
+
num_classes=int(_get(cfg, "num_classes")),
|
| 23 |
+
model_name=str(_get(cfg, "model_name")),
|
| 24 |
+
mlp_ratio=float(_get(cfg, "fusion_mlp_ratio")),
|
| 25 |
+
drop_path=float(_get(cfg, "fusion_drop_path")),
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
if head in ("upernet", "caswit", "default"):
|
| 29 |
+
return CASWiT_UperNet(**common)
|
| 30 |
+
if head in ("segformer",):
|
| 31 |
+
return CASWiT_SegFormer(**common)
|
| 32 |
+
if head in ("mask2former", "m2f"):
|
| 33 |
+
return CASWiT_Mask2Former(**common)
|
| 34 |
+
if head in ("fusion_last_stage_add", "last_stage_add"):
|
| 35 |
+
return CASWiT_FusionLastStageAdd(**common)
|
| 36 |
+
if head in ("ssl", "caswit_ssl"):
|
| 37 |
+
return CASWiT_SSL(model_name=str(_get(cfg, "model_name")))
|
| 38 |
+
raise ValueError(f"Unknown model head: {head}. Available: upernet, segformer, mask2former, fusion_last_stage_add, ssl")
|