CASWiT: Context-Aware Stage Wise Transformer for Ultra-High Resolution Semantic Segmentation
Official implementation of CASWiT, a dual-branch architecture for ultra-high resolution semantic segmentation that leverages cross-attention fusion between high-resolution and low-resolution branches.
π Table of Contents
π― Overview
CASWiT addresses the challenge of semantic segmentation on ultra-high resolution images by introducing a dual-branch architecture:
- HR Branch: Processes high-resolution crops (512Γ512) for fine-grained detail
- LR Branch: Processes low-resolution context (downsampled by 2Γ) for context
- Cross-Attention Fusion: Enables HR features to attend to LR context at each encoder stage
This design allows the model to capture both local details and global context, leading to improved segmentation performance on large-scale datasets.
In particular, CASWiT achieves 65.83 mIoU on the FLAIR-HUB RGB-only UHR benchmark and 49.1 mIoU on URUR, outperforming prior RGB/UHR baselines while remaining memory-efficient.
ποΈ Architecture
Key components:
- Dual Swin Transformer Backbones: Two UPerNet-Swin encoders process HR and LR streams
- Cross-Attention Fusion Blocks: Multi-head cross-attention at each encoder stage
- Auxiliary LR Supervision: Additional supervision on LR branch for better training
π¦ Installation
Requirements
- Python 3.12+
- PyTorch 2+
- CUDA 12+ (for GPU training)
Setup
- Clone the repository:
git clone https://huggingface.co/heig-vd-geo/CASWiT
cd CASWiT
- Install dependencies:
pip install -r requirements.txt
π Dataset Preparation
FLAIR-HUB
FlairHub is a large-scale ultra-high resolution semantic segmentation dataset. To prepare the dataset:
- Download the FlairHub dataset
- Run the preparation script to merge tiles:
python dataset/prepareFlairHub.py
The script will merge GeoTIFF tiles into larger mosaics suitable for training.
URUR
URUR dataset should be organized as:
URUR/
βββ train/
β βββ image/
β βββ label/
βββ val/
β βββ image/
β βββ label/
βββ test/
βββ image/
βββ label/
A re-hosted copy of the original URUR dataset is provided on HuggingFace in order to improve accessibility and ease of download for the research community here: URUR Dataset.
SWISSIMAGE
For SWISSIMAGE dataset:
- Download images using the provided CSV file:
python dataset/download_swissimage.py list_all_swiss_image_sept2025.csv
π Usage
Training
Train CASWiT on FlairHub:
python train/train.py configs/config_FlairHub.yaml
For distributed training (multi-GPU):
torchrun --nproc_per_node=4 train/train.py configs/config_FlairHub.yaml
Evaluation
Evaluate a trained model (on the test set):
python train/eval.py configs/config_FlairHub.yaml weights/checkpoint.pth test
Inference
Run inference on a single image:
python train/inference.py configs/config_FlairHub.yaml weights/checkpoint.pth image.tif output.png
Using Main Entry Point
Alternatively, use the unified main script:
# Training
python main.py train --config configs/config_FlairHub.yaml
# Evaluation (on test set)
python main.py eval --config configs/config_FlairHub.yaml --checkpoint weights/checkpoint.pth
# Inference a single image
python main.py inference --config configs/config_FlairHub.yaml --checkpoint weights/checkpoint.pth --image image.tif --output pred.png
βοΈ Configuration
Configuration files are in YAML format. Example structure:
paths:
data_path: "/path/to/dataset"
dataset_name: ""
train_img_subdir: "train/img"
train_msk_subdir: "train/msk"
val_img_subdir: "val/img"
val_msk_subdir: "val/msk"
test_img_subdir: "test/img"
test_msk_subdir: "test/msk"
save_dir: "weights"
pretrained_path: ""
model:
model_name: "openmmlab/upernet-swin-base" # or swin-tiny, swin-large
num_classes: 15
cross_attention_heads: 1
fusion_mlp_ratio: 4.0
fusion_drop_path: 0.1
lr_supervision_weight: 0.5
training:
batch_size: 4
num_workers: 8
num_epochs: 20
learning_rate: 0.00006
amp: true
seed: 42
eta_min: 0.000001
wandb:
use_wandb: true
project: "Fusion_HRLR"
entity: "your_entity"
run_name: "caswit_experiment"
Key Parameters
model_name: Swin variant (upernet-swin-tiny,upernet-swin-base,upernet-swin-large)cross_attention_heads: Number of attention heads in cross-attention blockslr_supervision_weight: Weight for LR branch auxiliary supervision
π Results
FLAIR-HUB (RGB-only UHR protocol)
We first evaluate CASWiT on the FLAIR-HUB ultra-high-resolution aerial benchmark under the RGB-only UHR protocol.
| Model | mIoU (%) β | mF1 (%) β | mBIoU (%) β |
|---|---|---|---|
| RGB Baselines (official FLAIR-HUB) | |||
| Swin-T + UPerNet | 62.01 | 75.27 | β |
| Swin-S + UPerNet | 61.87 | 75.11 | β |
| Swin-B + UPerNet | 64.05 | 76.88 | β |
| Swin-B + UPerNet (retrained) | 64.02 | 76.64 | 32.57 |
| Swin-L + UPerNet | 63.36 | 76.35 | β |
| Ours (RGB-only UHR protocol) | |||
| CASWiT-Base | 65.11 | 77.71 | 35.87 |
| CASWiT-Base-SSL | 65.35 | 77.87 | 35.99 |
| CASWiT-Base-SSL-aug | 65.83 | 78.22 | 36.90 |
CASWiT-Base already improves over the retrained Swin-B + UPerNet baseline, and CASWiT-Base-SSL-aug further pushes performance to 65.83 mIoU and 78.22 mF1.
On mean boundary IoU, CASWiT-Base-SSL-aug reaches 36.90 mBIoU, which is a +4.33 mBIoU gain over the retrained Swin-B baseline (32.57).
URUR
We also evaluate CASWiT on the URUR ultra-high-resolution benchmark, comparing to both generic and UHR-specific segmentation models.
| Model | mIoU (%) β | Mem (MB) β |
|---|---|---|
| Generic Models | ||
| PSPNet | 32.0 | 5482 |
| ResNet18 + DeepLabv3+ | 33.1 | 5508 |
| STDC | 42.0 | 7617 |
| UHR Models | ||
| GLNet | 41.2 | 3063 |
| FCLt | 43.1 | 4508 |
| ISDNet | 45.8 | 4920 |
| WSDNet | 46.9 | 4510 |
| Boosting Dual-branch | 48.2 | 3682 |
| CASWiT-Base | 48.7 | 3530 |
| CASWiT-Base-SSL | 49.1 | 3530 |
On URUR, CASWiT-Base already matches and slightly surpasses prior UHR-specific methods, and CASWiT-Base-SSL achieves 49.1 mIoU, i.e. +2.2 mIoU over WSDNet and +0.9 mIoU over Boosting Dual-branch (UHRS), while remaining competitive in memory usage.
π¬ Self-Supervised Learning
CASWiT also supports self-supervised pre-training using SimMIM-style SSL (Simple Masked Image Modeling):
from model.CASWiT_ssl import CASWiT_SSL
model_ssl = CASWiT_SSL(
model_name="openmmlab/upernet-swin-base",
mask_ratio_hr=0.75,
mask_ratio_lr=0.5
)
π οΈ Project Structure
CASWiT/
βββ model/
β βββ CASWiT.py # Main model architecture
β βββ CASWiT_ssl.py # SSL variant
βββ dataset/
β βββ definition_dataset.py
β βββ fusion_augment.py
β βββ download_swissimage.py
β βββ prepareFlairHub.py
βββ configs/
β βββ config_FlairHub.yaml
β βββ config_FlairHub_aug.yaml
β βββ config_URUR.yaml
β βββ config_URUR_aug.yaml
β βββ config_SWISSIMAGE.yaml
βββ utils/
β βββ metrics.py
β βββ logging.py
β βββ attention_viz.py
βββ train/
β βββ train.py
β βββ eval.py
β βββ inference.py
βββ weights/ # Model checkpoints
βββ main.py
βββ requirements.txt
βββ list_all_swiss_image_sept2025.csv
βββ README.md
π Citation
If you use CASWiT in your research, please cite:
@misc{caswit,
title={Context-Aware Semantic Segmentation via Stage-Wise Attention},
author={Antoine Carreaud and Elias Naha and Arthur Chansel and Nina Lahellec and Jan Skaloud and Adrien Gressin},
year={2026},
eprint={2601.11310},
url={https://arxiv.org/abs/2601.11310},
}
π License
This project is licensed under the MIT License - see the LICENSE file for details.
π Acknowledgments
- UPerNet for the base segmentation architecture
- Swin Transformer for the backbone
- FlairHub for the dataset
- URUR for the dataset
π§ Contact
For questions and issues, please open an issue on this repo.
