Image Segmentation
English
CASWiT / README.md
antoine.carreaud67
Apply PR #1 README metadata updates manually
aee2506
metadata
license: mit
datasets:
  - IGNF/FLAIR-HUB
  - heig-vd-geo/URUR
language:
  - en
metrics:
  - mean_iou
pipeline_tag: image-segmentation

CASWiT: Context-Aware Stage Wise Transformer for Ultra-High Resolution Semantic Segmentation

Context-Aware Semantic Segmentation via Stage-Wise Attention

License: MIT SOTA: FLAIR HUB @ RGB SOTA: URUR SOTA: ISIC SOTA: CRAG

Official implementation of CASWiT, a dual-branch architecture for ultra-high resolution semantic segmentation that leverages stage-wise cross-attention fusion between high-resolution and low-resolution branches.

πŸ“‹ Table of Contents

🎯 Overview

CASWiT addresses semantic segmentation on ultra-high resolution (UHR) imagery with a dual-resolution design:

  • HR Branch: processes high-resolution crops (e.g., 512Γ—512) for fine details
  • LR Branch: processes low-resolution context (typically downsampled) for global information
  • Stage-wise Cross-Attention Fusion: HR features attend to LR context at each encoder stage

CASWiT generalizes beyond aerial imagery without architectural changes, and we provide configs/scripts for several UHR datasets (FLAIR-HUB, URUR, SWISSIMAGE inference, and additional medical benchmarks).

πŸ—οΈ Architecture

CASWiT architecture

Key components:

  • Dual Swin Transformer Backbones
  • Cross-Attention Fusion Blocks at each encoder stage
  • Auxiliary LR Supervision (optional, weighted by model.lr_supervision_weight)

πŸ“¦ Installation

Requirements

  • Python 3.12+
  • PyTorch 2+
  • CUDA 12+ (for GPU training)

Setup

git clone https://huggingface.co/heig-vd-geo/CASWiT
cd CASWiT
pip install -r requirements.txt

🐳 Docker

A Dockerfile is provided for a reproducible environment.

Build

docker build -t caswit:latest .

Run (GPU)

docker run --gpus all -it --rm \
  -v $(pwd):/workspace \
  caswit:latest

If your datasets/checkpoints are outside the repo, mount them too, e.g.:

docker run --gpus all -it --rm \
  -v $(pwd):/workspace \
  -v /path/to/data:/data \
  -v /path/to/checkpoints:/checkpoints \
  caswit:latest

πŸ“Š Dataset Preparation

FLAIR-HUB

  1. Download the FlairHub dataset
  2. Merge GeoTIFF tiles into mosaics:
python dataset/prepareFlairHub.py

URUR

Expected structure:

URUR/
β”œβ”€β”€ train/
β”‚   β”œβ”€β”€ image/
β”‚   └── label/
β”œβ”€β”€ val/
β”‚   β”œβ”€β”€ image/
β”‚   └── label/
└── test/
    β”œβ”€β”€ image/
    └── label/

A re-hosted copy is available here: https://huggingface.co/datasets/heig-vd-geo/URUR

SWISSIMAGE

Download images using the provided CSV:

python dataset/download_swissimage.py list_all_swiss_image_sept2025.csv

SwissImage + SwissTLM3D Adaptation Pipeline

The repository also includes a full adaptation pipeline for Switzerland using:

  • SwissImage as imagery source
  • SwissTLM3D as coarse / partial supervision
  • FLAIR-HUB as the semantic anchor for mixed fine-tuning

1. Download SwissImage

python dataset/download_swissimage.py list_all_swiss_image_sept2025.csv

2. Download SwissTLM3D

The official SwissTLM3D GeoPackage can be downloaded from swisstopo:

mkdir -p /mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d
curl -L \
  https://data.geo.admin.ch/ch.swisstopo.swisstlm3d/swisstlm3d_2026-02-24/swisstlm3d_2026-02-24_2056_5728.gpkg.zip \
  -o /mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d/swisstlm3d_2026-02-24_2056_5728.gpkg.zip

unzip -o \
  /mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d/swisstlm3d_2026-02-24_2056_5728.gpkg.zip \
  -d /mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d

3. Rasterize SwissTLM3D masks on SwissImage

This creates masks in the FLAIR-HUB label space with 255 for unsupported / ignored classes.

python dataset/rasterize_swisstlm3d.py \
  --gpkg /mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d/SWISSTLM3D_2026_LV95_LN02.gpkg \
  --image_dir /mnt/CalcShare/datasets/Suisse_full/data/image \
  --output_dir /mnt/CalcShare/datasets/Suisse_full/data/mask_tlm3d \
  --workers 8

4. Create 2048 -> 1024 training patches

SwissImage has a coarser GSD than FLAIR-HUB in our setup. We therefore extract 2048x2048 windows and resize them to 1024x1024 before training.

python dataset/prepare_swisstlm3d_patches.py \
  --image_dir /mnt/CalcShare/datasets/Suisse_full/data/image \
  --mask_dir /mnt/CalcShare/datasets/Suisse_full/data/mask_tlm3d \
  --output_dir /mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d_patches_2048to1024 \
  --drop_all_ignore \
  --workers 48

5. Create a geographic split

This split is deterministic and avoids spatial leakage by splitting along contiguous geographic bands.

python dataset/split_swisstlm3d_geographic.py \
  --image_dir /mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d_patches_2048to1024/img \
  --mask_dir /mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d_patches_2048to1024/msk \
  --output_dir /mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d_patches_2048to1024_split \
  --overwrite

πŸš€ Usage

All commands below work either directly with the scripts in train/ or via the unified main.py.

Training

Single GPU:

python train/train.py configs/config_FlairHub.yaml

Multi-GPU (DDP):

torchrun --nproc_per_node=4 train/train.py configs/config_FlairHub.yaml

Evaluation

python train/eval.py configs/config_FlairHub.yaml weights/checkpoint.pth test

Inference (single image)

python train/inference.py configs/config_FlairHub.yaml weights/checkpoint.pth image.tif output.png

Inference on a VRT (DDP)

This runs tiled inference on a VRT using multiple GPUs:

torchrun --nproc_per_node=5 --master_port=29501 -m train.inference_vrt_ddp \
  --config configs/config_SWISSIMAGE_inf.yaml \
  --checkpoint weights/CASWiT-Base-SSL_FLAIRHUB_UN.pth \
  --vrt file.vrt \
  --out_dir output/ \
  --tile 1024 \
  --stride 512 \
  --lr_side 2048

Using main.py

# Train
python main.py train --config configs/config_FlairHub.yaml

# Eval (default split = test in the script if not specified)
python main.py eval --config configs/config_FlairHub.yaml --checkpoint weights/checkpoint.pth

# Inference (single image)
python main.py inference --config configs/config_FlairHub.yaml --checkpoint weights/checkpoint.pth --image image.tif --output pred.png

βš™οΈ Configuration

Configs are YAML files in configs/.

Model selection (decoder head)

You can select the decoder head directly in the config:

model:
  # options: upernet | segformer | mask2former | fusion_last_stage_add | ssl
  head: upernet

Notes:

  • mask2former may require training.amp: false depending on your environment/precision settings.
  • ssl is the SimMIM-style pretraining model and is not a drop-in replacement for segmentation inference/eval scripts.

Example config structure

paths:
  data_path: "/path/to/dataset"
  dataset_name: "FLAIRHUB"
  train_img_subdir: "train/img"
  train_msk_subdir: "train/msk"
  val_img_subdir: "valid/img"
  val_msk_subdir: "valid/msk"
  test_img_subdir: "test/img"
  test_msk_subdir: "test/msk"
  save_dir: "weights"
  pretrained_path: ""

model:
  model_name: "openmmlab/upernet-swin-base"
  num_classes: 15
  cross_attention_heads: 1
  ignore_index: 255
  fusion_mlp_ratio: 4.0
  fusion_drop_path: 0.1
  lr_supervision_weight: 0.5
  # options: upernet | segformer | mask2former | fusion_last_stage_add | ssl
  head: upernet

training:
  batch_size: 4
  num_workers: 8
  num_epochs: 20
  learning_rate: 0.00006
  amp: true
  seed: 42
  eta_min: 0.000001

wandb:
  use_wandb: true
  project: "CASWiT"
  entity: "your-entity"
  run_name: "caswit_experiment"

augmentations:
  enable: false
  p_hflip: 0.5
  p_vflip: 0.5
  p_rot90: 0.5
  color_jitter:
    brightness: 0.2
    contrast: 0.2
    saturation: 0.2
    hue: 0.05
  blur:
    p: 0.1
    kernel: 3

πŸ§ͺ Reproducing experiments (all provided configs)

Below are ready-to-run commands for each config. Replace --checkpoint with your own file when evaluating/inferencing.

FLAIR-HUB

# Train (no extra aug)
python main.py train --config configs/config_FlairHub.yaml

# Train (with augmentations)
python main.py train --config configs/config_FlairHub_aug.yaml

# Eval
python main.py eval --config configs/config_FlairHub.yaml --checkpoint weights/checkpoint.pth

URUR

python main.py train --config configs/config_URUR.yaml
python main.py train --config configs/config_URUR_aug.yaml

python main.py eval --config configs/config_URUR.yaml --checkpoint weights/checkpoint.pth

ISIC (medical)

python main.py train --config configs/config_ISIC_aug.yaml
python main.py eval  --config configs/config_ISIC_aug.yaml --checkpoint weights/checkpoint.pth

CRAG (medical)

python main.py train --config configs/config_CRAG_aug.yaml
python main.py eval  --config configs/config_CRAG_aug.yaml --checkpoint weights/checkpoint.pth

Swiss Fine-Tuning

SwissTLM3D-only fine-tuning

This is the simplest adaptation setup: start from the FLAIR-HUB SegFormer checkpoint and fine-tune on the prepared Swiss patches only.

python main.py train --config configs/config_SWISSTLM3D_segformer_ft.yaml

Mixed FLAIR-HUB + SwissTLM3D fine-tuning

This is the recommended strategy for adaptation:

  • keep FLAIR-HUB as the strong semantic supervision source
  • inject SwissTLM3D patches as domain adaptation supervision
  • validate separately on FLAIR-HUB and SwissTLM3D
  • select the best checkpoint using the configured mixed validation score

Before training, update the Swiss path in configs/config_mixed_domain_ft_segformer.yaml to point to your geographic split of prepared patches, e.g.:

paths:
  swiss_path: /mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d_patches_2048to1024_split

Then launch:

torchrun --nproc_per_node=4 --master_port=29501 \
  train/train_mixed_domain_ft.py \
  configs/config_mixed_domain_ft_segformer.yaml

The mixed script logs:

  • val_flair_mIoU, val_flair_mF1
  • val_swiss_mIoU, val_swiss_mF1
  • model_selection_score

The default mixed strategy uses:

  • swiss_ratio: 0.30
  • best_on: mean

This means the best checkpoint is selected using the mean of FLAIR-HUB and SwissTLM3D validation mIoU.

πŸ“ˆ Results

FLAIR-HUB (RGB-only UHR protocol)

Model mIoU (%) ↑ mF1 (%) ↑ mBIoU (%) ↑ GFLOPs ↓ FPS ↑
RGB Baselines (official)
Swin-T + UPerNet 62.01 75.27 – 237 69.2
Swin-S + UPerNet 61.87 75.11 – 261 41.5
Swin-B + UPerNet 64.05 76.88 – 306 36.3
Swin-B + UPerNet (retrained) 64.02 76.64 32.57 – –
Swin-L + UPerNet 63.36 76.35 – 420 27.8
Dual-branch baselines
ISDNet (trained with Pytorch implementation) 52.77 - - -
Dual Swin-Base (late fusion, no CA) 64.25 – – 398 19.4
CASWiT-Base + UPerNet 65.11 77.71 35.87 489 15.4
CASWiT-Base-SSL + UPerNet 65.35 77.87 35.99 489 15.4
CASWiT-Base-SSL-aug + UPerNet 65.83 78.22 36.90 489 15.4
CASWiT-Base-SSL-aug + SegFormer 66.37 78.58 36.51 298 17.9

CASWiT-Base already improves over the retrained Swin-B + UPerNet baseline, and CASWiT-Base-SSL-aug + SegFormer further pushes performance to 66.37 mIoU and ** 78.58 mF1**.
On mean boundary IoU, CASWiT-Base-SSL-aug + SegFormer reaches 36.51 mBIoU, which is a 3.94 mBIoU gain over the retrained Swin-B baseline (32.57).

ISIC / CRAG (test sets)

Method ISIC CRAG
GPWFormer 80.7 89.9
Boosting Dual-Branch 83.4 90.3
CASWiT-Base-SSL-aug + UperNet (ours) 85.4 90.3
CASWiT-Base-SSL-aug + SegFormer (ours) 86.5 90.7

URUR

We also evaluate CASWiT on the URUR ultra-high-resolution benchmark, comparing to both generic and UHR-specific segmentation models.

Model mIoU (%) ↑ Mem (MB) ↓
Generic Models
PSPNet 32.0 5482
ResNet18 + DeepLabv3+ 33.1 5508
STDC 42.0 7617
UHR Models
GLNet 41.2 3063
FCLt 43.1 4508
ISDNet 45.8 4920
WSDNet 46.9 4510
Boosting Dual-stream 48.2 3682
CASWiT-Base + UPer 48.7 2996
CASWiT-Base-SSL-aug + UPer 49.1 2996
CASWiT-Base-SSL-aug + SegF 49.2 2878

On URUR, CASWiT-Base + UPer already matches and slightly surpasses prior UHR-specific methods, CASWiT-Base-SSL-aug + UPer reaches 49.1 mIoU, and CASWiT-Base-SSL-aug + SegF further improves to 49.2 mIoU while remaining competitive in memory usage.

Switzerland Adaptation (SwissImage + SwissTLM3D)

We provide the full preparation and fine-tuning pipeline for SwissImage + SwissTLM3D adaptation. The recommended training setup is mixed fine-tuning:

  • initialize from CASWiT-Base-SSL-aug_FLAIRHUB_SF
  • train on a mixture of FLAIR-HUB and SwissTLM3D patches
  • keep separate validation metrics for the two domains

Mixed Fine-Tuning Results

To be filled once experiments are complete.

Model Val FLAIR-HUB mIoU (%) ↑ Val SwissTLM3D mIoU (%) ↑ Test FLAIR-HUB mIoU (%) ↑ Test SwissTLM3D mIoU (%) ↑
CASWiT-Base-SSL-aug + SegFormer + mixed FT – – – –

πŸ”¬ Self-Supervised Learning

CASWiT also supports self-supervised pre-training using SimMIM-style SSL (Simple Masked Image Modeling). We used this configuration on the entire SWISSIMAGE dataset to pretrain CASWiT.

πŸ› οΈ Project Structure

CASWiT/
β”œβ”€β”€ model/
β”‚   β”œβ”€β”€ CASWiT_upernet.py
β”‚   β”œβ”€β”€ CASWiT_segformer.py
β”‚   β”œβ”€β”€ CASWiT_m2f.py
β”‚   β”œβ”€β”€ CASWiT_fusion_last_stage_add.py
β”‚   └── CASWiT_ssl.py
β”œβ”€β”€ dataset/
β”‚   β”œβ”€β”€ definition_dataset.py
β”‚   β”œβ”€β”€ fusion_augment.py
β”‚   β”œβ”€β”€ download_swissimage.py
β”‚   └── prepareFlairHub.py
β”œβ”€β”€ configs/
β”‚   β”œβ”€β”€ config_FlairHub.yaml
β”‚   β”œβ”€β”€ config_FlairHub_aug.yaml
β”‚   β”œβ”€β”€ config_URUR.yaml
β”‚   β”œβ”€β”€ config_URUR_aug.yaml
β”‚   β”œβ”€β”€ config_SWISSIMAGE.yaml
β”‚   β”œβ”€β”€ config_SWISSIMAGE_inf.yaml
β”‚   β”œβ”€β”€ config_ISIC_aug.yaml
β”‚   └── config_CRAG_aug.yaml
β”œβ”€β”€ utils/
β”‚   β”œβ”€β”€ metrics.py
β”‚   └── attention_viz.py
β”œβ”€β”€ train/
β”‚   β”œβ”€β”€ train.py
β”‚   β”œβ”€β”€ eval.py
β”‚   β”œβ”€β”€ inference.py
β”‚   └── inference_vrt_ddp.py
β”œβ”€β”€ weights/
β”œβ”€β”€ Dockerfile
β”œβ”€β”€ main.py
β”œβ”€β”€ requirements.txt
└── README.md

πŸ“ Citation

@misc{caswit,
      title={Context-Aware Semantic Segmentation via Stage-Wise Attention}, 
      author={Antoine Carreaud and Elias Naha and Arthur Chansel and Nina Lahellec and Jan Skaloud and Adrien Gressin},
      year={2026},
      eprint={2601.11310},
      url={https://arxiv.org/abs/2601.11310}, 
}

πŸ“„ License

This project is licensed under the MIT License - see the LICENSE file for details.

πŸ™ Acknowledgments