Image Segmentation
English
antoine.carreaud67 commited on
Commit
36b4539
·
1 Parent(s): e6eaf2a

clean release

Browse files
README.md CHANGED
@@ -1,3 +1,303 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CASWiT: Context-Aware Stage Wise Transformer for Ultra-High Resolution Semantic Segmentation
2
+
3
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
4
+
5
+ Official implementation of **CASWiT**, a dual-branch architecture for ultra-high resolution semantic segmentation that leverages cross-attention fusion between high-resolution and low-resolution branches.
6
+
7
+ ## 📋 Table of Contents
8
+
9
+ - [Overview](#overview)
10
+ - [Architecture](#architecture)
11
+ - [Installation](#installation)
12
+ - [Dataset Preparation](#dataset-preparation)
13
+ - [Usage](#usage)
14
+ - [Configuration](#configuration)
15
+ - [Results](#results)
16
+ - [Citation](#citation)
17
+ - [License](#license)
18
+
19
+ ## 🎯 Overview
20
+
21
+ CASWiT addresses the challenge of semantic segmentation on ultra-high resolution images by introducing a dual-branch architecture:
22
+
23
+ - **HR Branch**: Processes high-resolution crops (512×512) for fine-grained detail
24
+ - **LR Branch**: Processes low-resolution context (downsampled by 2×) for context
25
+ - **Cross-Attention Fusion**: Enables HR features to attend to LR context at each encoder stage
26
+
27
+ This design allows the model to capture both local details and global context, leading to improved segmentation performance on large-scale datasets.
28
+ In particular, CASWiT achieves **65.35 mIoU** on the **FLAIR-HUB** RGB-only UHR benchmark and **49.1 mIoU** on **URUR**, outperforming prior RGB/UHR baselines while remaining memory-efficient.
29
+
30
+ ## 🏗️ Architecture
31
+
32
+ ![CASWiT architecture](model/architecture.png)
33
+
34
+ Key components:
35
+ - **Dual Swin Transformer Backbones**: Two UPerNet-Swin encoders process HR and LR streams
36
+ - **Cross-Attention Fusion Blocks**: Multi-head cross-attention at each encoder stage
37
+ - **Auxiliary LR Supervision**: Additional supervision on LR branch for better training
38
+
39
+ ## 📦 Installation
40
+
41
+ ### Requirements
42
+
43
+ - Python 3.8+
44
+ - PyTorch 2.0+
45
+ - CUDA 11.8+ (for GPU training)
46
+
47
+ ### Setup
48
+
49
+ 1. Clone the repository:
50
+ ```bash
51
+ git clone https://github.com/yourusername/CASWiT.git
52
+ cd CASWiT```
53
+
54
+ 2. Install dependencies:
55
+ ```bash
56
+ pip install -r requirements.txt
57
+ ```
58
+
59
+ ## 📊 Dataset Preparation
60
+
61
+ ### FLAIR-HUB
62
+
63
+ FlairHub is a large-scale ultra-high resolution semantic segmentation dataset. To prepare the dataset:
64
+
65
+ 1. Download the FlairHub dataset
66
+ 2. Run the preparation script to merge tiles:
67
+ ```bash
68
+ python dataset/prepareFlairHub.py
69
+ ```
70
+
71
+ The script will merge GeoTIFF tiles into larger mosaics suitable for training.
72
+
73
+ ### URUR
74
+
75
+ URUR dataset should be organized as:
76
+ ```
77
+ URUR/
78
+ ├── train/
79
+ │ ├── image/
80
+ │ └── label/
81
+ ├── val/
82
+ │ ├── image/
83
+ │ └── label/
84
+ └── test/
85
+ ├── image/
86
+ └── label/
87
+ ```
88
+
89
+ ### SWISSIMAGE
90
+
91
+ For SWISSIMAGE dataset:
92
+ 1. Download images using the provided CSV file:
93
+ ```bash
94
+ python dataset/download_swissimage.py list_all_swiss_image_sept2025.csv
95
+ ```
96
+
97
+ ## 🚀 Usage
98
+
99
+ ### Training
100
+
101
+ Train CASWiT on FlairHub:
102
+ ```bash
103
+ python train/train.py configs/FlairHub.yaml
104
+ ```
105
+
106
+ For distributed training (multi-GPU):
107
+ ```bash
108
+ torchrun --nproc_per_node=4 train/train.py configs/FlairHub.yaml
109
+ ```
110
+
111
+ ### Evaluation
112
+
113
+ Evaluate a trained model:
114
+ ```bash
115
+ python train/eval.py configs/FlairHub.yaml weights/checkpoint.pth test
116
+ ```
117
+
118
+ ### Inference
119
+
120
+ Run inference on a single image:
121
+ ```bash
122
+ python train/inference.py configs/FlairHub.yaml weights/checkpoint.pth image.tif output.png
123
+ ```
124
+
125
+ ### Using Main Entry Point
126
+
127
+ Alternatively, use the unified main script:
128
+ ```bash
129
+ # Training
130
+ python main.py train --config configs/FlairHub.yaml
131
+
132
+ # Evaluation
133
+ python main.py eval --config configs/FlairHub.yaml --checkpoint weights/checkpoint.pth
134
+
135
+ # Inference a single image
136
+ python main.py inference --config configs/FlairHub.yaml --checkpoint weights/checkpoint.pth --image image.tif --output pred.png
137
+ ```
138
+
139
+ ## ⚙️ Configuration
140
+
141
+ Configuration files are in YAML format. Example structure:
142
+
143
+ ```yaml
144
+ paths:
145
+ data_path: "/path/to/dataset"
146
+ dataset_name: ""
147
+ train_img_subdir: "train/img"
148
+ train_msk_subdir: "train/msk"
149
+ val_img_subdir: "val/img"
150
+ val_msk_subdir: "val/msk"
151
+ test_img_subdir: "test/img"
152
+ test_msk_subdir: "test/msk"
153
+ save_dir: "weights"
154
+ pretrained_path: ""
155
+
156
+ model:
157
+ model_name: "openmmlab/upernet-swin-base" # or swin-tiny, swin-large
158
+ num_classes: 15
159
+ cross_attention_heads: 1
160
+ fusion_mlp_ratio: 4.0
161
+ fusion_drop_path: 0.1
162
+ lr_supervision_weight: 0.5
163
+
164
+ training:
165
+ batch_size: 4
166
+ num_workers: 8
167
+ num_epochs: 20
168
+ learning_rate: 0.00006
169
+ amp: true
170
+ seed: 1337
171
+ eta_min: 0.000001
172
+
173
+ wandb:
174
+ use_wandb: true
175
+ project: "Fusion_HRLR"
176
+ entity: "your_entity"
177
+ run_name: "caswit_experiment"
178
+ ```
179
+
180
+ ### Key Parameters
181
+
182
+ - `model_name`: Swin variant (`upernet-swin-tiny`, `upernet-swin-base`, `upernet-swin-large`)
183
+ - `cross_attention_heads`: Number of attention heads in cross-attention blocks
184
+ - `lr_supervision_weight`: Weight for LR branch auxiliary supervision
185
+
186
+ ## 📈 Results
187
+
188
+ ### FLAIR-HUB (RGB-only UHR protocol)
189
+
190
+ We first evaluate CASWiT on the FLAIR-HUB ultra-high-resolution aerial benchmark under the RGB-only UHR protocol.
191
+
192
+ | Model | mIoU (%) ↑ | mF1 (%) ↑ | mBIoU (%) ↑ |
193
+ |----------------------------------|-----------:|----------:|------------:|
194
+ | *RGB Baselines (official FLAIR-HUB)* ||||
195
+ | Swin-T + UPerNet | 62.01 | 75.27 | – |
196
+ | Swin-S + UPerNet | 61.87 | 75.11 | – |
197
+ | Swin-B + UPerNet | 64.05 | 76.88 | – |
198
+ | Swin-B + UPerNet (retrained) | 64.02 | 76.64 | 32.57 |
199
+ | Swin-L + UPerNet | 63.36 | 76.35 | – |
200
+ | *Ours (RGB-only UHR protocol)* ||||
201
+ | **CASWiT-Base** | 65.11 | 77.71 | 35.87 |
202
+ | **CASWiT-Base-SSL** |**65.35** |**77.87** | **35.99** |
203
+
204
+ CASWiT-Base already improves over the retrained Swin-B + UPerNet baseline, and CASWiT-Base-SSL further pushes performance to **65.35 mIoU** and **77.87 mF1**.
205
+ On mean boundary IoU, CASWiT-Base-SSL reaches **35.99 mBIoU**, which is a **+3.42 mBIoU** gain over the retrained Swin-B baseline (32.57).
206
+
207
+ ---
208
+
209
+ ### URUR
210
+
211
+ We also evaluate CASWiT on the URUR ultra-high-resolution benchmark, comparing to both generic and UHR-specific segmentation models.
212
+
213
+ | Model | mIoU (%) ↑ | Mem (MB) ↓ |
214
+ |----------------------------------------|-----------:|-----------:|
215
+ | *Generic Models* |||
216
+ | PSPNet | 32.0 | 5482 |
217
+ | ResNet18 + DeepLabv3+ | 33.1 | 5508 |
218
+ | STDC | 42.0 | 7617 |
219
+ | *UHR Models* |||
220
+ | GLNet | 41.2 | 3063 |
221
+ | FCLt | 43.1 | 4508 |
222
+ | ISDNet | 45.8 | 4920 |
223
+ | WSDNet | 46.9 | 4510 |
224
+ | Boosting Dual-branch | 48.2 | 3682 |
225
+ | **CASWiT-Base** |**48.7** | 3530 |
226
+ | **CASWiT-Base-SSL** |**49.1** | 3530 |
227
+
228
+ On URUR, CASWiT-Base already matches and slightly surpasses prior UHR-specific methods, and CASWiT-Base-SSL achieves **49.1 mIoU**, i.e. **+2.2 mIoU** over WSDNet and **+0.9 mIoU** over Boosting Dual-branch (UHRS), while remaining competitive in memory usage.
229
+
230
+
231
+ ## 🔬 Self-Supervised Learning
232
+
233
+ CASWiT also supports self-supervised pre-training using SimMIM-style SSL (Simple Masked Image Modeling):
234
+
235
+ ```python
236
+ from model.CASWiT_ssl import CASWiT_SSL
237
+
238
+ model_ssl = CASWiT_SSL(
239
+ model_name="openmmlab/upernet-swin-base",
240
+ mask_ratio_hr=0.75,
241
+ mask_ratio_lr=0.5
242
+ )
243
+ ```
244
+
245
+ ## 🛠️ Project Structure
246
+
247
+ ```
248
+ CASWiT/
249
+ ├── model/
250
+ │ ├── CASWiT.py # Main model architecture
251
+ │ └── CASWiT_ssl.py # SSL variant
252
+ ├── dataset/
253
+ │ ├── definition_dataset.py
254
+ │ ├── download_swissimage.py
255
+ │ └── prepareFlairHub.py
256
+ ├── configs/
257
+ │ ├── FlairHub.yaml
258
+ │ ├── URUR.yaml
259
+ │ └── SWISSIMAGE.yaml
260
+ ├── utils/
261
+ │ ├── metrics.py
262
+ │ ├── logging.py
263
+ │ └── attention_viz.py
264
+ ├── train/
265
+ │ ├── train.py
266
+ │ ├── eval.py
267
+ │ └── inference.py
268
+ ├── weights/ # Model checkpoints
269
+ ├── main.py
270
+ ├── requirements.txt
271
+ └── README.md
272
+ ```
273
+
274
+ ## 📝 Citation
275
+
276
+ If you use CASWiT in your research, please cite:
277
+
278
+ ```bibtex
279
+ @article{caswit2025,
280
+ title={CASWiT: Context-Aware Swin Transformer for Ultra-High Resolution Semantic Segmentation},
281
+ author={Masked for instance},
282
+ journal={},
283
+ year={2026}
284
+ }
285
+ ```
286
+
287
+ ## 📄 License
288
+
289
+ This project is licensed under the MIT License - see the LICENSE file for details.
290
+
291
+ ## 🙏 Acknowledgments
292
+
293
+ - [UPerNet](https://github.com/open-mmlab/mmsegmentation) for the base segmentation architecture
294
+ - [Swin Transformer](https://github.com/microsoft/Swin-Transformer) for the backbone
295
+ - [FlairHub](https://github.com/IGNF/FlairHub) for the dataset
296
+ - [URUR](https://github.com/jankyee/URUR) for the dataset
297
+
298
+ ## 📧 Contact
299
+
300
+ For questions and issues, please open an issue on GitHub.
301
+
302
+ ---
303
+
configs/config_FlairHub.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ paths:
2
+ data_path: "/mnt/Data/FlairHUB/data_flairhub/output/FLAIR1024_optimal"
3
+ dataset_name: "FLAIRHUB"
4
+ train_img_subdir: "train/img"
5
+ train_msk_subdir: "train/msk"
6
+ val_img_subdir: "valid/img"
7
+ val_msk_subdir: "valid/msk"
8
+ test_img_subdir: "test/img"
9
+ test_msk_subdir: "test/msk"
10
+ save_dir: "weights"
11
+ pretrained_path: "weights/CASWiT-Base-SSL_FLAIRHUB_15classes.pth" #
12
+ model:
13
+ model_name: "openmmlab/upernet-swin-base"
14
+ num_classes: 15
15
+ cross_attention_heads: 1
16
+ ignore_index: 255
17
+ # Cross-fusion options
18
+ fusion_mlp_ratio: 4.0
19
+ fusion_drop_path: 0.1
20
+ lr_supervision_weight: 0.5
21
+ training:
22
+ batch_size: 4
23
+ num_workers: 8
24
+ num_epochs: 20
25
+ learning_rate: 0.00006
26
+ amp: true
27
+ seed: 42
28
+ eta_min: 0.000001
29
+ wandb:
30
+ use_wandb: true
31
+ project: "CASWiT-Base"
32
+ entity: "soloo"
33
+ run_name: "CASWiT-Base_FLAIRHUB_1epoch"
34
+ print_device: true
configs/config_SWISSIMAGE.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ paths:
2
+ data_path: "/path/to/swissimage"
3
+ dataset_name: "SWISSIMAGE"
4
+ train_img_subdir: "train/img"
5
+ train_msk_subdir: "train/msk"
6
+ val_img_subdir: "val/img"
7
+ val_msk_subdir: "val/msk"
8
+ test_img_subdir: "test/img"
9
+ test_msk_subdir: "test/msk"
10
+ save_dir: "weights"
11
+ pretrained_path: ""
12
+ model:
13
+ model_name: "openmmlab/upernet-swin-base"
14
+ num_classes: 15
15
+ cross_attention_heads: 1
16
+ ignore_index: 255
17
+ fusion_mlp_ratio: 4.0
18
+ fusion_drop_path: 0.1
19
+ lr_supervision_weight: 0.5
20
+ training:
21
+ batch_size: 4
22
+ num_workers: 8
23
+ num_epochs: 20
24
+ learning_rate: 0.00006
25
+ amp: true
26
+ seed: 1337
27
+ eta_min: 0.000001
28
+ wandb:
29
+ use_wandb: true
30
+ project: "Fusion_HRLR"
31
+ entity: "your_entity"
32
+ run_name: "swissimage_swin_base_swin_base_fusion_end_lr_supervised_1heads"
33
+ print_device: true
34
+
configs/config_URUR.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ paths:
2
+ data_path: "/mnt/Data/URUR"
3
+ dataset_name: "URUR"
4
+ train_img_subdir: "train/image"
5
+ train_msk_subdir: "train/label"
6
+ val_img_subdir: "val/image"
7
+ val_msk_subdir: "val/label"
8
+ test_img_subdir: "test/image"
9
+ test_msk_subdir: "test/label"
10
+ save_dir: "weights"
11
+ #pretrained_path: "weights/URUR_Last_8classes_alltrain_fusionend_hrlr_swinbase_swinbase_lrsupervised_pos_bias_none_12_epoch_15_head1.pth" # weights/fusion_hrlr_swintiny_swintiny_pos_bias_none_4_epoch_14_head1.pth
12
+ pretrained_path: "weights/from_mae_URUR_Lastv2_8classes_alltrain_fusionall_hrlr_swinbase_swinbase_lrsupervised_pos_bias_none_5_epoch_15_head1.pth"
13
+ model:
14
+ model_name: "openmmlab/upernet-swin-base"
15
+ num_classes: 8
16
+ cross_attention_heads: 1
17
+ ignore_index: 255
18
+ # Cross-fusion options
19
+ fusion_mlp_ratio: 4.0
20
+ fusion_drop_path: 0.1
21
+ lr_supervision_weight: 0.5
22
+ training:
23
+ batch_size: 5
24
+ num_workers: 8
25
+ num_epochs: 20
26
+ learning_rate: 0.00006
27
+ amp: true
28
+ seed: 1337
29
+ eta_min: 0.000001
30
+ wandb:
31
+ use_wandb: true
32
+ project: "Fusion_HRLR"
33
+ entity: "soloo"
34
+ run_name: "URUR_swin_base_swin_base_fusion_end_lr_supervised_1heads"
35
+ print_device: true
dataset/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset loaders and utilities.
3
+ """
4
+
5
+ from dataset.definition_dataset import (
6
+ SemanticSegmentationDatasetFusion,
7
+ SemanticSegmentationDatasetHR,
8
+ build_transforms
9
+ )
10
+
11
+ __all__ = [
12
+ 'SemanticSegmentationDatasetFusion',
13
+ 'SemanticSegmentationDatasetHR',
14
+ 'build_transforms',
15
+ ]
16
+
dataset/definition_dataset.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset definitions for CASWiT training and evaluation.
3
+
4
+ This module provides dataset classes for semantic segmentation with
5
+ HR/LR dual-branch processing.
6
+ """
7
+
8
+ import os
9
+ import math
10
+ from pathlib import Path
11
+ from typing import Optional, Union, Tuple, List, Dict
12
+ import numpy as np
13
+ import torch
14
+ from torch import Tensor
15
+ from torch.utils.data import Dataset
16
+ from PIL import Image
17
+ from tifffile import imread as tiff_imread
18
+ from torchvision import transforms
19
+
20
+
21
+ class SemanticSegmentationDatasetFusion(Dataset):
22
+ """
23
+ Dataset for HR/LR fusion training on FLAIRHub.
24
+
25
+ Returns (image_hr, mask_hr, image_lr, mask_lr):
26
+ - image_hr: 512x512 crop starting at (256, 256)
27
+ - image_lr: full image downsampled by factor 2
28
+ - mask >=15 replaced by 255 (ignore)
29
+ - transforms applied to images (ToTensor + Normalize) and mask -> LongTensor
30
+ """
31
+ def __init__(self, image_dir: Path, mask_dir: Path, transform: Optional[transforms.Compose] = None):
32
+ self.image_dir = Path(image_dir)
33
+ self.mask_dir = Path(mask_dir)
34
+ self.image_filenames = sorted(os.listdir(self.image_dir))
35
+ self.mask_filenames = sorted(os.listdir(self.mask_dir))
36
+ assert len(self.image_filenames) == len(self.mask_filenames), "Images/Masks count mismatch"
37
+ self.transform = transform
38
+
39
+ def __len__(self):
40
+ return len(self.image_filenames)
41
+
42
+ def __getitem__(self, idx):
43
+ image_path = self.image_dir / self.image_filenames[idx]
44
+ mask_path = self.mask_dir / self.mask_filenames[idx]
45
+
46
+ image = load_image(image_path)
47
+ mask = load_mask(mask_path)
48
+ mask[mask >= 15] = 255
49
+
50
+ # Crop HR at 512x512
51
+ hr_crop_size = 512
52
+ crop_x, crop_y = 256, 256
53
+
54
+ image_hr = image[crop_x:crop_x + hr_crop_size, crop_y:crop_y + hr_crop_size]
55
+ mask_hr = mask[crop_x:crop_x + hr_crop_size, crop_y:crop_y + hr_crop_size]
56
+
57
+ # Downsample LR
58
+ image_lr = image[::2,::2,:]
59
+ mask_lr = mask[::2,::2]
60
+
61
+ if self.transform:
62
+ image_hr = self.transform(to_pil_uint8(image_hr))
63
+ image_lr = self.transform(to_pil_uint8(image_lr))
64
+ else:
65
+ image_hr = to_tensor_img(image_hr)
66
+ image_lr = to_tensor_img(image_lr)
67
+ mask_hr = torch.tensor(mask_hr, dtype=torch.long)
68
+ mask_lr = torch.tensor(mask_lr, dtype=torch.long)
69
+
70
+ return image_hr, mask_hr, image_lr, mask_lr
71
+
72
+
73
+ class SemanticSegmentationDatasetHR(Dataset):
74
+ """
75
+ Dataset for HR-only training (single branch, no LR).
76
+
77
+ Returns (image_hr, mask_hr):
78
+ - image_hr: 512x512 crop starting at (256, 256)
79
+ - mask >=15 replaced by 255 (ignore)
80
+ """
81
+ def __init__(self, image_dir: Path, mask_dir: Path, transform: Optional[transforms.Compose] = None):
82
+ self.image_dir = Path(image_dir)
83
+ self.mask_dir = Path(mask_dir)
84
+ self.image_filenames = sorted(os.listdir(self.image_dir))
85
+ self.mask_filenames = sorted(os.listdir(self.mask_dir))
86
+ assert len(self.image_filenames) == len(self.mask_filenames), "Images/Masks count mismatch"
87
+ self.transform = transform
88
+
89
+ def __len__(self):
90
+ return len(self.image_filenames)
91
+
92
+ def __getitem__(self, idx):
93
+ image_path = self.image_dir / self.image_filenames[idx]
94
+ mask_path = self.mask_dir / self.mask_filenames[idx]
95
+
96
+ image = load_image(image_path)
97
+ mask = load_mask(mask_path)
98
+ mask[mask >= 15] = 255
99
+
100
+ crop_x, crop_y = 256, 256
101
+ image_hr = image[crop_x:crop_x + 512, crop_y:crop_y + 512]
102
+ mask_hr = mask[crop_x:crop_x + 512, crop_y:crop_y + 512]
103
+
104
+ if self.transform:
105
+ image_hr = self.transform(to_pil_uint8(image_hr))
106
+ else:
107
+ image_hr = to_tensor_img(image_hr)
108
+ mask_hr = torch.tensor(mask_hr, dtype=torch.long)
109
+ return image_hr, mask_hr
110
+
111
+
112
+ # ----------------------------
113
+ # Image Loading Functions for slifing windows without overlap on URUR, deepglobe and INRIA
114
+ # ----------------------------
115
+
116
+ def load_image(path: Union[str, Path]) -> np.ndarray:
117
+ """
118
+ Load an image as HxWx3 (RGB), float32 [0,1].
119
+ Handles both TIFF and PNG files gracefully.
120
+ """
121
+ p = str(path)
122
+ arr = None
123
+
124
+ # 1) Try TIFF first if available
125
+ if tiff_imread is not None:
126
+ try:
127
+ arr = tiff_imread(p)
128
+ except Exception:
129
+ arr = None
130
+
131
+ # 2) Fallback to PIL
132
+ if arr is None:
133
+ with Image.open(p) as im:
134
+ arr = np.array(im.convert("RGB")) # HWC uint8
135
+
136
+ # Ensure HWC format
137
+ if arr.ndim == 2:
138
+ arr = np.stack((arr, arr, arr), axis=-1) # HWC
139
+ elif arr.ndim == 3 and arr.shape[0] in (3, 4) and arr.shape[-1] not in (3, 4):
140
+ arr = np.moveaxis(arr, 0, -1) # CHW -> HWC
141
+
142
+ # Keep 3 channels
143
+ c = arr.shape[-1]
144
+ if c == 4:
145
+ arr = arr[..., :3]
146
+ elif c == 1:
147
+ arr = np.repeat(arr, 3, axis=-1)
148
+
149
+ # Normalize -> float32 [0,1]
150
+ if arr.dtype is np.dtype(np.uint8):
151
+ arr = arr.astype(np.float32) / 255.0
152
+ else:
153
+ arr = arr.astype(np.float32, copy=False)
154
+ m = arr.max()
155
+ if m > 1.0:
156
+ arr = arr / m
157
+
158
+ return arr # float32 HWC in [0,1]
159
+
160
+
161
+ def load_mask(path: Union[str, Path]) -> np.ndarray:
162
+ """Load a mask as HxW int64 (labels). Handles both TIFF and PNG files."""
163
+ p = str(path)
164
+ m = None
165
+ if tiff_imread is not None:
166
+ try:
167
+ m = tiff_imread(p)
168
+ except Exception:
169
+ m = None
170
+ if m is None:
171
+ with Image.open(p) as im:
172
+ m = np.array(im)
173
+
174
+ # Force 2D
175
+ if m.ndim == 3:
176
+ m = m[..., 0]
177
+ return m.astype(np.int64, copy=False)
178
+
179
+
180
+ # ----------------------------
181
+ # Helper Functions
182
+ # ----------------------------
183
+
184
+ def crop_with_pad(img: np.ndarray, y0: int, x0: int, h: int, w: int, pad_val=0) -> np.ndarray:
185
+ """Extract a crop HxW with padding if necessary (img HxW[ xC])."""
186
+ H, W = img.shape[:2]
187
+ y1, x1 = y0 + h, x0 + w
188
+
189
+ pad_top = max(0, -y0); ys = max(0, y0)
190
+ pad_left = max(0, -x0); xs = max(0, x0)
191
+ pad_bot = max(0, y1 - H); ye = min(H, y1)
192
+ pad_right = max(0, x1 - W); xe = min(W, x1)
193
+
194
+ sl = img[ys:ye, xs:xe]
195
+ pad_cfg = ((pad_top, pad_bot), (pad_left, pad_right)) + (((0, 0),) if img.ndim == 3 else ())
196
+ return np.pad(sl, pad_cfg, mode="constant", constant_values=pad_val)
197
+
198
+
199
+ def resize_np_img(img_hwc_float01: np.ndarray, size_hw: Tuple[int, int]) -> np.ndarray:
200
+ """Resize HWC float32[0,1] -> HWC float32[0,1] using bilinear interpolation."""
201
+ Ht, Wt = size_hw
202
+ im = Image.fromarray((np.clip(img_hwc_float01, 0.0, 1.0) * 255.0).astype(np.uint8))
203
+ im = im.resize((Wt, Ht), resample=Image.BILINEAR)
204
+ out = np.asarray(im, dtype=np.uint8).astype(np.float32) / 255.0
205
+ if out.ndim == 2:
206
+ out = np.stack((out, out, out), axis=-1)
207
+ return out
208
+
209
+
210
+ def resize_np_mask(mask_hw_int: np.ndarray, size_hw: Tuple[int, int]) -> np.ndarray:
211
+ """Resize mask HW using nearest neighbor via PIL, output int64."""
212
+ Ht, Wt = size_hw
213
+ mask = np.ascontiguousarray(mask_hw_int)
214
+ if mask.ndim != 2:
215
+ raise ValueError(f"resize_np_mask expects 2D mask HW, received shape={mask.shape}")
216
+
217
+ dt = mask.dtype
218
+ if dt in (np.int64, np.int32, np.int16, np.int8, np.uint16):
219
+ pil_arr, pil_mode = mask.astype(np.int32, copy=False), "I"
220
+ elif dt == np.uint8:
221
+ pil_arr, pil_mode = mask, "L"
222
+ else:
223
+ pil_arr, pil_mode = mask.astype(np.int32, copy=False), "I"
224
+
225
+ im = Image.fromarray(pil_arr, mode=pil_mode).resize((Wt, Ht), resample=Image.NEAREST)
226
+ return np.asarray(im).astype(np.int64, copy=False)
227
+
228
+
229
+ def to_tensor_img(x: np.ndarray) -> Tensor:
230
+ """Convert HWC float32[0,1] -> CHW float32[0,1]."""
231
+ return torch.from_numpy(np.transpose(x, (2, 0, 1)).copy())
232
+
233
+
234
+ def to_pil_uint8(img_float01_hwc: np.ndarray) -> Image.Image:
235
+ """Convert HWC float32[0,1] -> PIL RGB uint8."""
236
+ arr = (np.clip(img_float01_hwc, 0.0, 1.0) * 255.0).round().astype(np.uint8)
237
+ return Image.fromarray(arr, mode="RGB")
238
+
239
+
240
+ # ----------------------------
241
+ # Dataset Classes
242
+ # ----------------------------
243
+
244
+ class URURHRLRDataset(Dataset):
245
+ """
246
+ URUR dataset with HR/LR dual-branch processing and tiling support.
247
+
248
+ In test mode, each worker caches the current image+mask in RAM for all tiles
249
+ of the same image to avoid re-reading for each tile.
250
+ """
251
+ def __init__(
252
+ self,
253
+ image_dir: Union[str, Path],
254
+ mask_dir: Union[str, Path],
255
+ num_classes: int,
256
+ mode: str = "train",
257
+ ignore_index: int = 255,
258
+ hr_size: int = 1024,
259
+ lr_side: int = 2048,
260
+ transform: Optional = None,
261
+ limit: Optional[int] = None
262
+ ) -> None:
263
+ assert mode in {"train", "val", "test"}
264
+ self.image_dir = Path(image_dir)
265
+ self.mask_dir = Path(mask_dir)
266
+ self.mode = mode
267
+ self.num_classes = int(num_classes)
268
+ self.ignore_index = int(ignore_index)
269
+ self.HR = int(hr_size)
270
+ self.LR_WIN = int(lr_side)
271
+ self.transform = transform
272
+
273
+ imgs = sorted([p for p in self.image_dir.iterdir() if p.is_file()])
274
+ msks = sorted([p for p in self.mask_dir.iterdir() if p.is_file()])
275
+ if limit is not None:
276
+ imgs, msks = imgs[:limit], msks[:limit]
277
+ assert len(imgs) == len(msks) and len(imgs) > 0, "Images/Masks missing or misaligned"
278
+
279
+ self.images: List[Path] = imgs
280
+ self.masks: List[Path] = msks
281
+
282
+ # Tile index + quick sizes (without loading full image)
283
+ self._test_index: List[Tuple[int, int, int]] = [] # (img_id, y0, x0)
284
+ self._sizes: List[Tuple[int, int]] = [] # (H, W) per image
285
+
286
+ if self.mode == "test" or self.mode == "val":
287
+ for img_id, ip in enumerate(self.images):
288
+ with Image.open(ip) as im:
289
+ W, H = im.size
290
+ self._sizes.append((H, W))
291
+
292
+ n_ty = math.ceil(H / self.HR)
293
+ n_tx = math.ceil(W / self.HR)
294
+ for iy in range(n_ty):
295
+ for ix in range(n_tx):
296
+ self._test_index.append((img_id, iy * self.HR, ix * self.HR))
297
+
298
+ # Cache per worker (used only in test mode)
299
+ self._cache_img_id: Optional[int] = None
300
+ self._cache_img: Optional[np.ndarray] = None # float32 HWC [0,1]
301
+ self._cache_msk: Optional[np.ndarray] = None # int64 HW
302
+
303
+ def __len__(self) -> int:
304
+ return len(self._test_index) if self.mode == "test" else len(self.images)
305
+
306
+ def _extract_pair_np(
307
+ self, img: np.ndarray, msk: np.ndarray, y0: int, x0: int
308
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
309
+ """Extract HR/LR pair in numpy (images HWC float32[0,1], masks HW int64)."""
310
+ # HR
311
+ img_hr = crop_with_pad(img, y0, x0, self.HR, self.HR, pad_val=0)
312
+ msk_hr = crop_with_pad(msk, y0, x0, self.HR, self.HR, pad_val=self.ignore_index)
313
+ # Resize HR to half size
314
+ img_hr = resize_np_img(img_hr, (self.HR//2, self.HR//2))
315
+
316
+ # LR centered on HR
317
+ cy, cx = y0 + self.HR // 2, x0 + self.HR // 2
318
+ half = self.LR_WIN // 2
319
+ img_lr = crop_with_pad(img, cy - half, cx - half, self.LR_WIN, self.LR_WIN, pad_val=0)
320
+ msk_lr = crop_with_pad(msk, cy - half, cx - half, self.LR_WIN, self.LR_WIN, pad_val=self.ignore_index)
321
+
322
+ # Downsample LR -> HR
323
+ img_lr_512 = resize_np_img(img_lr, (self.HR//2, self.HR//2))
324
+ msk_lr_512 = resize_np_mask(msk_lr, (self.HR, self.HR))
325
+
326
+ # Clamp out of range -> ignore_index
327
+ msk_hr = msk_hr.astype(np.int64, copy=False)
328
+ msk_lr_512 = msk_lr_512.astype(np.int64, copy=False)
329
+ msk_hr[msk_hr >= self.num_classes] = self.ignore_index
330
+ msk_lr_512[msk_lr_512 >= self.num_classes] = self.ignore_index
331
+
332
+ return img_hr, msk_hr, img_lr_512, msk_lr_512
333
+
334
+ def __getitem__(self, idx: int):
335
+ if self.mode == "test" or self.mode == "val":
336
+ # Tile index
337
+ img_id, y0, x0 = self._test_index[idx]
338
+ ip, mp = self.images[img_id], self.masks[img_id]
339
+ H, W = self._sizes[img_id]
340
+
341
+ # Cache per worker: read/convert only once per image
342
+ if self._cache_img_id != img_id:
343
+ self._cache_img = load_image(ip) # float32 HWC [0,1]
344
+ self._cache_msk = load_mask(mp) # int64 HW
345
+ self._cache_img_id = img_id
346
+
347
+ img = self._cache_img
348
+ msk = self._cache_msk
349
+
350
+ img_hr_np, msk_hr_np, img_lr_np, msk_lr_np = self._extract_pair_np(img, msk, y0, x0)
351
+
352
+ if self.transform:
353
+ image_hr = self.transform(to_pil_uint8(img_hr_np))
354
+ image_lr = self.transform(to_pil_uint8(img_lr_np))
355
+ else:
356
+ image_hr = to_tensor_img(img_hr_np)
357
+ image_lr = to_tensor_img(img_lr_np)
358
+
359
+ mask_hr = torch.as_tensor(msk_hr_np, dtype=torch.long)
360
+ mask_lr = torch.as_tensor(msk_lr_np, dtype=torch.long)
361
+
362
+ meta: Dict[str, object] = {
363
+ "img_path": str(ip),
364
+ "mask_path": str(mp),
365
+ "tile": (int(y0), int(x0), self.HR, self.HR),
366
+ "img_hw": (int(H), int(W)),
367
+ "tile_index": int(idx),
368
+ }
369
+ return image_hr, mask_hr, image_lr, mask_lr, meta
370
+
371
+ # train mode
372
+ ip, mp = self.images[idx], self.masks[idx]
373
+ img = load_image(ip)
374
+ msk = load_mask(mp)
375
+ H, W = img.shape[:2]
376
+
377
+ y0 = 0 if H <= self.HR else np.random.randint(0, H - self.HR + 1)
378
+ x0 = 0 if W <= self.HR else np.random.randint(0, W - self.HR + 1)
379
+
380
+ img_hr_np, msk_hr_np, img_lr_np, msk_lr_np = self._extract_pair_np(img, msk, y0, x0)
381
+
382
+ if self.transform:
383
+ image_hr = self.transform(to_pil_uint8(img_hr_np))
384
+ image_lr = self.transform(to_pil_uint8(img_lr_np))
385
+ else:
386
+ image_hr = to_tensor_img(img_hr_np)
387
+ image_lr = to_tensor_img(img_lr_np)
388
+
389
+ mask_hr = torch.as_tensor(msk_hr_np, dtype=torch.long)
390
+ mask_lr = torch.as_tensor(msk_lr_np, dtype=torch.long)
391
+
392
+ meta: Dict[str, object] = {"img_path": str(ip), "mask_path": str(mp)}
393
+ return image_hr, mask_hr, image_lr, mask_lr, meta
394
+
395
+
396
+ def build_transforms():
397
+ """Build standard transforms with normalization (mean=std=0.5)."""
398
+ return transforms.Compose([
399
+ transforms.ToTensor(),
400
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
401
+ ])
402
+
dataset/download_swissimage.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SwissImage dataset downloader.
3
+
4
+ Downloads SwissImage dataset from URLs provided in a CSV file.
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import pandas as pd
10
+ import requests
11
+
12
+
13
+ def download_data(url: str, destination: str, count: int) -> int:
14
+ """
15
+ Download a single image from URL.
16
+
17
+ Args:
18
+ url: Image URL to download
19
+ destination: Local file path to save
20
+ count: Current download count
21
+
22
+ Returns:
23
+ Updated download count
24
+ """
25
+ response = requests.get(url)
26
+ if response.status_code == 200:
27
+ with open(destination, 'wb') as f:
28
+ f.write(response.content)
29
+ count += 1
30
+ else:
31
+ print(f"Failed to download: {url}. Status: {response.status_code}")
32
+ with open('failed_images.txt', 'a') as f:
33
+ f.write(url + '\n')
34
+ return count
35
+
36
+
37
+ def main():
38
+ """Main function to download SwissImage dataset."""
39
+ if len(sys.argv) < 2:
40
+ print("Usage: python download_swissimage.py <csv_file>")
41
+ sys.exit(1)
42
+
43
+ csv_file = sys.argv[1]
44
+ df = pd.read_csv(csv_file, header=None)
45
+ count_download = 0
46
+ count = 0
47
+ total = len(df)
48
+ print(f'Downloading {total} images.')
49
+
50
+ if not os.path.exists('data'):
51
+ os.mkdir('data')
52
+
53
+ for row in df.itertuples():
54
+ download_link = row[1]
55
+ count += 1
56
+ if count % 10 == 0:
57
+ print(f'Progress: {count/total*100:.1f}%')
58
+ fn = download_link.split('/')[-1]
59
+ fn_local = os.path.join('data', fn)
60
+ count_download = download_data(download_link, fn_local, count_download)
61
+
62
+ print(f'Process finished with {count_download} images downloaded out of {total} planned')
63
+
64
+
65
+ if __name__ == "__main__":
66
+ main()
67
+
dataset/prepareFlairHub.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FlairHub dataset preparation script.
3
+
4
+ Merges GeoTIFF tiles from FlairHub into larger mosaics for training.
5
+ This script processes the hierarchical folder structure and merges
6
+ neighboring tiles into 2x2 mosaics.
7
+ """
8
+
9
+ import rasterio
10
+ from rasterio.windows import Window
11
+ import os
12
+ from affine import Affine
13
+ import numpy as np
14
+ from tqdm import tqdm
15
+ from concurrent.futures import ProcessPoolExecutor
16
+
17
+
18
+ def get_raster_info(filepath):
19
+ """
20
+ Get raster information from a GeoTIFF file.
21
+
22
+ Args:
23
+ filepath: Path to the GeoTIFF file.
24
+
25
+ Returns:
26
+ tuple: (data, transform, profile) where:
27
+ - data: numpy array of raster data
28
+ - transform: Affine transform object
29
+ - profile: Dictionary of raster metadata
30
+ """
31
+ with rasterio.open(filepath) as src:
32
+ data = src.read()
33
+ transform = src.transform
34
+ profile = src.profile
35
+ return data, transform, profile
36
+
37
+
38
+ def is_near(value1, value2, tolerance=0.5):
39
+ """Check if two values are within a tolerance range."""
40
+ return abs(value1 - value2) <= tolerance
41
+
42
+
43
+ def find_neighboring_files(reference_file, corner_dict):
44
+ """
45
+ Find neighboring GeoTIFF files based on corner coordinates.
46
+
47
+ Args:
48
+ reference_file: Path to the reference GeoTIFF file.
49
+ corner_dict: Dictionary mapping filenames to corner coordinates.
50
+
51
+ Returns:
52
+ dict: Dictionary with keys ['right', 'bottom_right', 'bottom', 'bottom_left',
53
+ 'left', 'top_left', 'top', 'top_right'] containing paths to neighboring files.
54
+ """
55
+ neighbors = {
56
+ 'right': None, 'bottom_right': None, 'bottom': None, 'bottom_left': None,
57
+ 'left': None, 'top_left': None, 'top': None, 'top_right': None
58
+ }
59
+ reference_basename = os.path.basename(reference_file)
60
+ if reference_basename not in corner_dict:
61
+ return neighbors
62
+
63
+ reference_min_x, reference_min_y, reference_max_x, reference_max_y = corner_dict[reference_basename]
64
+ height = reference_max_y - reference_min_y
65
+ width = reference_max_x - reference_min_x
66
+
67
+ for filename, corners in corner_dict.items():
68
+ if filename == reference_basename:
69
+ continue
70
+
71
+ min_x, min_y, max_x, max_y = corners
72
+
73
+ # Check for right neighbor
74
+ if is_near(min_x, reference_max_x) and is_near(min_y, reference_min_y):
75
+ neighbors['right'] = os.path.join(os.path.dirname(reference_file), filename)
76
+
77
+ # Check for bottom_right neighbor
78
+ if is_near(min_x, reference_max_x) and is_near(min_y, (reference_min_y - height)):
79
+ neighbors['bottom_right'] = os.path.join(os.path.dirname(reference_file), filename)
80
+
81
+ # Check for bottom neighbor
82
+ if is_near(min_x, reference_min_x) and is_near(min_y, (reference_min_y - height)):
83
+ neighbors['bottom'] = os.path.join(os.path.dirname(reference_file), filename)
84
+
85
+ # Check for bottom_left neighbor
86
+ if is_near(min_x, reference_min_x - width) and is_near(min_y, (reference_min_y - height)):
87
+ neighbors['bottom_left'] = os.path.join(os.path.dirname(reference_file), filename)
88
+
89
+ # Check for left neighbor
90
+ if is_near(min_x, reference_min_x - width) and is_near(min_y, reference_min_y):
91
+ neighbors['left'] = os.path.join(os.path.dirname(reference_file), filename)
92
+
93
+ # Check for top_left neighbor
94
+ if is_near(min_x, reference_min_x - width) and is_near(min_y, reference_max_y):
95
+ neighbors['top_left'] = os.path.join(os.path.dirname(reference_file), filename)
96
+
97
+ # Check for top neighbor
98
+ if is_near(min_x, reference_min_x) and is_near(min_y, reference_max_y):
99
+ neighbors['top'] = os.path.join(os.path.dirname(reference_file), filename)
100
+
101
+ # Check for top_right neighbor
102
+ if is_near(min_x, reference_max_x) and is_near(min_y, reference_max_y):
103
+ neighbors['top_right'] = os.path.join(os.path.dirname(reference_file), filename)
104
+
105
+ return neighbors
106
+
107
+
108
+ def create_black_tile_like(reference_data, height, width):
109
+ """Create a black tile with the same properties as the reference data."""
110
+ count = reference_data.shape[0]
111
+ return np.zeros((count, height, width), dtype=reference_data.dtype)
112
+
113
+
114
+ def merge_geotiffs(reference_filepath, neighbors, output_filepath):
115
+ """Merge a reference GeoTIFF with its neighboring tiles into a larger mosaic."""
116
+ reference_data, _, profile = get_raster_info(reference_filepath)
117
+ _, reference_height, reference_width = reference_data.shape
118
+
119
+ output_width = 2 * reference_width
120
+ output_height = 2 * reference_height
121
+
122
+ _, top_transform, _ = get_raster_info(neighbors.get('top') or reference_filepath)
123
+
124
+ new_origin_x = top_transform.c - int(reference_width / 2)
125
+ new_origin_y = top_transform.f + int(reference_height / 2)
126
+
127
+ new_transform = Affine(
128
+ top_transform.a, top_transform.b, new_origin_x,
129
+ top_transform.d, top_transform.e, new_origin_y
130
+ )
131
+
132
+ with rasterio.open(
133
+ output_filepath, 'w', driver=profile['driver'], height=output_height, width=output_width,
134
+ count=profile['count'], dtype=profile['dtype'], transform=new_transform
135
+ ) as dst:
136
+ ref_window_offset_col = reference_width // 2
137
+ ref_window_offset_row = reference_height // 2
138
+ dst.write(reference_data, window=Window(
139
+ col_off=ref_window_offset_col,
140
+ row_off=ref_window_offset_row,
141
+ width=reference_width,
142
+ height=reference_height
143
+ ))
144
+
145
+ tile_layout = {
146
+ 'top_left': (0, 0, (reference_height//2, reference_width//2),
147
+ (slice(reference_height//2, None), slice(reference_width//2, None))),
148
+ 'top': (reference_width//2, 0, (reference_height//2, reference_width),
149
+ (slice(reference_height//2, None), slice(None))),
150
+ 'top_right': (3*reference_width//2, 0, (reference_height//2, reference_width//2),
151
+ (slice(reference_height//2, None), slice(0, reference_width//2))),
152
+ 'left': (0, reference_height//2, (reference_height, reference_width//2),
153
+ (slice(None), slice(reference_width//2, None))),
154
+ 'right': (3*reference_width//2, reference_height//2, (reference_height, reference_width//2),
155
+ (slice(None), slice(0, reference_width//2))),
156
+ 'bottom_left': (0, 3*reference_height//2, (reference_height//2, reference_width//2),
157
+ (slice(0, reference_height//2), slice(reference_width//2, None))),
158
+ 'bottom': (reference_width//2, 3*reference_height//2, (reference_height//2, reference_width),
159
+ (slice(0, reference_height//2), slice(None))),
160
+ 'bottom_right': (3*reference_width//2, 3*reference_height//2, (reference_height//2, reference_width//2),
161
+ (slice(0, reference_height//2), slice(0, reference_width//2))),
162
+ }
163
+
164
+ for direction, (offset_col, offset_row, (h, w), slicing) in tile_layout.items():
165
+ if neighbors[direction]:
166
+ neighbor_data, _, _ = get_raster_info(neighbors[direction])
167
+ neighbor_crop = neighbor_data[:, slicing[0], slicing[1]]
168
+ else:
169
+ neighbor_crop = create_black_tile_like(reference_data, h, w)
170
+
171
+ dst.write(neighbor_crop, window=Window(offset_col, offset_row, w, h))
172
+
173
+
174
+ def get_corner_coordinates(filepath):
175
+ """Get the corner coordinates of a GeoTIFF file."""
176
+ with rasterio.open(filepath) as src:
177
+ transform = src.transform
178
+ res_x, _, min_x, _, res_y, max_y, _, _, _ = transform
179
+ width, height = src.width, src.height
180
+ min_y = max_y + res_y * height
181
+ max_x = min_x + res_x * width
182
+ return min_x, min_y, max_x, max_y
183
+
184
+
185
+ def save_corner_coordinates(filepaths):
186
+ """Save corner coordinates for a list of GeoTIFF files."""
187
+ corners = {}
188
+ for filepath in filepaths:
189
+ basename = os.path.basename(filepath)
190
+ min_x, min_y, max_x, max_y = get_corner_coordinates(filepath)
191
+ corners[basename] = (min_x, min_y, max_x, max_y)
192
+ return corners
193
+
194
+
195
+ def find_lower_east_file(filepaths, corner_dict):
196
+ """Find the file with the lowest Y and easternmost X coordinates."""
197
+ lower_east_file = None
198
+ lowest_y = float('inf')
199
+ lowest_x = float('inf')
200
+
201
+ for filepath in filepaths:
202
+ basename = os.path.basename(filepath)
203
+ min_x, min_y, _, _ = corner_dict[basename]
204
+ if min_y < lowest_y or (min_y == lowest_y and min_x < lowest_x):
205
+ lowest_y = min_y
206
+ lowest_x = min_x
207
+ lower_east_file = filepath
208
+
209
+ return lower_east_file
210
+
211
+
212
+ def merge_files_from_folder(folder_path_in, folder_path_out):
213
+ """Merge all GeoTIFF files in a folder into larger mosaics."""
214
+ path_parts = folder_path_in.split(os.sep)
215
+ d_folder = next((part for part in path_parts if part.startswith('D')), '')
216
+ z_folder = next((part for part in path_parts if part.startswith('Z')), '')
217
+
218
+ filepaths = [os.path.join(folder_path_in, f) for f in os.listdir(folder_path_in)
219
+ if f.lower().endswith(".tif")]
220
+
221
+ corner_dict = save_corner_coordinates(filepaths)
222
+
223
+ while filepaths:
224
+ reference_filepath = find_lower_east_file(filepaths, corner_dict)
225
+ filepaths.remove(reference_filepath)
226
+
227
+ neighbors = find_neighboring_files(reference_filepath, corner_dict)
228
+ new_basename = f"{d_folder}_{z_folder}_{os.path.basename(reference_filepath)}"
229
+ output_filepath = os.path.join(folder_path_out, new_basename)
230
+ merge_geotiffs(reference_filepath, neighbors, output_filepath)
231
+
232
+
233
+ def list_folders(path):
234
+ """List all folders in a directory."""
235
+ return [f for f in os.listdir(path) if os.path.isdir(os.path.join(path, f))]
236
+
237
+
238
+ def process_d_folder(args):
239
+ """Process a single D folder."""
240
+ split, modality, d_folder, data_dir, merged_dir = args
241
+ results = []
242
+ for z_folder in list_folders(os.path.join(data_dir, split, modality, d_folder)):
243
+ folder_in = os.path.join(data_dir, split, modality, d_folder, z_folder)
244
+ folder_out = os.path.join(merged_dir, split, modality)
245
+ os.makedirs(folder_out, exist_ok=True)
246
+ merge_files_from_folder(folder_in, folder_out)
247
+ results.append((folder_in, folder_out))
248
+ return results
249
+
250
+
251
+ if __name__ == "__main__":
252
+ data_dir = "/mnt/Data/FlairHUB/data_flairhub/output/"
253
+ merged_dir = "/mnt/Data/FlairHUB/data_flairhub/output/FLAIR1024_optimal"
254
+ os.makedirs(merged_dir, exist_ok=True)
255
+
256
+ for split in ["train", "valid", "test"]:
257
+ print(f"📂 Split: {split}")
258
+ for modality in ["img", "msk"]:
259
+ print(f"🏞️ Modality: {modality}")
260
+ d_folders = [f for f in list_folders(os.path.join(data_dir, split, modality))
261
+ if f.startswith("D")]
262
+
263
+ args_list = [(split, modality, d, data_dir, merged_dir) for d in d_folders]
264
+
265
+ with ProcessPoolExecutor(max_workers=16) as executor:
266
+ list(tqdm(executor.map(process_d_folder, args_list), total=len(args_list),
267
+ desc="Processing D-folders"))
268
+
269
+ print("📝 Done!")
270
+
list_all_swiss_image_sept2025.csv ADDED
The diff for this file is too large to render. See raw diff
 
main.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main entry point for CASWiT training and evaluation.
3
+
4
+ This script provides a unified interface for training, evaluation, and inference.
5
+ """
6
+
7
+ import argparse
8
+ import sys
9
+ import logging
10
+ from pathlib import Path
11
+
12
+ # Add project root to path
13
+ sys.path.insert(0, str(Path(__file__).parent))
14
+
15
+ from train.train import main as train_main, load_config
16
+ from train.eval import evaluate_model
17
+ from train.inference import inference_single_image
18
+ from model.CASWiT import CASWiT
19
+ from dataset.definition_dataset import build_transforms
20
+
21
+
22
+ def main():
23
+ """Main entry point."""
24
+ parser = argparse.ArgumentParser(description="CASWiT: Context-Aware Swin Transformer")
25
+ parser.add_argument("mode", choices=["train", "eval", "inference"],
26
+ help="Mode: train, eval, or inference")
27
+ parser.add_argument("--config", type=str, required=True,
28
+ help="Path to config YAML file")
29
+ parser.add_argument("--checkpoint", type=str, default="",
30
+ help="Path to model checkpoint (for eval/inference)")
31
+ parser.add_argument("--image", type=str, default="",
32
+ help="Path to input image (for inference)")
33
+ parser.add_argument("--output", type=str, default="prediction.png",
34
+ help="Path to save output (for inference)")
35
+ parser.add_argument("--split", type=str, default="test", choices=["test", "val"],
36
+ help="Dataset split for evaluation")
37
+
38
+ args = parser.parse_args()
39
+
40
+ if args.mode == "train":
41
+ train_main(args.config)
42
+ elif args.mode == "eval":
43
+ if not args.checkpoint:
44
+ print("Error: --checkpoint required for evaluation")
45
+ sys.exit(1)
46
+ cfg = load_config(args.config)
47
+ evaluate_model(cfg, args.checkpoint, args.split)
48
+ elif args.mode == "inference":
49
+ if not args.checkpoint or not args.image:
50
+ print("Error: --checkpoint and --image required for inference")
51
+ sys.exit(1)
52
+ import torch
53
+
54
+ cfg = load_config(args.config)
55
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+
57
+ # Validate checkpoint path
58
+ checkpoint_path_obj = Path(args.checkpoint)
59
+ if not checkpoint_path_obj.exists() or not checkpoint_path_obj.is_file():
60
+ print(f"Error: Checkpoint file not found: {args.checkpoint}")
61
+ sys.exit(1)
62
+
63
+ model = CASWiT(
64
+ num_head_xa=cfg.cross_attention_heads,
65
+ num_classes=cfg.num_classes,
66
+ model_name=cfg.model_name,
67
+ mlp_ratio=cfg.fusion_mlp_ratio,
68
+ drop_path=cfg.fusion_drop_path
69
+ ).to(device)
70
+
71
+ print(f"Loading checkpoint from: {args.checkpoint}")
72
+ state_dict = torch.load(args.checkpoint, map_location=device)
73
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
74
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
75
+ print(f"Successfully loaded checkpoint from: {args.checkpoint}")
76
+ if len(missing) > 0:
77
+ print(f" Missing keys: {len(missing)}")
78
+ if len(unexpected) > 0:
79
+ print(f" Unexpected keys: {len(unexpected)}")
80
+ if len(missing) == 0 and len(unexpected) == 0:
81
+ print(f" Perfect match! All weights loaded successfully.")
82
+
83
+ transform = build_transforms()
84
+ inference_single_image(model, args.image, device, transform, args.output)
85
+
86
+
87
+ if __name__ == "__main__":
88
+ main()
89
+
model/CASWiT.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CASWiT: Context-Aware Swin Transformer for Ultra-High Resolution Semantic Segmentation
3
+
4
+ This module implements the main CASWiT model architecture with dual-branch
5
+ high-resolution and low-resolution processing with cross-attention fusion.
6
+ """
7
+
8
+ import math
9
+ from typing import Dict
10
+ import torch
11
+ import torch.nn as nn
12
+ from transformers import UperNetForSemanticSegmentation
13
+
14
+
15
+ class DropPath(nn.Module):
16
+ """Drop path (stochastic depth) regularization module."""
17
+ def __init__(self, drop_prob: float = 0.0):
18
+ super().__init__()
19
+ self.drop_prob = float(drop_prob)
20
+
21
+ def forward(self, x):
22
+ if self.drop_prob == 0.0 or (not self.training):
23
+ return x
24
+ keep = 1.0 - self.drop_prob
25
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
26
+ mask = x.new_empty(shape).bernoulli_(keep).div_(keep)
27
+ return x * mask
28
+
29
+
30
+ class CrossFusionBlock(nn.Module):
31
+ """
32
+ Cross-attention fusion block that enables HR features to attend to LR features.
33
+
34
+ Implements pre-norm cross-attention (Q=HR, K/V=LR).
35
+
36
+ Args:
37
+ C_hr: Channel dimension of HR features
38
+ C_lr: Channel dimension of LR features
39
+ num_heads: Number of attention heads
40
+ mlp_ratio: MLP expansion ratio
41
+ drop: Dropout rate
42
+ drop_path: Drop path rate
43
+ """
44
+ def __init__(self, C_hr: int, C_lr: int, num_heads: int = 8,
45
+ mlp_ratio: float = 4.0, drop: float = 0.0, drop_path: float = 0.1):
46
+ super().__init__()
47
+
48
+ self.norm_q = nn.LayerNorm(C_hr)
49
+ self.norm_kv = nn.LayerNorm(C_lr)
50
+ self.attn = nn.MultiheadAttention(
51
+ embed_dim=C_hr, num_heads=num_heads, kdim=C_lr, vdim=C_lr,
52
+ dropout=drop, batch_first=True
53
+ )
54
+
55
+ hidden = int(C_hr * mlp_ratio)
56
+ self.mlp = nn.Sequential(
57
+ nn.LayerNorm(C_hr),
58
+ nn.Linear(C_hr, hidden),
59
+ nn.GELU(),
60
+ nn.Linear(hidden, C_hr),
61
+ )
62
+
63
+ def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> torch.Tensor:
64
+ """
65
+ Forward pass through cross-attention fusion block.
66
+
67
+ Args:
68
+ x_hr: HR features [B, C_hr, H_hr, W_hr]
69
+ x_lr: LR features [B, C_lr, H_lr, W_lr]
70
+
71
+ Returns:
72
+ Fused HR features [B, C_hr, H_hr, W_hr]
73
+ """
74
+ B, C_hr, H_hr, W_hr = x_hr.shape
75
+ _, C_lr, H_lr, W_lr = x_lr.shape
76
+
77
+ # Flatten to sequences
78
+ q = x_hr.flatten(2).transpose(1, 2) # [B, N_hr, C_hr]
79
+ kv = x_lr.flatten(2).transpose(1, 2) # [B, N_lr, C_lr]
80
+
81
+ # Pre-norm
82
+ qn = self.norm_q(q)
83
+ kvn = self.norm_kv(kv)
84
+
85
+ attn_out, _ = self.attn(qn, kvn, kvn) # [B, N_hr, C_hr]
86
+
87
+ # Residual connection + MLP
88
+ y = q + attn_out
89
+ y = y + self.mlp(y)
90
+
91
+ return y.transpose(1, 2).view(B, C_hr, H_hr, W_hr)
92
+
93
+
94
+ class CASWiT(nn.Module):
95
+ """
96
+ CASWiT: Context-Aware Swin Transformer for Ultra-High Resolution Semantic Segmentation.
97
+
98
+ Dual-branch architecture with:
99
+ - HR branch: Processes high-resolution crops
100
+ - LR branch: Processes low-resolution context
101
+ - Cross-attention fusion at each encoder stage
102
+
103
+ Args:
104
+ num_head_xa: Number of cross-attention heads
105
+ num_classes: Number of segmentation classes
106
+ model_name: HuggingFace model identifier for UPerNet-Swin
107
+ mlp_ratio: MLP expansion ratio in fusion blocks
108
+ drop_path: Drop path rate
109
+ """
110
+ def __init__(self, num_head_xa: int = 1, num_classes: int = 12,
111
+ model_name: str = "openmmlab/upernet-swin-tiny",
112
+ mlp_ratio: float = 4.0, drop_path: float = 0.1):
113
+ super().__init__()
114
+ # Load two UPerNet backbones (HR and LR branches)
115
+ model_hr = UperNetForSemanticSegmentation.from_pretrained(
116
+ model_name, num_labels=num_classes, ignore_mismatched_sizes=True
117
+ )
118
+ model_lr = UperNetForSemanticSegmentation.from_pretrained(
119
+ model_name, num_labels=num_classes, ignore_mismatched_sizes=True
120
+ )
121
+
122
+ # Extract HR branch components
123
+ self.embeddings_hr = model_hr.backbone.embeddings
124
+ self.encoder_layers_hr = model_hr.backbone.encoder.layers
125
+ self.hidden_states_norms_hr = model_hr.backbone.hidden_states_norms
126
+ self.decoder = model_hr.decode_head
127
+
128
+ # Extract LR branch components
129
+ self.embeddings_lr = model_lr.backbone.embeddings
130
+ self.encoder_layers_lr = model_lr.backbone.encoder.layers
131
+ self.hidden_states_norms_lr = model_lr.backbone.hidden_states_norms
132
+ self.decoder_lr = model_lr.decode_head
133
+
134
+ # Cross-attention blocks at each stage
135
+ # Dimensions: tiny:[96, 192, 384, 768] base:[128, 256, 512, 1024] large:[192, 384, 768, 1536]
136
+ dims_map = {
137
+ "tiny": [96, 192, 384, 768],
138
+ "base": [128, 256, 512, 1024],
139
+ "large": [192, 384, 768, 1536]
140
+ }
141
+ # Infer dimensions from model name
142
+ if "tiny" in model_name.lower():
143
+ dims = dims_map["tiny"]
144
+ elif "large" in model_name.lower():
145
+ dims = dims_map["large"]
146
+ else:
147
+ dims = dims_map["base"] # default to base
148
+
149
+ self.cross_attn_blocks = nn.ModuleList([
150
+ CrossFusionBlock(dim, dim, num_heads=num_head_xa,
151
+ mlp_ratio=mlp_ratio, drop=0.0, drop_path=drop_path)
152
+ for dim in dims
153
+ ])
154
+
155
+ def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> Dict[str, torch.Tensor]:
156
+ """
157
+ Forward pass through CASWiT model.
158
+
159
+ Args:
160
+ x_hr: HR input images [B, 3, H_hr, W_hr]
161
+ x_lr: LR input images [B, 3, H_lr, W_lr]
162
+
163
+ Returns:
164
+ Dictionary with 'logits_hr' and 'logits_lr' segmentation logits
165
+ """
166
+ B = x_hr.size(0)
167
+
168
+ # Patch embeddings
169
+ x_hr_seq, _ = self.embeddings_hr(x_hr)
170
+ x_lr_seq, _ = self.embeddings_lr(x_lr)
171
+
172
+ N_hr, C_hr = x_hr_seq.shape[1], x_hr_seq.shape[2]
173
+ N_lr, C_lr = x_lr_seq.shape[1], x_lr_seq.shape[2]
174
+ H_hr = W_hr = int(math.sqrt(N_hr))
175
+ H_lr = W_lr = int(math.sqrt(N_lr))
176
+ dims_hr = (H_hr, W_hr)
177
+ dims_lr = (H_lr, W_lr)
178
+
179
+ features_hr: Dict[str, torch.Tensor] = {}
180
+ features_lr: Dict[str, torch.Tensor] = {}
181
+
182
+ # Process through encoder stages with cross-attention fusion
183
+ for idx, (stage_hr, stage_lr, ca) in enumerate(zip(
184
+ self.encoder_layers_hr, self.encoder_layers_lr, self.cross_attn_blocks
185
+ )):
186
+ # HR branch blocks
187
+ for block in stage_hr.blocks:
188
+ x_hr_seq = block(x_hr_seq, dims_hr)
189
+ if isinstance(x_hr_seq, tuple):
190
+ x_hr_seq = x_hr_seq[0]
191
+
192
+ # LR branch blocks
193
+ for block in stage_lr.blocks:
194
+ x_lr_seq = block(x_lr_seq, dims_lr)
195
+ if isinstance(x_lr_seq, tuple):
196
+ x_lr_seq = x_lr_seq[0]
197
+
198
+ # Layer normalization
199
+ x_hr_seq = self.hidden_states_norms_hr[f"stage{idx+1}"](x_hr_seq)
200
+ x_lr_seq = self.hidden_states_norms_lr[f"stage{idx+1}"](x_lr_seq)
201
+
202
+ H_hr, W_hr = dims_hr
203
+ H_lr, W_lr = dims_lr
204
+ C_hr = x_hr_seq.shape[-1]
205
+ C_lr = x_lr_seq.shape[-1]
206
+
207
+ # Reshape to spatial format
208
+ feat_hr = x_hr_seq.transpose(1, 2).contiguous().view(B, C_hr, H_hr, W_hr)
209
+ feat_lr = x_lr_seq.transpose(1, 2).contiguous().view(B, C_lr, H_lr, W_lr)
210
+
211
+ # Cross-attend HR to LR
212
+ fused_hr = ca(feat_hr, feat_lr)
213
+ fused_hr_seq = fused_hr.flatten(2).transpose(1, 2).contiguous()
214
+
215
+ # Downsample if stage has it
216
+ if stage_hr.downsample is not None:
217
+ fused_hr_seq = stage_hr.downsample(fused_hr_seq, dims_hr)
218
+ dims_hr = (dims_hr[0] // 2, dims_hr[1] // 2)
219
+ if stage_lr.downsample is not None:
220
+ x_lr_seq = stage_lr.downsample(x_lr_seq, dims_lr)
221
+ dims_lr = (dims_lr[0] // 2, dims_lr[1] // 2)
222
+
223
+ features_hr[f"stage{idx+1}"] = fused_hr
224
+ features_lr[f"stage{idx+1}"] = feat_lr
225
+ x_hr_seq = fused_hr_seq
226
+
227
+ # Decode HR features
228
+ features_tuple = (
229
+ features_hr["stage1"],
230
+ features_hr["stage2"],
231
+ features_hr["stage3"],
232
+ features_hr["stage4"],
233
+ )
234
+ logits = self.decoder(features_tuple)
235
+
236
+ # Decode LR features (for auxiliary supervision)
237
+ features_tuple_lr = (
238
+ features_lr["stage1"],
239
+ features_lr["stage2"],
240
+ features_lr["stage3"],
241
+ features_lr["stage4"],
242
+ )
243
+ logits_lr = self.decoder_lr(features_tuple_lr)
244
+
245
+ return {"logits_hr": logits, "logits_lr": logits_lr}
246
+
model/CASWiT_ssl.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CASWiT Self-Supervised Learning (SSL) Module
3
+
4
+ Implements SimMIM-based self-supervised pre-training for CASWiT using
5
+ masked image modeling with dual-branch HR/LR processing.
6
+ """
7
+
8
+ import math
9
+ from typing import Optional, Tuple
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from transformers import UperNetForSemanticSegmentation
14
+
15
+
16
+ def random_masking_with_tokens(x: torch.Tensor, mask_ratio: float = 0.75,
17
+ mask_token: Optional[torch.Tensor] = None):
18
+ """
19
+ Random masking at token level with learned mask token.
20
+
21
+ Args:
22
+ x: Input tokens [B, N, C]
23
+ mask_ratio: Ratio of tokens to mask
24
+ mask_token: Learnable mask token
25
+
26
+ Returns:
27
+ x_masked: Masked tokens [B, N, C]
28
+ mask: Binary mask [B, N] where 0=visible, 1=masked
29
+ ids_restore: Indices to restore original order
30
+ """
31
+ B, N, C = x.shape
32
+ len_keep = int(N * (1 - mask_ratio))
33
+
34
+ noise = torch.rand(B, N, device=x.device)
35
+ ids_shuffle = torch.argsort(noise, dim=1)
36
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
37
+ ids_keep = ids_shuffle[:, :len_keep]
38
+
39
+ x_keep = torch.gather(x, 1, ids_keep.unsqueeze(-1).expand(-1, -1, C))
40
+
41
+ if mask_token is None:
42
+ mask_token = torch.zeros((1, C), device=x.device)
43
+ m_tok = mask_token.view(1, 1, C).expand(B, N - len_keep, C)
44
+
45
+ x_cat = torch.cat([x_keep, m_tok], dim=1)
46
+ x_masked = torch.gather(x_cat, 1, ids_restore.unsqueeze(-1).expand(-1, -1, C))
47
+
48
+ mask = torch.ones(B, N, device=x.device)
49
+ mask[:, :len_keep] = 0
50
+ mask = torch.gather(mask, 1, ids_restore)
51
+ return x_masked, mask, ids_restore
52
+
53
+
54
+ def center_masking_with_tokens(x: torch.Tensor, mask_token: Optional[torch.Tensor] = None,
55
+ mask_ratio: float = 0.5):
56
+ """
57
+ Deterministic centered square mask.
58
+
59
+ Args:
60
+ x: Input tokens [B, N, C]
61
+ mask_token: Learnable mask token
62
+ mask_ratio: Ratio of tokens to mask
63
+
64
+ Returns:
65
+ x_masked: Masked tokens [B, N, C]
66
+ mask: Binary mask [B, N]
67
+ ids_restore: Indices to restore original order
68
+ """
69
+ B, N, C = x.shape
70
+ H = W = int(N**0.5)
71
+ assert H * W == N, "N must be a perfect square"
72
+ L = int(round(H * (mask_ratio ** 0.5)))
73
+ start = (H - L) // 2
74
+ end = start + L
75
+
76
+ mask_2d = torch.zeros(H, W, device=x.device, dtype=torch.bool)
77
+ mask_2d[start:end, start:end] = True
78
+ mask = mask_2d.view(1, -1).expand(B, -1) # (B,N)
79
+
80
+ if mask_token is None:
81
+ mask_token = torch.zeros(C, device=x.device)
82
+ mask_token = mask_token.view(-1)
83
+
84
+ x_masked = x * (~mask).unsqueeze(-1) + mask.unsqueeze(-1) * mask_token.view(1, 1, C)
85
+ ids_restore = torch.arange(N, device=x.device).unsqueeze(0).expand(B, N)
86
+ return x_masked, mask.to(x_masked.dtype), ids_restore
87
+
88
+
89
+ class CrossAttentionBlock(nn.Module):
90
+ """Simplified cross-attention block for SSL."""
91
+ def __init__(self, C_hr, C_lr, num_heads=8, dropout=0.0):
92
+ super().__init__()
93
+ self.cross_attn = nn.MultiheadAttention(
94
+ embed_dim=C_hr, num_heads=num_heads, kdim=C_lr, vdim=C_lr,
95
+ dropout=dropout, batch_first=True
96
+ )
97
+ self.norm = nn.LayerNorm(C_hr)
98
+ self.mlp = nn.Sequential(
99
+ nn.LayerNorm(C_hr),
100
+ nn.Linear(C_hr, C_hr * 4),
101
+ nn.GELU(),
102
+ nn.Linear(C_hr * 4, C_hr),
103
+ )
104
+
105
+ def forward(self, x_hr, x_lr):
106
+ B, C_hr, H_hr, W_hr = x_hr.shape
107
+ _, C_lr, H_lr, W_lr = x_lr.shape
108
+ q = x_hr.flatten(2).transpose(1, 2) # (B,N_hr,C_hr)
109
+ kv = x_lr.flatten(2).transpose(1, 2) # (B,N_lr,C_lr)
110
+ attn_out, _ = self.cross_attn(q, kv, kv)
111
+ y = self.norm(q + attn_out)
112
+ y = y + self.mlp(y)
113
+ return y.transpose(1, 2).view(B, C_hr, H_hr, W_hr)
114
+
115
+
116
+ class CASWiT_SSL(nn.Module):
117
+ """
118
+ CASWiT Self-Supervised Learning model using SimMIM.
119
+
120
+ Encoder: Dual Swin backbones with cross-attention blocks
121
+ Decoder: Conv1x1 + PixelShuffle for reconstruction
122
+ Masking: HR random masking, LR center masking
123
+
124
+ Args:
125
+ model_name: HuggingFace model identifier
126
+ mask_ratio_hr: Masking ratio for HR branch
127
+ mask_ratio_lr: Masking ratio for LR branch
128
+ patch_size: Patch size for masking
129
+ encoder_stride: Encoder stride for decoder
130
+ xa_heads: Number of cross-attention heads per stage
131
+ """
132
+ def __init__(self, model_name: str = "openmmlab/upernet-swin-base",
133
+ mask_ratio_hr: float = 0.75, mask_ratio_lr: float = 0.5,
134
+ patch_size: int = 4, encoder_stride: int = 32,
135
+ xa_heads: Tuple[int, int, int, int] = (8, 8, 8, 8)):
136
+ super().__init__()
137
+ self.mask_ratio_hr = mask_ratio_hr
138
+ self.mask_ratio_lr = mask_ratio_lr
139
+ self.patch_size = patch_size
140
+ self.encoder_stride = encoder_stride
141
+
142
+ # Load two UPerNet (Swin) backbones
143
+ model_hr = UperNetForSemanticSegmentation.from_pretrained(
144
+ model_name, ignore_mismatched_sizes=True
145
+ )
146
+ model_lr = UperNetForSemanticSegmentation.from_pretrained(
147
+ model_name, ignore_mismatched_sizes=True
148
+ )
149
+
150
+ self.embeddings_hr = model_hr.backbone.embeddings
151
+ self.encoder_layers_hr = model_hr.backbone.encoder.layers
152
+ self.hidden_states_norms_hr = model_hr.backbone.hidden_states_norms
153
+
154
+ self.embeddings_lr = model_lr.backbone.embeddings
155
+ self.encoder_layers_lr = model_lr.backbone.encoder.layers
156
+ self.hidden_states_norms_lr = model_lr.backbone.hidden_states_norms
157
+
158
+ # Cross-attention blocks with explicit Swin-Base dims
159
+ dims = [128, 256, 512, 1024]
160
+ self.cross_attn_blocks = nn.ModuleList([
161
+ CrossAttentionBlock(d, d, num_heads=h) for d, h in zip(dims, xa_heads)
162
+ ])
163
+
164
+ # Learnable mask tokens
165
+ self.mask_token_hr = nn.Parameter(torch.zeros(1, dims[0]))
166
+ self.mask_token_lr = nn.Parameter(torch.zeros(1, dims[0]))
167
+
168
+ # SimMIM decoder: Conv1×1 → PixelShuffle(stride)
169
+ self.decoder_conv = None # lazy init after we know C_last
170
+ self.decoder_shuffle = nn.PixelShuffle(self.encoder_stride)
171
+
172
+ # Store masks for visualization
173
+ self.last_mask_hr = None
174
+ self.last_mask_lr = None
175
+
176
+ def _encode(self, x_hr: torch.Tensor, x_lr: torch.Tensor):
177
+ """Encode with masking and return reconstruction targets."""
178
+ B, C, H, W = x_hr.shape
179
+ target_img = x_hr
180
+ target_lr = x_lr
181
+
182
+ # Patch embeddings
183
+ x_hr_seq, _ = self.embeddings_hr(x_hr) # (B, N_hr, C1)
184
+ x_lr_seq, _ = self.embeddings_lr(x_lr) # (B, N_lr, C1)
185
+
186
+ # Masking
187
+ x_hr_seq, mask_hr, _ = random_masking_with_tokens(
188
+ x_hr_seq, self.mask_ratio_hr, self.mask_token_hr
189
+ )
190
+ x_lr_seq, mask_lr, _ = center_masking_with_tokens(
191
+ x_lr_seq, self.mask_token_lr, mask_ratio=self.mask_ratio_lr
192
+ )
193
+
194
+ # Initial spatial dims
195
+ H_hr = W_hr = int(math.sqrt(x_hr_seq.shape[1]))
196
+ H_lr = W_lr = int(math.sqrt(x_lr_seq.shape[1]))
197
+ dims_hr = (H_hr, W_hr)
198
+ dims_lr = (H_lr, W_lr)
199
+
200
+ # Walk encoder stages with cross attention at each stage
201
+ for idx, (stage_hr, stage_lr, ca) in enumerate(zip(
202
+ self.encoder_layers_hr, self.encoder_layers_lr, self.cross_attn_blocks
203
+ )):
204
+ # HR blocks
205
+ for block in stage_hr.blocks:
206
+ x_hr_seq = block(x_hr_seq, dims_hr)
207
+ if isinstance(x_hr_seq, tuple):
208
+ x_hr_seq = x_hr_seq[0]
209
+ # LR blocks
210
+ for block in stage_lr.blocks:
211
+ x_lr_seq = block(x_lr_seq, dims_lr)
212
+ if isinstance(x_lr_seq, tuple):
213
+ x_lr_seq = x_lr_seq[0]
214
+
215
+ # Norms
216
+ x_hr_seq = self.hidden_states_norms_hr[f"stage{idx+1}"](x_hr_seq)
217
+ x_lr_seq = self.hidden_states_norms_lr[f"stage{idx+1}"](x_lr_seq)
218
+
219
+ # Maps
220
+ B_, N_hr_, C_hr_ = x_hr_seq.shape
221
+ B_, N_lr_, C_lr_ = x_lr_seq.shape
222
+ Hh, Wh = dims_hr
223
+ Hl, Wl = dims_lr
224
+ feat_hr = x_hr_seq.transpose(1, 2).contiguous().view(B_, C_hr_, Hh, Wh)
225
+ feat_lr = x_lr_seq.transpose(1, 2).contiguous().view(B_, C_lr_, Hl, Wl)
226
+
227
+ # Cross-fuse HR <- LR
228
+ fused_hr = ca(feat_hr, feat_lr)
229
+ x_hr_seq = fused_hr.flatten(2).transpose(1, 2).contiguous()
230
+
231
+ # Downsample to next stage
232
+ if stage_hr.downsample is not None:
233
+ x_hr_seq = stage_hr.downsample(x_hr_seq, dims_hr)
234
+ dims_hr = (dims_hr[0] // 2, dims_hr[1] // 2)
235
+ if stage_lr.downsample is not None:
236
+ x_lr_seq = stage_lr.downsample(x_lr_seq, dims_lr)
237
+ dims_lr = (dims_lr[0] // 2, dims_lr[1] // 2)
238
+
239
+ # Last-stage feature map z (B, C_last, H/stride, W/stride)
240
+ Hs, Ws = dims_hr
241
+ C_last = x_hr_seq.shape[-1]
242
+ z = x_hr_seq.transpose(1, 2).contiguous().view(B, C_last, Hs, Ws)
243
+
244
+ # Lazy init decoder conv
245
+ if self.decoder_conv is None:
246
+ self.decoder_conv = nn.Conv2d(
247
+ C_last, (self.encoder_stride ** 2) * 3, kernel_size=1
248
+ ).to(z.device)
249
+
250
+ # Reconstruction
251
+ x_rec = self.decoder_shuffle(self.decoder_conv(z)) # (B,3,H,W)
252
+
253
+ # Convert patch masks to pixel masks
254
+ Mh = int(math.sqrt(mask_hr.shape[1]))
255
+ mask_patch_hr = mask_hr.view(B, Mh, Mh)
256
+ mask_pix_hr = mask_patch_hr.repeat_interleave(
257
+ self.patch_size, 1
258
+ ).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous()
259
+
260
+ Ml = int(math.sqrt(mask_lr.shape[1]))
261
+ mask_patch_lr = mask_lr.view(B, Ml, Ml)
262
+ mask_pix_lr = mask_patch_lr.repeat_interleave(
263
+ self.patch_size, 1
264
+ ).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous()
265
+
266
+ self.last_mask_hr = mask_patch_hr
267
+ self.last_mask_lr = mask_patch_lr
268
+
269
+ return x_rec, target_img, mask_pix_hr, target_lr, mask_pix_lr
270
+
271
+ def forward(self, x_hr: torch.Tensor, x_lr: torch.Tensor) -> torch.Tensor:
272
+ """
273
+ Forward pass for SSL training.
274
+
275
+ Returns reconstruction loss on masked pixels only.
276
+ """
277
+ x_rec, target_img, mask_pix, _, _ = self._encode(x_hr, x_lr)
278
+ loss_recon = F.l1_loss(target_img, x_rec, reduction='none')
279
+ loss = (loss_recon * mask_pix).sum() / (mask_pix.sum() + 1e-6) / target_img.shape[1]
280
+ return loss
281
+
282
+ @torch.no_grad()
283
+ def forward_outputs(self, x_hr: torch.Tensor, x_lr: torch.Tensor):
284
+ """Forward pass returning all outputs for visualization."""
285
+ x_rec, target_img, mask_pix_hr, target_lr, mask_pix_lr = self._encode(x_hr, x_lr)
286
+ return x_rec, target_img, mask_pix_hr, target_lr, mask_pix_lr
287
+
model/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CASWiT model implementations.
3
+ """
4
+
5
+ from model.CASWiT import CASWiT, CrossFusionBlock
6
+ from model.CASWiT_ssl import CASWiT_SSL
7
+
8
+ __all__ = ['CASWiT', 'CrossFusionBlock', 'CASWiT_SSL']
9
+
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ transformers>=4.30.0
5
+ timm>=0.9.0
6
+
7
+ # Data processing
8
+ numpy>=1.21.0
9
+ pillow>=9.0.0
10
+ tifffile>=2023.0.0
11
+ opencv-python>=4.5.0
12
+ pandas>=1.3.0
13
+ rasterio>=1.3.0
14
+ affine>=2.3.0
15
+
16
+ # Training utilities
17
+ tqdm>=4.64.0
18
+ PyYAML>=6.0
19
+ scikit-learn>=1.0.0
20
+
21
+ # Logging and visualization
22
+ wandb>=0.15.0
23
+ matplotlib>=3.5.0
24
+
25
+ # Optional: for distributed training
26
+ accelerate>=0.20.0
27
+
train/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training, evaluation, and inference scripts.
3
+ """
4
+
5
+ from train.train import TrainConfig, load_config
6
+
7
+ __all__ = ['TrainConfig', 'load_config']
8
+
train/eval.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation script for CASWiT model.
3
+
4
+ Evaluates a trained model on test/validation sets and computes metrics.
5
+ """
6
+
7
+ import sys
8
+ import yaml
9
+ import logging
10
+ from pathlib import Path
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch.utils.data import DataLoader
14
+ from torch.cuda.amp import autocast
15
+ from tqdm import tqdm
16
+
17
+ # Add project root to Python path
18
+ project_root = Path(__file__).parent.parent
19
+ sys.path.insert(0, str(project_root))
20
+
21
+ from model.CASWiT import CASWiT
22
+ from dataset.definition_dataset import SemanticSegmentationDatasetFusion, URURHRLRDataset, build_transforms
23
+ from utils.metrics import compute_metrics_from_confusion
24
+ from train.train import load_config, TrainConfig
25
+
26
+
27
+ def evaluate_model(cfg: TrainConfig, checkpoint_path: str, split: str = "test"):
28
+ """
29
+ Evaluate model on specified split.
30
+
31
+ Args:
32
+ cfg: Training configuration
33
+ checkpoint_path: Path to model checkpoint
34
+ split: Dataset split to evaluate ('test' or 'val')
35
+ """
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+
38
+ # Validate checkpoint path
39
+ checkpoint_path_obj = Path(checkpoint_path)
40
+ if not checkpoint_path_obj.exists() or not checkpoint_path_obj.is_file():
41
+ raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
42
+
43
+ # Load model
44
+ model = CASWiT(
45
+ num_head_xa=cfg.cross_attention_heads,
46
+ num_classes=cfg.num_classes,
47
+ model_name=cfg.model_name,
48
+ mlp_ratio=cfg.fusion_mlp_ratio,
49
+ drop_path=cfg.fusion_drop_path
50
+ ).to(device)
51
+
52
+ # Load checkpoint
53
+ print(f"Loading checkpoint from: {checkpoint_path}")
54
+ state_dict = torch.load(checkpoint_path, map_location=device)
55
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
56
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
57
+ print(f"Successfully loaded checkpoint from: {checkpoint_path}")
58
+ if len(missing) > 0:
59
+ print(f" Missing keys: {len(missing)}")
60
+ if len(unexpected) > 0:
61
+ print(f" Unexpected keys: {len(unexpected)}")
62
+ if len(missing) == 0 and len(unexpected) == 0:
63
+ print(f" Perfect match! All weights loaded successfully.")
64
+ model.eval()
65
+ dataset_name = cfg.dataset_name
66
+
67
+ # Create dataset with sliding window for test, simple crop for val
68
+ t = build_transforms()
69
+ base = Path(cfg.data_path)
70
+ if dataset_name == "URUR":
71
+ # Use sliding window without tiling for URUR (full image coverage)
72
+ ds = URURHRLRDataset(
73
+ image_dir=base / cfg.test_img_subdir,
74
+ mask_dir=base / cfg.test_msk_subdir,
75
+ num_classes=cfg.num_classes,
76
+ mode="test",
77
+ ignore_index=cfg.ignore_index,
78
+ transform=t
79
+ )
80
+ else:
81
+ # Use simple center crop with FLAIRHUB
82
+ ds = SemanticSegmentationDatasetFusion(base / cfg.test_img_subdir, base / cfg.test_msk_subdir, transform=t)
83
+
84
+ dl = DataLoader(ds, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)
85
+
86
+ # Evaluate
87
+ criterion = torch.nn.CrossEntropyLoss(ignore_index=cfg.ignore_index)
88
+ running_loss = 0.0
89
+ full_confmat = torch.zeros((cfg.num_classes, cfg.num_classes), dtype=torch.long, device=device)
90
+
91
+ with torch.inference_mode():
92
+ for batch in tqdm(dl, desc=f"Evaluating {split}"):
93
+ # Handle datasets that return meta dict (URURHRLRDataset returns 5 values)
94
+ if len(batch) == 5:
95
+ images_hr, masks_hr, images_lr, masks_lr, _ = batch
96
+ else:
97
+ images_hr, masks_hr, images_lr, masks_lr = batch
98
+ images_hr = images_hr.to(device, non_blocking=True)
99
+ masks_hr = masks_hr.to(device, non_blocking=True)
100
+ images_lr = images_lr.to(device, non_blocking=True)
101
+ masks_lr = masks_lr.to(device, non_blocking=True)
102
+
103
+ with autocast():
104
+ out = model(images_hr, images_lr)
105
+ logits_hr = out["logits_hr"]
106
+ logits_hr = F.interpolate(logits_hr, size=masks_hr.shape[-2:], mode="bilinear", align_corners=False)
107
+ loss = criterion(logits_hr, masks_hr)
108
+
109
+ running_loss += float(loss.item())
110
+
111
+ preds = torch.argmax(logits_hr, dim=1)
112
+ valid = (masks_hr >= 0) & (masks_hr < cfg.num_classes)
113
+ t = masks_hr[valid]
114
+ p = preds[valid]
115
+
116
+ cm = torch.bincount(
117
+ (t * cfg.num_classes + p).view(-1),
118
+ minlength=cfg.num_classes * cfg.num_classes
119
+ ).reshape(cfg.num_classes, cfg.num_classes)
120
+ full_confmat += cm
121
+
122
+ avg_loss = running_loss / len(dl)
123
+ confmat_np = full_confmat.cpu().numpy()
124
+ metrics = compute_metrics_from_confusion(confmat_np)
125
+
126
+ print(f"\n{split.upper()} Results:")
127
+ print(f" Loss: {avg_loss:.4f}")
128
+ print(f" mIoU: {metrics['mIoU']:.4f}")
129
+ print(f" mF1: {metrics['mF1']:.4f}")
130
+ print(f" Per-class IoU: {metrics['IoUs']}")
131
+
132
+ return metrics
133
+
134
+
135
+ def main():
136
+ """Main evaluation function."""
137
+ import sys
138
+ if len(sys.argv) < 3:
139
+ print("Usage: python eval.py <config_path> <checkpoint_path> [split]")
140
+ sys.exit(1)
141
+
142
+ cfg_path = sys.argv[1]
143
+ checkpoint_path = sys.argv[2]
144
+ split = sys.argv[3] if len(sys.argv) > 3 else "test"
145
+
146
+ logging.basicConfig(level=logging.INFO)
147
+ cfg = load_config(cfg_path)
148
+ evaluate_model(cfg, checkpoint_path, split)
149
+
150
+
151
+ if __name__ == "__main__":
152
+ main()
153
+
train/inference.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference script for CASWiT model.
3
+
4
+ Performs inference on images and saves predictions.
5
+ """
6
+
7
+ import sys
8
+ import yaml
9
+ import logging
10
+ from pathlib import Path
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from PIL import Image
15
+ from tifffile import imread
16
+ from torchvision import transforms
17
+
18
+ # Add project root to Python path
19
+ project_root = Path(__file__).parent.parent
20
+ sys.path.insert(0, str(project_root))
21
+
22
+ from model.CASWiT import CASWiT
23
+ from dataset.definition_dataset import build_transforms
24
+ from train.train import load_config
25
+
26
+
27
+ def inference_single_image(model, image_path: str, device, transform, output_path: str = None):
28
+ """
29
+ Run inference on a single image.
30
+
31
+ Args:
32
+ model: Trained CASWiT model
33
+ image_path: Path to input image
34
+ device: Device to run inference on
35
+ transform: Image transform
36
+ output_path: Path to save prediction (optional)
37
+ """
38
+ # Load and preprocess image
39
+ image = imread(str(image_path))[:, :, :3]
40
+
41
+ # Create HR and LR versions
42
+ crop_x, crop_y = 256, 256
43
+ image_hr = image[crop_x:crop_x + 512, crop_y:crop_y + 512]
44
+ image_lr = image[::2, ::2, :]
45
+
46
+ # Transform
47
+ img_hr_tensor = transform(Image.fromarray(image_hr)).unsqueeze(0).to(device)
48
+ img_lr_tensor = transform(Image.fromarray(image_lr)).unsqueeze(0).to(device)
49
+
50
+ # Inference
51
+ model.eval()
52
+ with torch.no_grad():
53
+ out = model(img_hr_tensor, img_lr_tensor)
54
+ logits = out["logits_hr"]
55
+ pred = torch.argmax(logits, dim=1).squeeze(0).cpu().numpy()
56
+
57
+ if output_path:
58
+ # Save prediction as image
59
+ pred_img = Image.fromarray(pred.astype(np.uint8))
60
+ pred_img.save(output_path)
61
+ print(f"Prediction saved to {output_path}")
62
+
63
+ return pred
64
+
65
+
66
+ def main():
67
+ """Main inference function."""
68
+ import sys
69
+ if len(sys.argv) < 4:
70
+ print("Usage: python inference.py <config_path> <checkpoint_path> <image_path> [output_path]")
71
+ sys.exit(1)
72
+
73
+ cfg_path = sys.argv[1]
74
+ checkpoint_path = sys.argv[2]
75
+ image_path = sys.argv[3]
76
+ output_path = sys.argv[4] if len(sys.argv) > 4 else "prediction.png"
77
+
78
+ logging.basicConfig(level=logging.INFO)
79
+ cfg = load_config(cfg_path)
80
+
81
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
+
83
+ # Validate checkpoint path
84
+ checkpoint_path_obj = Path(checkpoint_path)
85
+ if not checkpoint_path_obj.exists() or not checkpoint_path_obj.is_file():
86
+ raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
87
+
88
+ # Load model
89
+ model = CASWiT(
90
+ num_head_xa=cfg.cross_attention_heads,
91
+ num_classes=cfg.num_classes,
92
+ model_name=cfg.model_name,
93
+ mlp_ratio=cfg.fusion_mlp_ratio,
94
+ drop_path=cfg.fusion_drop_path
95
+ ).to(device)
96
+
97
+ # Load checkpoint
98
+ print(f"Loading checkpoint from: {checkpoint_path}")
99
+ state_dict = torch.load(checkpoint_path, map_location=device)
100
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
101
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
102
+ print(f"Successfully loaded checkpoint from: {checkpoint_path}")
103
+ if len(missing) > 0:
104
+ print(f" Missing keys: {len(missing)}")
105
+ if len(unexpected) > 0:
106
+ print(f" Unexpected keys: {len(unexpected)}")
107
+ if len(missing) == 0 and len(unexpected) == 0:
108
+ print(f" Perfect match! All weights loaded successfully.")
109
+
110
+ # Run inference
111
+ transform = build_transforms()
112
+ pred = inference_single_image(model, image_path, device, transform, output_path)
113
+ print(f"Inference complete. Prediction shape: {pred.shape}")
114
+
115
+
116
+ if __name__ == "__main__":
117
+ main()
118
+
train/train.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training script for CASWiT model.
3
+
4
+ Supports distributed training with DDP, mixed precision, and WandB logging.
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import yaml
10
+ import logging
11
+ from pathlib import Path
12
+ from dataclasses import dataclass
13
+ from typing import Tuple
14
+
15
+ # Add project root to Python path
16
+ project_root = Path(__file__).parent.parent
17
+ sys.path.insert(0, str(project_root))
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ import torch.optim as optim
24
+ import torch.distributed as dist
25
+ from torch.utils.data import DataLoader
26
+ from torch.utils.data.distributed import DistributedSampler
27
+ from torch.cuda.amp import GradScaler, autocast
28
+ from torch.optim.lr_scheduler import CosineAnnealingLR
29
+ from tqdm import tqdm
30
+
31
+ from model.CASWiT import CASWiT
32
+ from dataset.definition_dataset import SemanticSegmentationDatasetFusion, URURHRLRDataset, build_transforms
33
+ from utils.metrics import compute_metrics_from_confusion
34
+ from utils.logging import setup_wandb_logging, log_metrics
35
+
36
+ try:
37
+ import wandb
38
+ except ImportError:
39
+ wandb = None
40
+
41
+
42
+ @dataclass
43
+ class TrainConfig:
44
+ """Training configuration dataclass."""
45
+ # Paths
46
+ data_path: str
47
+ dataset_name: str
48
+ train_img_subdir: str
49
+ train_msk_subdir: str
50
+ val_img_subdir: str
51
+ val_msk_subdir: str
52
+ test_img_subdir: str
53
+ test_msk_subdir: str
54
+ save_dir: str = "weights"
55
+ pretrained_path: str = ""
56
+
57
+ # Model
58
+ model_name: str = "openmmlab/upernet-swin-tiny"
59
+ num_classes: int = 15
60
+ cross_attention_heads: int = 1
61
+ ignore_index: int = 255
62
+ fusion_mlp_ratio: float = 4.0
63
+ fusion_drop_path: float = 0.1
64
+ lr_supervision_weight: float = 0.5
65
+
66
+ # Training
67
+ batch_size: int = 4
68
+ num_workers: int = 4
69
+ num_epochs: int = 50
70
+ learning_rate: float = 1e-4
71
+ amp: bool = True
72
+ seed: int = 1337
73
+ eta_min: float = 1e-6
74
+
75
+ # Logging
76
+ use_wandb: bool = True
77
+ wandb_project: str = "your_project"
78
+ wandb_entity: str = "your_entity"
79
+ wandb_run_name: str = "hrlr_fusion"
80
+
81
+ # Misc
82
+ print_device: bool = True
83
+
84
+
85
+ def load_config(cfg_path: str) -> TrainConfig:
86
+ """Load configuration from YAML file."""
87
+ with open(cfg_path, "r") as f:
88
+ raw = yaml.safe_load(f)
89
+
90
+ training = raw.get("training", {})
91
+ model = raw.get("model", {})
92
+ wandb_cfg = raw.get("wandb", {})
93
+ paths = raw.get("paths", {})
94
+
95
+ return TrainConfig(
96
+ data_path=paths.get("data_path", ""),
97
+ dataset_name=paths.get("dataset_name", ""),
98
+ train_img_subdir=paths.get("train_img_subdir", "train/img"),
99
+ train_msk_subdir=paths.get("train_msk_subdir", "train/msk"),
100
+ val_img_subdir=paths.get("val_img_subdir", "val/img"),
101
+ val_msk_subdir=paths.get("val_msk_subdir", "val/msk"),
102
+ test_img_subdir=paths.get("test_img_subdir", "test/img"),
103
+ test_msk_subdir=paths.get("test_msk_subdir", "test/msk"),
104
+ save_dir=paths.get("save_dir", "weights"),
105
+ pretrained_path=paths.get("pretrained_path", ""),
106
+ model_name=model.get("model_name", "openmmlab/upernet-swin-tiny"),
107
+ num_classes=int(model.get("num_classes", 12)),
108
+ cross_attention_heads=int(model.get("cross_attention_heads", 1)),
109
+ ignore_index=int(model.get("ignore_index", 255)),
110
+ fusion_mlp_ratio=float(model.get("fusion_mlp_ratio", 4.0)),
111
+ fusion_drop_path=float(model.get("fusion_drop_path", 0.1)),
112
+ lr_supervision_weight=float(training.get("lr_supervision_weight", 0.5)),
113
+ batch_size=int(training.get("batch_size", 8)),
114
+ num_workers=int(training.get("num_workers", 4)),
115
+ num_epochs=int(training.get("num_epochs", 50)),
116
+ learning_rate=float(training.get("learning_rate", 1e-4)),
117
+ amp=bool(training.get("amp", True)),
118
+ seed=int(training.get("seed", 1337)),
119
+ eta_min=float(training.get("eta_min", 1e-6)),
120
+ use_wandb=bool(wandb_cfg.get("use_wandb", True)),
121
+ wandb_project=wandb_cfg.get("project", "your_project"),
122
+ wandb_entity=wandb_cfg.get("entity", "your_entity"),
123
+ wandb_run_name=wandb_cfg.get("run_name", "hrlr_fusion"),
124
+ print_device=bool(raw.get("print_device", True)),
125
+ )
126
+
127
+
128
+ def set_seed(seed: int):
129
+ """Set random seeds for reproducibility."""
130
+ import random
131
+ random.seed(seed)
132
+ np.random.seed(seed)
133
+ torch.manual_seed(seed)
134
+ torch.cuda.manual_seed_all(seed)
135
+ torch.backends.cudnn.deterministic = True
136
+ torch.backends.cudnn.benchmark = False
137
+
138
+
139
+ def is_distributed() -> bool:
140
+ """Check if running in distributed mode."""
141
+ return int(os.environ.get("WORLD_SIZE", "1")) > 1
142
+
143
+
144
+ def get_rank() -> int:
145
+ """Get current process rank."""
146
+ return int(os.environ.get("RANK", "0"))
147
+
148
+
149
+ def get_local_rank() -> int:
150
+ """Get local rank."""
151
+ return int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0")))
152
+
153
+
154
+ def is_main_process() -> bool:
155
+ """Check if this is the main process."""
156
+ return get_rank() == 0
157
+
158
+
159
+ def setup_distributed():
160
+ """Setup distributed training."""
161
+ if is_distributed() and not dist.is_initialized():
162
+ dist.init_process_group(backend="nccl")
163
+ local_rank = get_local_rank()
164
+ torch.cuda.set_device(local_rank)
165
+
166
+
167
+ def cleanup_distributed():
168
+ """Cleanup distributed training."""
169
+ if dist.is_initialized():
170
+ dist.destroy_process_group()
171
+
172
+
173
+ def make_dataloaders(cfg: TrainConfig) -> Tuple[DataLoader, DataLoader, DataLoader,
174
+ DistributedSampler, DistributedSampler, DistributedSampler]:
175
+ """Create data loaders for train, val, and test splits."""
176
+ t = build_transforms()
177
+ base = Path(cfg.data_path)
178
+ dataset_name = cfg.dataset_name
179
+ if dataset_name == "URUR":
180
+ ds_train = URURHRLRDataset(image_dir=base / cfg.train_img_subdir, mask_dir=base / cfg.train_msk_subdir,
181
+ num_classes=cfg.num_classes, mode="train", ignore_index=cfg.ignore_index,
182
+ transform=t)
183
+ ds_val = URURHRLRDataset(image_dir=base / cfg.val_img_subdir, mask_dir=base / cfg.val_msk_subdir,
184
+ num_classes=cfg.num_classes, mode="val", ignore_index=cfg.ignore_index,
185
+ transform=t)
186
+ ds_test = URURHRLRDataset(image_dir=base / cfg.test_img_subdir, mask_dir=base / cfg.test_msk_subdir,
187
+ num_classes=cfg.num_classes, mode="test", ignore_index=cfg.ignore_index,
188
+ transform=t)
189
+ else: # FLAIRHUB
190
+ ds_train = SemanticSegmentationDatasetFusion(base / cfg.train_img_subdir, base / cfg.train_msk_subdir, transform=t)
191
+ ds_val = SemanticSegmentationDatasetFusion(base / cfg.val_img_subdir, base / cfg.val_msk_subdir, transform=t)
192
+ ds_test = SemanticSegmentationDatasetFusion(base / cfg.test_img_subdir, base / cfg.test_msk_subdir, transform=t)
193
+
194
+ if is_distributed():
195
+ train_sampler = DistributedSampler(ds_train, shuffle=True)
196
+ val_sampler = DistributedSampler(ds_val, shuffle=False)
197
+ test_sampler = DistributedSampler(ds_test, shuffle=False)
198
+ shuffle_train = False
199
+ shuffle_eval = False
200
+ else:
201
+ train_sampler = None
202
+ val_sampler = None
203
+ test_sampler = None
204
+ shuffle_train = True
205
+ shuffle_eval = False
206
+
207
+ dl_train = DataLoader(ds_train, batch_size=cfg.batch_size, sampler=train_sampler,
208
+ shuffle=shuffle_train, num_workers=cfg.num_workers, drop_last=True,
209
+ pin_memory=True)
210
+ dl_val = DataLoader(ds_val, batch_size=cfg.batch_size, sampler=val_sampler,
211
+ shuffle=shuffle_eval, num_workers=cfg.num_workers, pin_memory=True)
212
+ dl_test = DataLoader(ds_test, batch_size=cfg.batch_size, sampler=test_sampler,
213
+ shuffle=shuffle_eval, num_workers=cfg.num_workers, pin_memory=True)
214
+
215
+ return dl_train, dl_val, dl_test, train_sampler, val_sampler, test_sampler
216
+
217
+
218
+ def maybe_load_pretrained(model: nn.Module, path: Path):
219
+ """Load pretrained weights if available."""
220
+ # Check if path is empty, None, or not a valid file
221
+ path_str = str(path) if path else ""
222
+ if not path or path_str.strip() == "" or not path.exists() or not path.is_file():
223
+ if path and path_str.strip() != "":
224
+ print(f"WARNING: Pretrained weights not found at {path}. Skipping load.")
225
+ return
226
+
227
+ # Load the checkpoint
228
+ print(f"Loading pretrained weights from: {path}")
229
+ try:
230
+ state_dict = torch.load(str(path), map_location="cpu")
231
+ # Remove DistributedDataParallel/DataParallel prefix if any
232
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
233
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
234
+
235
+ # Log successful loading with details
236
+ print(f"Successfully loaded pretrained weights from: {path}")
237
+ if len(missing) > 0:
238
+ print(f" Missing keys: {len(missing)} (this is normal if architecture changed)")
239
+ if len(unexpected) > 0:
240
+ print(f" Unexpected keys: {len(unexpected)} (these were ignored)")
241
+ if len(missing) == 0 and len(unexpected) == 0:
242
+ print(f" Perfect match! All weights loaded successfully.")
243
+ except Exception as e:
244
+ print(f"ERROR: Failed to load pretrained weights from {path}: {e}")
245
+ raise
246
+
247
+
248
+ def train_one_epoch(model, dl, device, criterion, optimizer, scaler: GradScaler, amp: bool, lr_supervision_weight: float):
249
+ """Train for one epoch."""
250
+ model.train()
251
+ running = 0.0
252
+ with tqdm(dl, desc="Training", unit="batch", disable=not is_main_process()) as bar:
253
+ for batch in bar:
254
+ # Handle datasets that return meta dict (e.g., URURHRLRDataset)
255
+ if len(batch) == 5:
256
+ images_hr, masks_hr, images_lr, masks_lr, _ = batch
257
+ else:
258
+ images_hr, masks_hr, images_lr, masks_lr = batch
259
+ images_hr = images_hr.to(device, non_blocking=True)
260
+ masks_hr = masks_hr.to(device, non_blocking=True)
261
+ images_lr = images_lr.to(device, non_blocking=True)
262
+ masks_lr = masks_lr.to(device, non_blocking=True)
263
+
264
+ optimizer.zero_grad(set_to_none=True)
265
+ if amp:
266
+ with autocast():
267
+ out = model(images_hr, images_lr)
268
+ logits_hr = out["logits_hr"]
269
+ logits_lr = out["logits_lr"]
270
+ logits_hr = F.interpolate(logits_hr, size=masks_hr.shape[-2:], mode="bilinear", align_corners=False)
271
+ logits_lr = F.interpolate(logits_lr, size=masks_lr.shape[-2:], mode="bilinear", align_corners=False)
272
+ loss_hr = criterion(logits_hr, masks_hr)
273
+ loss_lr = criterion(logits_lr, masks_lr)
274
+ loss = loss_hr + lr_supervision_weight * loss_lr
275
+ scaler.scale(loss).backward()
276
+ scaler.step(optimizer)
277
+ scaler.update()
278
+ else:
279
+ out = model(images_hr, images_lr)
280
+ logits_hr = out["logits_hr"]
281
+ logits_lr = out["logits_lr"]
282
+ logits_hr = F.interpolate(logits_hr, size=masks_hr.shape[-2:], mode="bilinear", align_corners=False)
283
+ logits_lr = F.interpolate(logits_lr, size=masks_lr.shape[-2:], mode="bilinear", align_corners=False)
284
+ loss_hr = criterion(logits_hr, masks_hr)
285
+ loss_lr = criterion(logits_lr, masks_lr)
286
+ loss = loss_hr + lr_supervision_weight * loss_lr
287
+ loss.backward()
288
+ optimizer.step()
289
+
290
+ running += float(loss.item())
291
+ bar.set_postfix(loss=float(loss.item()))
292
+ return running / max(1, len(dl))
293
+
294
+
295
+ def evaluate(model, dl, device, criterion, num_classes: int, lr_supervision_weight: float, phase_name: str = "Validation"):
296
+ """Evaluate model on dataset."""
297
+ model.eval()
298
+ running = 0.0
299
+ # Confusion matrix on GPU
300
+ full_confmat = torch.zeros((num_classes, num_classes), dtype=torch.long, device=device)
301
+
302
+ is_main = (not is_distributed()) or is_main_process()
303
+ iterator = tqdm(dl, desc=phase_name, unit="batch", disable=not is_main) if is_main else dl
304
+
305
+ with torch.inference_mode():
306
+ for batch in iterator:
307
+ # Handle datasets that return meta dict (e.g., URURHRLRDataset)
308
+ if len(batch) == 5:
309
+ images_hr, masks_hr, images_lr, masks_lr, _ = batch
310
+ else:
311
+ images_hr, masks_hr, images_lr, masks_lr = batch
312
+ images_hr = images_hr.to(device, non_blocking=True)
313
+ masks_hr = masks_hr.to(device, non_blocking=True)
314
+ images_lr = images_lr.to(device, non_blocking=True)
315
+ masks_lr = masks_lr.to(device, non_blocking=True)
316
+
317
+ with autocast():
318
+ out = model(images_hr, images_lr)
319
+ logits_hr = out["logits_hr"]
320
+ logits_lr = out["logits_lr"]
321
+ logits_hr = F.interpolate(logits_hr, size=masks_hr.shape[-2:], mode="bilinear", align_corners=False)
322
+ logits_lr = F.interpolate(logits_lr, size=masks_hr.shape[-2:], mode="bilinear", align_corners=False)
323
+ loss_hr = criterion(logits_hr, masks_hr)
324
+ loss_lr = criterion(logits_lr, masks_lr)
325
+ loss = loss_hr + lr_supervision_weight * loss_lr
326
+ running += float(loss.item())
327
+
328
+ preds = torch.argmax(logits_hr, dim=1)
329
+
330
+ # Ignore index 255 directly on GPU
331
+ valid = (masks_hr >= 0) & (masks_hr < num_classes)
332
+ t = masks_hr[valid]
333
+ p = preds[valid]
334
+
335
+ # GPU-vectorized confusion matrix
336
+ cm = torch.bincount(
337
+ (t * num_classes + p).view(-1),
338
+ minlength=num_classes * num_classes
339
+ ).reshape(num_classes, num_classes)
340
+ full_confmat += cm
341
+
342
+ if is_main and isinstance(iterator, tqdm):
343
+ iterator.set_postfix(loss=float(loss.item()))
344
+
345
+ # Aggregate across all ranks if DDP
346
+ if is_distributed():
347
+ dist.all_reduce(full_confmat, op=dist.ReduceOp.SUM)
348
+
349
+ avg_loss = running / max(1, len(dl))
350
+
351
+ # Convert to CPU once at the end
352
+ confmat_np = full_confmat.cpu().numpy()
353
+ metrics = compute_metrics_from_confusion(confmat_np)
354
+ return avg_loss, metrics
355
+
356
+
357
+ def main(cfg_path: str = "configs/FlairHub.yaml"):
358
+ """Main training function."""
359
+ logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(levelname)s - %(message)s")
360
+ cfg = load_config(cfg_path)
361
+
362
+ set_seed(cfg.seed)
363
+
364
+ # DDP setup
365
+ setup_distributed()
366
+ local_rank = get_local_rank()
367
+ device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
368
+
369
+ if cfg.print_device and is_main_process():
370
+ print(f"Distributed: {is_distributed()} | WORLD_SIZE={os.environ.get('WORLD_SIZE', '1')} | RANK={get_rank()} | LOCAL_RANK={local_rank}")
371
+ print(f"torch.cuda.is_available(): {torch.cuda.is_available()}")
372
+ print(f"Device used (this process): {device}")
373
+
374
+ # Data
375
+ dl_train, dl_val, dl_test, train_sampler, val_sampler, test_sampler = make_dataloaders(cfg)
376
+
377
+ # Model
378
+ model = CASWiT(
379
+ num_head_xa=cfg.cross_attention_heads,
380
+ num_classes=cfg.num_classes,
381
+ model_name=cfg.model_name,
382
+ mlp_ratio=cfg.fusion_mlp_ratio,
383
+ drop_path=cfg.fusion_drop_path
384
+ ).to(device)
385
+
386
+ # Load pretrained weights
387
+ if cfg.pretrained_path and cfg.pretrained_path.strip():
388
+ if is_main_process():
389
+ print(f"Attempting to load pretrained weights from: {cfg.pretrained_path}")
390
+ maybe_load_pretrained(model, Path(cfg.pretrained_path))
391
+ else:
392
+ if is_main_process():
393
+ print("No pretrained weights specified. Starting training from scratch.")
394
+
395
+ # Wrap with DDP if distributed
396
+ if is_distributed():
397
+ model = torch.nn.parallel.DistributedDataParallel(
398
+ model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False
399
+ )
400
+
401
+ # Optimizer/Scheduler/Loss
402
+ criterion = nn.CrossEntropyLoss(ignore_index=cfg.ignore_index)
403
+ optimizer = optim.AdamW(model.parameters(), lr=cfg.learning_rate)
404
+ scheduler = CosineAnnealingLR(optimizer, T_max=cfg.num_epochs, eta_min=cfg.eta_min)
405
+ scaler = GradScaler(enabled=cfg.amp)
406
+
407
+ # Logging (wandb only on main)
408
+ use_wandb = cfg.use_wandb and (wandb is not None) and is_main_process()
409
+ if use_wandb:
410
+ setup_wandb_logging(
411
+ project=cfg.wandb_project,
412
+ entity=cfg.wandb_entity,
413
+ run_name=cfg.wandb_run_name,
414
+ config={
415
+ "training": {
416
+ "num_epochs": cfg.num_epochs,
417
+ "learning_rate": cfg.learning_rate,
418
+ "batch_size": cfg.batch_size,
419
+ "amp": cfg.amp,
420
+ "eta_min": cfg.eta_min,
421
+ },
422
+ "model": {
423
+ "num_classes": cfg.num_classes,
424
+ "model_name": cfg.model_name,
425
+ "cross_attention_heads": cfg.cross_attention_heads,
426
+ },
427
+ "paths": {
428
+ "data_path": cfg.data_path,
429
+ "save_dir": cfg.save_dir,
430
+ },
431
+ },
432
+ use_wandb=cfg.use_wandb
433
+ )
434
+
435
+ # Train loop with best model tracking
436
+ Path(cfg.save_dir).mkdir(parents=True, exist_ok=True)
437
+ best_miou = 0.0
438
+ best_epoch = 0
439
+ best_model_path = None
440
+
441
+ for epoch in range(cfg.num_epochs):
442
+ if is_distributed() and train_sampler is not None:
443
+ train_sampler.set_epoch(epoch)
444
+
445
+ if is_main_process():
446
+ print(f"\nEpoch {epoch + 1}/{cfg.num_epochs}")
447
+
448
+ train_loss = train_one_epoch(model, dl_train, device, criterion, optimizer, scaler, cfg.amp, cfg.lr_supervision_weight)
449
+ val_loss, val_metrics = evaluate(model, dl_val, device, criterion, cfg.num_classes, cfg.lr_supervision_weight, phase_name="Validation")
450
+
451
+ current_lr = optimizer.param_groups[0]['lr']
452
+ current_miou = val_metrics["mIoU"]
453
+
454
+ if is_main_process():
455
+ log_payload = {
456
+ "epoch": epoch + 1,
457
+ "lr": current_lr,
458
+ "train_loss": train_loss,
459
+ "val_loss": val_loss,
460
+ "val_mIoU": current_miou,
461
+ "val_mF1": val_metrics["mF1"],
462
+ }
463
+ print(f"LR={current_lr:.6f} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | mIoU={current_miou:.4f}")
464
+ if use_wandb:
465
+ log_metrics(log_payload)
466
+
467
+ # Save checkpoint for this epoch
468
+ ckpt_name = f"fusion_hrlr_{cfg.model_name.split('/')[-1]}_lrsupervised_{cfg.batch_size}_epoch_{epoch+1}_head{cfg.cross_attention_heads}.pth"
469
+ ckpt_path = Path(cfg.save_dir) / ckpt_name
470
+ torch.save(
471
+ model.module.state_dict() if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model.state_dict(),
472
+ str(ckpt_path)
473
+ )
474
+
475
+ # Track best model based on validation mIoU
476
+ if current_miou > best_miou:
477
+ best_miou = current_miou
478
+ best_epoch = epoch + 1
479
+ best_model_path = ckpt_path
480
+ # Save best model with special name
481
+ best_ckpt_name = f"best_model_epoch_{epoch+1}_miou_{current_miou:.4f}.pth"
482
+ torch.save(
483
+ model.module.state_dict() if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model.state_dict(),
484
+ str(Path(cfg.save_dir) / best_ckpt_name)
485
+ )
486
+ print(f"*** New best model! mIoU: {best_miou:.4f} ***")
487
+
488
+ # Step LR scheduler
489
+ scheduler.step()
490
+
491
+ # Final test evaluation with best model
492
+ if is_main_process():
493
+ print(f"\n{'='*80}")
494
+ print(f"Training completed. Best validation mIoU: {best_miou:.4f} at epoch {best_epoch}")
495
+ print(f"Loading best model for final test evaluation...")
496
+ print(f"{'='*80}\n")
497
+
498
+ # Load best model for final test
499
+ if best_model_path and best_model_path.exists():
500
+ state_dict = torch.load(str(best_model_path), map_location=device)
501
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
502
+ model.module.load_state_dict(state_dict, strict=True)
503
+ else:
504
+ model.load_state_dict(state_dict, strict=True)
505
+ if is_main_process():
506
+ print(f"Loaded best model from: {best_model_path}\n")
507
+
508
+ # Run final test evaluation
509
+ test_loss, test_metrics = evaluate(model, dl_test, device, criterion, cfg.num_classes, cfg.lr_supervision_weight, phase_name="Final Test")
510
+ if is_main_process():
511
+ test_payload = {
512
+ "best_epoch": best_epoch,
513
+ "best_val_mIoU": best_miou,
514
+ "FINAL_TEST_loss": test_loss,
515
+ "FINAL_TEST_mIoU": test_metrics["mIoU"],
516
+ "FINAL_TEST_mF1": test_metrics["mF1"],
517
+ }
518
+ print(f"\n{'='*80}")
519
+ print(f"FINAL TEST RESULTS (with best model from epoch {best_epoch}):")
520
+ print(f" Evaluation")
521
+ print(f" Test Loss: {test_loss:.4f}")
522
+ print(f" Test mIoU: {test_metrics['mIoU']:.4f}")
523
+ print(f" Test mF1: {test_metrics['mF1']:.4f}")
524
+ print(f"{'='*80}\n")
525
+ if use_wandb:
526
+ log_metrics(test_payload)
527
+ cleanup_distributed()
528
+
529
+
530
+ if __name__ == "__main__":
531
+ import sys
532
+ cfg_path = sys.argv[1] if len(sys.argv) > 1 else "configs/FlairHub.yaml"
533
+ main(cfg_path)
534
+
utils/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for CASWiT training and evaluation.
3
+ """
4
+
5
+ from utils.metrics import compute_confusion, compute_metrics_from_confusion
6
+ from utils.logging import setup_wandb_logging, log_metrics
7
+ from utils.attention_viz import viz_cross_attention
8
+
9
+ __all__ = [
10
+ 'compute_confusion',
11
+ 'compute_metrics_from_confusion',
12
+ 'setup_wandb_logging',
13
+ 'log_metrics',
14
+ 'viz_cross_attention',
15
+ ]
16
+
utils/attention_viz.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cross-attention visualization for CASWiT model.
3
+
4
+ This module provides utilities to visualize cross-attention maps between
5
+ HR and LR branches at different encoder stages.
6
+ """
7
+
8
+ from typing import Dict, List, Tuple
9
+ import math
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+
15
+
16
+ def _to_numpy(img: torch.Tensor) -> np.ndarray:
17
+ """
18
+ Convert normalized tensor to numpy array for visualization.
19
+
20
+ Args:
21
+ img: (1,3,H,W) tensor in normalized space [-1,1]
22
+
23
+ Returns:
24
+ uint8 HxWx3 array for plotting
25
+ """
26
+ x = img.detach().float().cpu()[0]
27
+ # Undo Normalize(mean=0.5, std=0.5) -> x*0.5 + 0.5
28
+ x = x * 0.5 + 0.5
29
+ x = torch.clamp(x, 0, 1)
30
+ x = (x.permute(1, 2, 0).numpy() * 255.0).astype(np.uint8)
31
+ return x
32
+
33
+
34
+ def _pixel_to_token(x: int, y: int, W_img: int, H_img: int,
35
+ W_tokens: int, H_tokens: int) -> int:
36
+ """
37
+ Map HR pixel (x,y) to linear token index for a grid (H_tokens, W_tokens).
38
+
39
+ Uses a ratio-based mapping which is correct for uniform patch embeddings.
40
+ """
41
+ tx = min(max(int(math.floor(x * W_tokens / max(W_img, 1))), 0), W_tokens - 1)
42
+ ty = min(max(int(math.floor(y * H_tokens / max(H_img, 1))), 0), H_tokens - 1)
43
+ return ty * W_tokens + tx
44
+
45
+
46
+ class _CrossAttnTap:
47
+ """Hook storage for capturing cross-attention weights."""
48
+ def __init__(self):
49
+ # Per stage we store: attn (B, N_q, N_k) averaged over heads, and grids
50
+ self.attn_by_stage: Dict[int, torch.Tensor] = {}
51
+ self.hr_grid_by_stage: Dict[int, Tuple[int, int]] = {}
52
+ self.lr_grid_by_stage: Dict[int, Tuple[int, int]] = {}
53
+ self._handles: List[torch.utils.hooks.RemovableHandle] = []
54
+
55
+ def register(self, model) -> None:
56
+ """Register hooks on model to capture attention weights."""
57
+ # Locate the list of CrossFusionBlock modules
58
+ blocks = getattr(model, 'cross_attn_blocks', None)
59
+ if blocks is None:
60
+ raise RuntimeError('Model has no attribute cross_attn_blocks')
61
+
62
+ for s, block in enumerate(blocks):
63
+ # Hook on the CrossFusionBlock to get H/W of inputs (x_hr, x_lr)
64
+ def fwd_hook(stage_idx: int):
65
+ def _f(module, inputs, output):
66
+ # inputs: (x_hr, x_lr)
67
+ x_hr, x_lr = inputs[0], inputs[1]
68
+ _, _, Hh, Wh = x_hr.shape
69
+ _, _, Hl, Wl = x_lr.shape
70
+ self.hr_grid_by_stage[stage_idx] = (Hh, Wh)
71
+ self.lr_grid_by_stage[stage_idx] = (Hl, Wl)
72
+ return _f
73
+ self._handles.append(block.register_forward_hook(fwd_hook(s)))
74
+
75
+ # Hook on the internal nn.MultiheadAttention to grab attn weights
76
+ mha = getattr(block, 'attn', None)
77
+ if mha is None:
78
+ raise RuntimeError(f'CrossFusionBlock at stage {s} has no attn module')
79
+
80
+ def attn_hook(stage_idx: int):
81
+ def _f(module, inputs, output):
82
+ # output is a tuple: (attn_out, attn_weights)
83
+ if isinstance(output, tuple) and len(output) == 2:
84
+ attn_w = output[1] # shape: (B, N_q, N_k) averaged over heads
85
+ self.attn_by_stage[stage_idx] = attn_w.detach()
86
+ return _f
87
+ self._handles.append(mha.register_forward_hook(attn_hook(s)))
88
+
89
+ def clear(self):
90
+ """Clear stored attention weights."""
91
+ self.attn_by_stage.clear()
92
+ self.hr_grid_by_stage.clear()
93
+ self.lr_grid_by_stage.clear()
94
+
95
+ def remove(self):
96
+ """Remove all registered hooks."""
97
+ for h in self._handles:
98
+ h.remove()
99
+ self._handles.clear()
100
+
101
+
102
+ @torch.no_grad()
103
+ def viz_cross_attention(
104
+ model: torch.nn.Module,
105
+ img_hr: torch.Tensor, # (1,3,H,W) normalized with mean=std=0.5
106
+ img_lr: torch.Tensor, # (1,3,h,w) normalized
107
+ pixel_xy: Tuple[int, int],
108
+ save_path: str = 'attn_maps.png',
109
+ overlay_alpha: float = 0.55,
110
+ dpi: int = 180,
111
+ show_titles: bool = True,
112
+ ):
113
+ """
114
+ Visualize cross-attention maps for a given pixel location.
115
+
116
+ Runs a forward pass and saves a multi-panel PNG: one panel per cross-attn stage.
117
+ The attention is averaged over heads (default behavior of nn.MultiheadAttention).
118
+
119
+ Args:
120
+ model: CASWiT model (unwrap DDP if needed)
121
+ img_hr: HR input image [1, 3, H, W]
122
+ img_lr: LR input image [1, 3, h, w]
123
+ pixel_xy: (x, y) pixel coordinates in HR image space
124
+ save_path: Path to save visualization
125
+ overlay_alpha: Alpha transparency for attention overlay
126
+ dpi: DPI for saved figure
127
+ show_titles: Whether to show stage titles
128
+ """
129
+ was_training = model.training
130
+ model.eval()
131
+
132
+ # If user passed a DDP-wrapped model, unwrap
133
+ if hasattr(model, 'module') and not hasattr(model, 'cross_attn_blocks'):
134
+ model = model.module
135
+
136
+ tap = _CrossAttnTap()
137
+ tap.register(model)
138
+ tap.clear()
139
+
140
+ try:
141
+ # Forward to populate hooks
142
+ device = next(model.parameters()).device
143
+ img_hr = img_hr.to(device)
144
+ img_lr = img_lr.to(device)
145
+ _ = model(img_hr, img_lr)
146
+
147
+ H_img, W_img = img_hr.shape[-2:]
148
+ px, py = pixel_xy
149
+ px = int(np.clip(px, 0, W_img - 1))
150
+ py = int(np.clip(py, 0, H_img - 1))
151
+
152
+ # Prepare base images for overlays (H,W,3) in [0,255]
153
+ base_hr = _to_numpy(img_hr)
154
+ base_lr = _to_numpy(img_lr)
155
+
156
+ stages = sorted(tap.attn_by_stage.keys())
157
+ if len(stages) == 0:
158
+ raise RuntimeError('No attention captured. Ensure a forward pass reached the cross-attention blocks.')
159
+
160
+ n = len(stages)
161
+
162
+ # Figure with a dedicated column for colorbar
163
+ fig = plt.figure(figsize=(4.0*n, 4.2), dpi=dpi)
164
+ gs = fig.add_gridspec(nrows=1, ncols=n+1, width_ratios=[1]*n + [0.04], wspace=0.05)
165
+
166
+ axes = [fig.add_subplot(gs[0, i]) for i in range(n)]
167
+ cax = fig.add_subplot(gs[0, -1]) # Axis reserved for colorbar
168
+
169
+ hm = None
170
+ for i, s in enumerate(stages):
171
+ attn = tap.attn_by_stage[s] # (B, N_q, N_k)
172
+ (Hh, Wh) = tap.hr_grid_by_stage[s]
173
+ (Hl, Wl) = tap.lr_grid_by_stage[s]
174
+
175
+ # Pick batch 0
176
+ attn0 = attn[0] # (N_q, N_k)
177
+ q_idx = _pixel_to_token(px, py, W_img, H_img, Wh, Hh) # note W first in tokens
178
+ row = attn0[q_idx] # (N_k,)
179
+
180
+ attn_map = row.view(Hl, Wl) # reshape to LR grid because K comes from LR branch
181
+
182
+ # Normalize for visualization
183
+ attn_map = attn_map - attn_map.min()
184
+ denom = float(attn_map.max().item()) if float(attn_map.max().item()) > 0 else 1.0
185
+ attn_map = attn_map / denom
186
+
187
+ # Upsample to LR background size
188
+ attn_up = F.interpolate(
189
+ attn_map[None, None, ...],
190
+ size=base_lr.shape[:2],
191
+ mode='bilinear',
192
+ align_corners=False
193
+ )[0, 0]
194
+ attn_np = attn_up.detach().cpu().numpy()
195
+
196
+ ax = axes[i]
197
+ ax.imshow(base_lr)
198
+ hm = ax.imshow(attn_np, cmap='jet', alpha=overlay_alpha, vmin=0.0, vmax=1.0)
199
+
200
+ # Approx HR→LR pixel mapping for the marker (simple ratio)
201
+ hx, hy = base_hr.shape[1], base_hr.shape[0]
202
+ lx, ly = base_lr.shape[1], base_lr.shape[0]
203
+ px_lr = int(round(px * lx / max(hx, 1)))
204
+ py_lr = int(round(py * ly / max(hy, 1)))
205
+ ax.scatter([px_lr], [py_lr], s=18, c='white', marker='o',
206
+ linewidths=0.5, edgecolors='black')
207
+
208
+ if show_titles:
209
+ ax.set_title(f'Stage {s+1}: HR→LR attn', fontsize=10)
210
+ ax.set_axis_off()
211
+
212
+ # Colorbar in dedicated axis
213
+ cbar = fig.colorbar(hm, cax=cax)
214
+ cbar.set_label('Attention')
215
+
216
+ # Save PNG
217
+ fig.savefig(save_path, bbox_inches='tight', format='png')
218
+ plt.close(fig)
219
+
220
+ finally:
221
+ # Cleanup hooks, restore training mode if needed
222
+ tap.remove()
223
+ if was_training:
224
+ model.train()
225
+
utils/logging.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Logging utilities for CASWiT training.
3
+
4
+ Provides WandB integration and logging helpers.
5
+ """
6
+
7
+ from typing import Optional
8
+ try:
9
+ import wandb
10
+ except ImportError:
11
+ wandb = None
12
+
13
+
14
+ def setup_wandb_logging(project: str, entity: str, run_name: str, config: dict,
15
+ use_wandb: bool = True) -> bool:
16
+ """
17
+ Setup Weights & Biases logging.
18
+
19
+ Args:
20
+ project: WandB project name
21
+ entity: WandB entity/username
22
+ run_name: Run name
23
+ config: Configuration dictionary
24
+ use_wandb: Whether to use WandB
25
+
26
+ Returns:
27
+ True if WandB is initialized, False otherwise
28
+ """
29
+ if not use_wandb or wandb is None:
30
+ return False
31
+
32
+ wandb.init(project=project, entity=entity, config=config, name=run_name)
33
+ return True
34
+
35
+
36
+ def log_metrics(metrics: dict, step: Optional[int] = None):
37
+ """Log metrics to WandB."""
38
+ if wandb is not None:
39
+ wandb.log(metrics, step=step)
40
+
utils/metrics.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Metrics computation for semantic segmentation.
3
+
4
+ Provides functions for computing IoU, mIoU, F1 score, and confusion matrices.
5
+ """
6
+
7
+ from typing import Dict
8
+ import numpy as np
9
+ from sklearn.metrics import confusion_matrix
10
+
11
+
12
+ def compute_confusion(preds: np.ndarray, targets: np.ndarray, num_classes: int) -> np.ndarray:
13
+ """
14
+ Compute confusion matrix.
15
+
16
+ Args:
17
+ preds: Predicted labels
18
+ targets: Ground truth labels
19
+ num_classes: Number of classes
20
+
21
+ Returns:
22
+ Confusion matrix [num_classes, num_classes]
23
+ """
24
+ return confusion_matrix(targets.flatten(), preds.flatten(), labels=np.arange(num_classes))
25
+
26
+
27
+ def compute_metrics_from_confusion(confmat: np.ndarray) -> Dict[str, np.ndarray]:
28
+ """
29
+ Compute IoU, mIoU, and F1 scores from confusion matrix.
30
+
31
+ Args:
32
+ confmat: Confusion matrix [num_classes, num_classes]
33
+
34
+ Returns:
35
+ Dictionary with 'mIoU', 'mF1', and 'IoUs' keys
36
+ """
37
+ with np.errstate(divide='ignore', invalid='ignore'):
38
+ intersection = np.diag(confmat)
39
+ ground_truth_set = confmat.sum(axis=1)
40
+ predicted_set = confmat.sum(axis=0)
41
+ union = ground_truth_set + predicted_set - intersection
42
+ ious = intersection / np.maximum(union, 1)
43
+ f1s = (2 * intersection) / np.maximum(ground_truth_set + predicted_set, 1)
44
+ miou = np.nanmean(ious)
45
+ mf1 = np.nanmean(f1s)
46
+ return {"mIoU": float(miou), "mF1": float(mf1), "IoUs": ious}
47
+
weights/CASWiT-Base-SSL_FLAIRHUB_15classes.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a5c54b18a75ca3a462d31a9c70a7ed9575ceacc15e0576a691a3a1db59244f6
3
+ size 1036652553
weights/CASWiT-Base-SSL_URUR_8classes.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5bc2d3523c801305f65f0af07fbde4af5f3f651952f3689edaeb676581baf67
3
+ size 1036657681
weights/CASWiT-Base_FLAIRHUB_15classes.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8d6fa7d28ed92e905a03d1fc753dd27963944c0d42aafb0a71f57144e48ce98
3
+ size 1036639033
weights/CASWiT-Base_URUR_8classes.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f498c0e8f0b5d139d332b0b534695d7c2b9128c4c68097a77bc93f304ea6e1a
3
+ size 1036647637
weights/Swin-Base_FLAIRHUB_15classes.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80aeef0022881619f221b9bf89d9a3174f27ba7f45f701963694e7077a8e4313
3
+ size 489532214