license: mit
datasets:
- IGNF/FLAIR-HUB
- heig-vd-geo/URUR
language:
- en
metrics:
- mean_iou
pipeline_tag: image-segmentation
CASWiT: Context-Aware Stage Wise Transformer for Ultra-High Resolution Semantic Segmentation
Context-Aware Semantic Segmentation via Stage-Wise Attention
Official implementation of CASWiT, a dual-branch architecture for ultra-high resolution semantic segmentation that leverages stage-wise cross-attention fusion between high-resolution and low-resolution branches.
π Table of Contents
- Overview
- Architecture
- Installation
- Docker
- Dataset Preparation
- Usage
- Configuration
- Results
- Citation
- License
π― Overview
CASWiT addresses semantic segmentation on ultra-high resolution (UHR) imagery with a dual-resolution design:
- HR Branch: processes high-resolution crops (e.g., 512Γ512) for fine details
- LR Branch: processes low-resolution context (typically downsampled) for global information
- Stage-wise Cross-Attention Fusion: HR features attend to LR context at each encoder stage
CASWiT generalizes beyond aerial imagery without architectural changes, and we provide configs/scripts for several UHR datasets (FLAIR-HUB, URUR, SWISSIMAGE inference, and additional medical benchmarks).
ποΈ Architecture
Key components:
- Dual Swin Transformer Backbones
- Cross-Attention Fusion Blocks at each encoder stage
- Auxiliary LR Supervision (optional, weighted by
model.lr_supervision_weight)
π¦ Installation
Requirements
- Python 3.12+
- PyTorch 2+
- CUDA 12+ (for GPU training)
Setup
git clone https://huggingface.co/heig-vd-geo/CASWiT
cd CASWiT
pip install -r requirements.txt
π³ Docker
A Dockerfile is provided for a reproducible environment.
Build
docker build -t caswit:latest .
Run (GPU)
docker run --gpus all -it --rm \
-v $(pwd):/workspace \
caswit:latest
If your datasets/checkpoints are outside the repo, mount them too, e.g.:
docker run --gpus all -it --rm \
-v $(pwd):/workspace \
-v /path/to/data:/data \
-v /path/to/checkpoints:/checkpoints \
caswit:latest
π Dataset Preparation
FLAIR-HUB
- Download the FlairHub dataset
- Merge GeoTIFF tiles into mosaics:
python dataset/prepareFlairHub.py
URUR
Expected structure:
URUR/
βββ train/
β βββ image/
β βββ label/
βββ val/
β βββ image/
β βββ label/
βββ test/
βββ image/
βββ label/
A re-hosted copy is available here: https://huggingface.co/datasets/heig-vd-geo/URUR
SWISSIMAGE
Download images using the provided CSV:
python dataset/download_swissimage.py list_all_swiss_image_sept2025.csv
SwissImage + SwissTLM3D Adaptation Pipeline
The repository also includes a full adaptation pipeline for Switzerland using:
- SwissImage as imagery source
- SwissTLM3D as coarse / partial supervision
- FLAIR-HUB as the semantic anchor for mixed fine-tuning
1. Download SwissImage
python dataset/download_swissimage.py list_all_swiss_image_sept2025.csv
2. Download SwissTLM3D
The official SwissTLM3D GeoPackage can be downloaded from swisstopo:
mkdir -p /mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d
curl -L \
https://data.geo.admin.ch/ch.swisstopo.swisstlm3d/swisstlm3d_2026-02-24/swisstlm3d_2026-02-24_2056_5728.gpkg.zip \
-o /mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d/swisstlm3d_2026-02-24_2056_5728.gpkg.zip
unzip -o \
/mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d/swisstlm3d_2026-02-24_2056_5728.gpkg.zip \
-d /mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d
3. Rasterize SwissTLM3D masks on SwissImage
This creates masks in the FLAIR-HUB label space with 255 for unsupported / ignored classes.
python dataset/rasterize_swisstlm3d.py \
--gpkg /mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d/SWISSTLM3D_2026_LV95_LN02.gpkg \
--image_dir /mnt/CalcShare/datasets/Suisse_full/data/image \
--output_dir /mnt/CalcShare/datasets/Suisse_full/data/mask_tlm3d \
--workers 8
4. Create 2048 -> 1024 training patches
SwissImage has a coarser GSD than FLAIR-HUB in our setup. We therefore extract 2048x2048 windows and resize them to 1024x1024 before training.
python dataset/prepare_swisstlm3d_patches.py \
--image_dir /mnt/CalcShare/datasets/Suisse_full/data/image \
--mask_dir /mnt/CalcShare/datasets/Suisse_full/data/mask_tlm3d \
--output_dir /mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d_patches_2048to1024 \
--drop_all_ignore \
--workers 48
5. Create a geographic split
This split is deterministic and avoids spatial leakage by splitting along contiguous geographic bands.
python dataset/split_swisstlm3d_geographic.py \
--image_dir /mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d_patches_2048to1024/img \
--mask_dir /mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d_patches_2048to1024/msk \
--output_dir /mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d_patches_2048to1024_split \
--overwrite
π Usage
All commands below work either directly with the scripts in train/ or via the unified main.py.
Training
Single GPU:
python train/train.py configs/config_FlairHub.yaml
Multi-GPU (DDP):
torchrun --nproc_per_node=4 train/train.py configs/config_FlairHub.yaml
Evaluation
python train/eval.py configs/config_FlairHub.yaml weights/checkpoint.pth test
Inference (single image)
python train/inference.py configs/config_FlairHub.yaml weights/checkpoint.pth image.tif output.png
Inference on a VRT (DDP)
This runs tiled inference on a VRT using multiple GPUs:
torchrun --nproc_per_node=5 --master_port=29501 -m train.inference_vrt_ddp \
--config configs/config_SWISSIMAGE_inf.yaml \
--checkpoint weights/CASWiT-Base-SSL_FLAIRHUB_UN.pth \
--vrt file.vrt \
--out_dir output/ \
--tile 1024 \
--stride 512 \
--lr_side 2048
Using main.py
# Train
python main.py train --config configs/config_FlairHub.yaml
# Eval (default split = test in the script if not specified)
python main.py eval --config configs/config_FlairHub.yaml --checkpoint weights/checkpoint.pth
# Inference (single image)
python main.py inference --config configs/config_FlairHub.yaml --checkpoint weights/checkpoint.pth --image image.tif --output pred.png
βοΈ Configuration
Configs are YAML files in configs/.
Model selection (decoder head)
You can select the decoder head directly in the config:
model:
# options: upernet | segformer | mask2former | fusion_last_stage_add | ssl
head: upernet
Notes:
mask2formermay requiretraining.amp: falsedepending on your environment/precision settings.sslis the SimMIM-style pretraining model and is not a drop-in replacement for segmentation inference/eval scripts.
Example config structure
paths:
data_path: "/path/to/dataset"
dataset_name: "FLAIRHUB"
train_img_subdir: "train/img"
train_msk_subdir: "train/msk"
val_img_subdir: "valid/img"
val_msk_subdir: "valid/msk"
test_img_subdir: "test/img"
test_msk_subdir: "test/msk"
save_dir: "weights"
pretrained_path: ""
model:
model_name: "openmmlab/upernet-swin-base"
num_classes: 15
cross_attention_heads: 1
ignore_index: 255
fusion_mlp_ratio: 4.0
fusion_drop_path: 0.1
lr_supervision_weight: 0.5
# options: upernet | segformer | mask2former | fusion_last_stage_add | ssl
head: upernet
training:
batch_size: 4
num_workers: 8
num_epochs: 20
learning_rate: 0.00006
amp: true
seed: 42
eta_min: 0.000001
wandb:
use_wandb: true
project: "CASWiT"
entity: "your-entity"
run_name: "caswit_experiment"
augmentations:
enable: false
p_hflip: 0.5
p_vflip: 0.5
p_rot90: 0.5
color_jitter:
brightness: 0.2
contrast: 0.2
saturation: 0.2
hue: 0.05
blur:
p: 0.1
kernel: 3
π§ͺ Reproducing experiments (all provided configs)
Below are ready-to-run commands for each config. Replace --checkpoint with your own file when evaluating/inferencing.
FLAIR-HUB
# Train (no extra aug)
python main.py train --config configs/config_FlairHub.yaml
# Train (with augmentations)
python main.py train --config configs/config_FlairHub_aug.yaml
# Eval
python main.py eval --config configs/config_FlairHub.yaml --checkpoint weights/checkpoint.pth
URUR
python main.py train --config configs/config_URUR.yaml
python main.py train --config configs/config_URUR_aug.yaml
python main.py eval --config configs/config_URUR.yaml --checkpoint weights/checkpoint.pth
ISIC (medical)
python main.py train --config configs/config_ISIC_aug.yaml
python main.py eval --config configs/config_ISIC_aug.yaml --checkpoint weights/checkpoint.pth
CRAG (medical)
python main.py train --config configs/config_CRAG_aug.yaml
python main.py eval --config configs/config_CRAG_aug.yaml --checkpoint weights/checkpoint.pth
Swiss Fine-Tuning
SwissTLM3D-only fine-tuning
This is the simplest adaptation setup: start from the FLAIR-HUB SegFormer checkpoint and fine-tune on the prepared Swiss patches only.
python main.py train --config configs/config_SWISSTLM3D_segformer_ft.yaml
Mixed FLAIR-HUB + SwissTLM3D fine-tuning
This is the recommended strategy for adaptation:
- keep FLAIR-HUB as the strong semantic supervision source
- inject SwissTLM3D patches as domain adaptation supervision
- validate separately on FLAIR-HUB and SwissTLM3D
- select the best checkpoint using the configured mixed validation score
Before training, update the Swiss path in configs/config_mixed_domain_ft_segformer.yaml to point to your geographic split of prepared patches, e.g.:
paths:
swiss_path: /mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d_patches_2048to1024_split
Then launch:
torchrun --nproc_per_node=4 --master_port=29501 \
train/train_mixed_domain_ft.py \
configs/config_mixed_domain_ft_segformer.yaml
The mixed script logs:
val_flair_mIoU,val_flair_mF1val_swiss_mIoU,val_swiss_mF1model_selection_score
The default mixed strategy uses:
swiss_ratio: 0.30best_on: mean
This means the best checkpoint is selected using the mean of FLAIR-HUB and SwissTLM3D validation mIoU.
π Results
FLAIR-HUB (RGB-only UHR protocol)
| Model | mIoU (%) β | mF1 (%) β | mBIoU (%) β | GFLOPs β | FPS β |
|---|---|---|---|---|---|
| RGB Baselines (official) | |||||
| Swin-T + UPerNet | 62.01 | 75.27 | β | 237 | 69.2 |
| Swin-S + UPerNet | 61.87 | 75.11 | β | 261 | 41.5 |
| Swin-B + UPerNet | 64.05 | 76.88 | β | 306 | 36.3 |
| Swin-B + UPerNet (retrained) | 64.02 | 76.64 | 32.57 | β | β |
| Swin-L + UPerNet | 63.36 | 76.35 | β | 420 | 27.8 |
| Dual-branch baselines | |||||
| ISDNet (trained with Pytorch implementation) | 52.77 | - | - | - | |
| Dual Swin-Base (late fusion, no CA) | 64.25 | β | β | 398 | 19.4 |
| CASWiT-Base + UPerNet | 65.11 | 77.71 | 35.87 | 489 | 15.4 |
| CASWiT-Base-SSL + UPerNet | 65.35 | 77.87 | 35.99 | 489 | 15.4 |
| CASWiT-Base-SSL-aug + UPerNet | 65.83 | 78.22 | 36.90 | 489 | 15.4 |
| CASWiT-Base-SSL-aug + SegFormer | 66.37 | 78.58 | 36.51 | 298 | 17.9 |
CASWiT-Base already improves over the retrained Swin-B + UPerNet baseline, and CASWiT-Base-SSL-aug + SegFormer further pushes performance to 66.37 mIoU and ** 78.58 mF1**.
On mean boundary IoU, CASWiT-Base-SSL-aug + SegFormer reaches 36.51 mBIoU, which is a 3.94 mBIoU gain over the retrained Swin-B baseline (32.57).
ISIC / CRAG (test sets)
| Method | ISIC | CRAG |
|---|---|---|
| GPWFormer | 80.7 | 89.9 |
| Boosting Dual-Branch | 83.4 | 90.3 |
| CASWiT-Base-SSL-aug + UperNet (ours) | 85.4 | 90.3 |
| CASWiT-Base-SSL-aug + SegFormer (ours) | 86.5 | 90.7 |
URUR
We also evaluate CASWiT on the URUR ultra-high-resolution benchmark, comparing to both generic and UHR-specific segmentation models.
| Model | mIoU (%) β | Mem (MB) β |
|---|---|---|
| Generic Models | ||
| PSPNet | 32.0 | 5482 |
| ResNet18 + DeepLabv3+ | 33.1 | 5508 |
| STDC | 42.0 | 7617 |
| UHR Models | ||
| GLNet | 41.2 | 3063 |
| FCLt | 43.1 | 4508 |
| ISDNet | 45.8 | 4920 |
| WSDNet | 46.9 | 4510 |
| Boosting Dual-stream | 48.2 | 3682 |
| CASWiT-Base + UPer | 48.7 | 2996 |
| CASWiT-Base-SSL-aug + UPer | 49.1 | 2996 |
| CASWiT-Base-SSL-aug + SegF | 49.2 | 2878 |
On URUR, CASWiT-Base + UPer already matches and slightly surpasses prior UHR-specific methods, CASWiT-Base-SSL-aug + UPer reaches 49.1 mIoU, and CASWiT-Base-SSL-aug + SegF further improves to 49.2 mIoU while remaining competitive in memory usage.
Switzerland Adaptation (SwissImage + SwissTLM3D)
We provide the full preparation and fine-tuning pipeline for SwissImage + SwissTLM3D adaptation. The recommended training setup is mixed fine-tuning:
- initialize from
CASWiT-Base-SSL-aug_FLAIRHUB_SF - train on a mixture of FLAIR-HUB and SwissTLM3D patches
- keep separate validation metrics for the two domains
Mixed Fine-Tuning Results
To be filled once experiments are complete.
| Model | Val FLAIR-HUB mIoU (%) β | Val SwissTLM3D mIoU (%) β | Test FLAIR-HUB mIoU (%) β | Test SwissTLM3D mIoU (%) β |
|---|---|---|---|---|
| CASWiT-Base-SSL-aug + SegFormer + mixed FT | β | β | β | β |
π¬ Self-Supervised Learning
CASWiT also supports self-supervised pre-training using SimMIM-style SSL (Simple Masked Image Modeling). We used this configuration on the entire SWISSIMAGE dataset to pretrain CASWiT.
π οΈ Project Structure
CASWiT/
βββ model/
β βββ CASWiT_upernet.py
β βββ CASWiT_segformer.py
β βββ CASWiT_m2f.py
β βββ CASWiT_fusion_last_stage_add.py
β βββ CASWiT_ssl.py
βββ dataset/
β βββ definition_dataset.py
β βββ fusion_augment.py
β βββ download_swissimage.py
β βββ prepareFlairHub.py
βββ configs/
β βββ config_FlairHub.yaml
β βββ config_FlairHub_aug.yaml
β βββ config_URUR.yaml
β βββ config_URUR_aug.yaml
β βββ config_SWISSIMAGE.yaml
β βββ config_SWISSIMAGE_inf.yaml
β βββ config_ISIC_aug.yaml
β βββ config_CRAG_aug.yaml
βββ utils/
β βββ metrics.py
β βββ attention_viz.py
βββ train/
β βββ train.py
β βββ eval.py
β βββ inference.py
β βββ inference_vrt_ddp.py
βββ weights/
βββ Dockerfile
βββ main.py
βββ requirements.txt
βββ README.md
π Citation
@misc{caswit,
title={Context-Aware Semantic Segmentation via Stage-Wise Attention},
author={Antoine Carreaud and Elias Naha and Arthur Chansel and Nina Lahellec and Jan Skaloud and Adrien Gressin},
year={2026},
eprint={2601.11310},
url={https://arxiv.org/abs/2601.11310},
}
π License
This project is licensed under the MIT License - see the LICENSE file for details.
π Acknowledgments
- UPerNet for the base segmentation architecture
- Swin Transformer for the base segmentation architecture
- FlairHub for the dataset
- URUR for the dataset
- CRAG for the dataset
- ISIC for the dataset
- swisstopo for making SwissImage openly available
- swisstopo for making SwissTLM3D openly available
