Image Segmentation
English

CASWiT: Context-Aware Stage Wise Transformer for Ultra-High Resolution Semantic Segmentation

License: MIT SOTA: FLAIR HUB @ RGB SOTA: URUR

Official implementation of CASWiT, a dual-branch architecture for ultra-high resolution semantic segmentation that leverages cross-attention fusion between high-resolution and low-resolution branches.

πŸ“‹ Table of Contents

🎯 Overview

CASWiT addresses the challenge of semantic segmentation on ultra-high resolution images by introducing a dual-branch architecture:

  • HR Branch: Processes high-resolution crops (512Γ—512) for fine-grained detail
  • LR Branch: Processes low-resolution context (downsampled by 2Γ—) for context
  • Cross-Attention Fusion: Enables HR features to attend to LR context at each encoder stage

This design allows the model to capture both local details and global context, leading to improved segmentation performance on large-scale datasets.
In particular, CASWiT achieves 65.83 mIoU on the FLAIR-HUB RGB-only UHR benchmark and 49.1 mIoU on URUR, outperforming prior RGB/UHR baselines while remaining memory-efficient.

πŸ—οΈ Architecture

CASWiT architecture

Key components:

  • Dual Swin Transformer Backbones: Two UPerNet-Swin encoders process HR and LR streams
  • Cross-Attention Fusion Blocks: Multi-head cross-attention at each encoder stage
  • Auxiliary LR Supervision: Additional supervision on LR branch for better training

πŸ“¦ Installation

Requirements

  • Python 3.12+
  • PyTorch 2+
  • CUDA 12+ (for GPU training)

Setup

  1. Clone the repository:
git clone https://huggingface.co/heig-vd-geo/CASWiT
cd CASWiT
  1. Install dependencies:
pip install -r requirements.txt

πŸ“Š Dataset Preparation

FLAIR-HUB

FlairHub is a large-scale ultra-high resolution semantic segmentation dataset. To prepare the dataset:

  1. Download the FlairHub dataset
  2. Run the preparation script to merge tiles:
python dataset/prepareFlairHub.py

The script will merge GeoTIFF tiles into larger mosaics suitable for training.

URUR

URUR dataset should be organized as:

URUR/
β”œβ”€β”€ train/
β”‚   β”œβ”€β”€ image/
β”‚   └── label/
β”œβ”€β”€ val/
β”‚   β”œβ”€β”€ image/
β”‚   └── label/
└── test/
    β”œβ”€β”€ image/
    └── label/

A re-hosted copy of the original URUR dataset is provided on HuggingFace in order to improve accessibility and ease of download for the research community here: URUR Dataset.

SWISSIMAGE

For SWISSIMAGE dataset:

  1. Download images using the provided CSV file:
python dataset/download_swissimage.py list_all_swiss_image_sept2025.csv

πŸš€ Usage

Training

Train CASWiT on FlairHub:

python train/train.py configs/config_FlairHub.yaml

For distributed training (multi-GPU):

torchrun --nproc_per_node=4 train/train.py configs/config_FlairHub.yaml

Evaluation

Evaluate a trained model (on the test set):

python train/eval.py configs/config_FlairHub.yaml weights/checkpoint.pth test

Inference

Run inference on a single image:

python train/inference.py configs/config_FlairHub.yaml weights/checkpoint.pth image.tif output.png

Using Main Entry Point

Alternatively, use the unified main script:

# Training
python main.py train --config configs/config_FlairHub.yaml

# Evaluation (on test set)
python main.py eval --config configs/config_FlairHub.yaml --checkpoint weights/checkpoint.pth

# Inference a single image
python main.py inference --config configs/config_FlairHub.yaml --checkpoint weights/checkpoint.pth --image image.tif --output pred.png

βš™οΈ Configuration

Configuration files are in YAML format. Example structure:

paths:
  data_path: "/path/to/dataset"
  dataset_name: ""
  train_img_subdir: "train/img"
  train_msk_subdir: "train/msk"
  val_img_subdir: "val/img"
  val_msk_subdir: "val/msk"
  test_img_subdir: "test/img"
  test_msk_subdir: "test/msk"
  save_dir: "weights"
  pretrained_path: ""

model:
  model_name: "openmmlab/upernet-swin-base"  # or swin-tiny, swin-large
  num_classes: 15
  cross_attention_heads: 1
  fusion_mlp_ratio: 4.0
  fusion_drop_path: 0.1
  lr_supervision_weight: 0.5

training:
  batch_size: 4
  num_workers: 8
  num_epochs: 20
  learning_rate: 0.00006
  amp: true
  seed: 42
  eta_min: 0.000001

wandb:
  use_wandb: true
  project: "Fusion_HRLR"
  entity: "your_entity"
  run_name: "caswit_experiment"

Key Parameters

  • model_name: Swin variant (upernet-swin-tiny, upernet-swin-base, upernet-swin-large)
  • cross_attention_heads: Number of attention heads in cross-attention blocks
  • lr_supervision_weight: Weight for LR branch auxiliary supervision

πŸ“ˆ Results

FLAIR-HUB (RGB-only UHR protocol)

We first evaluate CASWiT on the FLAIR-HUB ultra-high-resolution aerial benchmark under the RGB-only UHR protocol.

Model mIoU (%) ↑ mF1 (%) ↑ mBIoU (%) ↑
RGB Baselines (official FLAIR-HUB)
Swin-T + UPerNet 62.01 75.27 –
Swin-S + UPerNet 61.87 75.11 –
Swin-B + UPerNet 64.05 76.88 –
Swin-B + UPerNet (retrained) 64.02 76.64 32.57
Swin-L + UPerNet 63.36 76.35 –
Ours (RGB-only UHR protocol)
CASWiT-Base 65.11 77.71 35.87
CASWiT-Base-SSL 65.35 77.87 35.99
CASWiT-Base-SSL-aug 65.83 78.22 36.90

CASWiT-Base already improves over the retrained Swin-B + UPerNet baseline, and CASWiT-Base-SSL-aug further pushes performance to 65.83 mIoU and 78.22 mF1.
On mean boundary IoU, CASWiT-Base-SSL-aug reaches 36.90 mBIoU, which is a +4.33 mBIoU gain over the retrained Swin-B baseline (32.57).


URUR

We also evaluate CASWiT on the URUR ultra-high-resolution benchmark, comparing to both generic and UHR-specific segmentation models.

Model mIoU (%) ↑ Mem (MB) ↓
Generic Models
PSPNet 32.0 5482
ResNet18 + DeepLabv3+ 33.1 5508
STDC 42.0 7617
UHR Models
GLNet 41.2 3063
FCLt 43.1 4508
ISDNet 45.8 4920
WSDNet 46.9 4510
Boosting Dual-branch 48.2 3682
CASWiT-Base 48.7 3530
CASWiT-Base-SSL 49.1 3530

On URUR, CASWiT-Base already matches and slightly surpasses prior UHR-specific methods, and CASWiT-Base-SSL achieves 49.1 mIoU, i.e. +2.2 mIoU over WSDNet and +0.9 mIoU over Boosting Dual-branch (UHRS), while remaining competitive in memory usage.

πŸ”¬ Self-Supervised Learning

CASWiT also supports self-supervised pre-training using SimMIM-style SSL (Simple Masked Image Modeling):

from model.CASWiT_ssl import CASWiT_SSL

model_ssl = CASWiT_SSL(
    model_name="openmmlab/upernet-swin-base",
    mask_ratio_hr=0.75,
    mask_ratio_lr=0.5
)

πŸ› οΈ Project Structure

CASWiT/
β”œβ”€β”€ model/
β”‚   β”œβ”€β”€ CASWiT.py          # Main model architecture
β”‚   └── CASWiT_ssl.py      # SSL variant
β”œβ”€β”€ dataset/
β”‚   β”œβ”€β”€ definition_dataset.py
β”‚   β”œβ”€β”€ fusion_augment.py
β”‚   β”œβ”€β”€ download_swissimage.py
β”‚   └── prepareFlairHub.py
β”œβ”€β”€ configs/
β”‚   β”œβ”€β”€ config_FlairHub.yaml
β”‚   β”œβ”€β”€ config_FlairHub_aug.yaml
β”‚   β”œβ”€β”€ config_URUR.yaml
β”‚   β”œβ”€β”€ config_URUR_aug.yaml
β”‚   └── config_SWISSIMAGE.yaml
β”œβ”€β”€ utils/
β”‚   β”œβ”€β”€ metrics.py
β”‚   β”œβ”€β”€ logging.py
β”‚   └── attention_viz.py
β”œβ”€β”€ train/
β”‚   β”œβ”€β”€ train.py
β”‚   β”œβ”€β”€ eval.py
β”‚   └── inference.py
β”œβ”€β”€ weights/               # Model checkpoints
β”œβ”€β”€ main.py
β”œβ”€β”€ requirements.txt
β”œβ”€β”€ list_all_swiss_image_sept2025.csv
└── README.md

πŸ“ Citation

If you use CASWiT in your research, please cite:

@misc{caswit,
      title={Context-Aware Semantic Segmentation via Stage-Wise Attention}, 
      author={Antoine Carreaud and Elias Naha and Arthur Chansel and Nina Lahellec and Jan Skaloud and Adrien Gressin},
      year={2026},
      eprint={2601.11310},
      url={https://arxiv.org/abs/2601.11310}, 
}

πŸ“„ License

This project is licensed under the MIT License - see the LICENSE file for details.

πŸ™ Acknowledgments

πŸ“§ Contact

For questions and issues, please open an issue on this repo.


Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train heig-vd-geo/CASWiT

Collection including heig-vd-geo/CASWiT

Paper for heig-vd-geo/CASWiT