U-Mamba_Enc on bruniss's Dataset059 (Scroll surface prediction)
This repository hosts a U-Mamba_Enc checkpoint trained for surface (papyrus sheet) prediction in carbonized Vesuvius scroll volumes, motivated by ScrollPrize/villa#191. It is a baseline from the architecture comparison documented in https://github.com/ciscoriordan/mednext-vs-umamba-scroll, not the production winner.
For the production model that addresses #191 (+0.27 high_compressed IoU vs d058 on held-out S1 cubes), see the companion Spec 1.5 = MedNeXt-L kernel5 + SkeletonRecall checkpoint.
Status: training stopped at plateau (epoch 319)
| metric | value |
|---|---|
| Best validation pseudo-dice | 0.5883 |
| Best epoch | 263 |
| Total epochs trained | 319 (out of 1000 default) |
| Final mean (last 30 epochs) | 0.5651 |
| Wall time | ~10.5 hours |
| Hardware | Vast.ai A100 SXM4 80GB |
| Cost | ~$11.55 |
Training was stopped early when both plateau criteria were met simultaneously:
- 50+ epochs since the last best-validation-dice checkpoint (we hit 55 since epoch 263)
- Mean of last 30 epochs' validation dice ≤ 0.57 (was 0.5651)
Files
checkpoint_best.pth(327 MB) — best validation-dice checkpoint, epoch 263plans.json— nnUNet plans file (architecture config, normalization, target spacing)dataset.json— dataset descriptor from preprocessingtraining_log.txt— per-epoch loss / val_loss / pseudo-dice trace through all 319 epochsprogress.png— loss-curve plot from nnUNet's built-in plotter
Why U-Mamba_Enc
bruniss's existing d058 surface model (nnUNet ResEnc-L, 3³ kernels) underperforms in compressed and highly curved scroll regions, where adjacent papyrus sheets get merged into a single predicted surface. Issue #191 calls for help on this failure mode.
U-Mamba_Enc (Ma 2024) places Mamba state-space model blocks at every encoder stage of an nnUNet-style 3D U-Net. A single Mamba block can integrate context across the entire spatial sequence in linear time, theoretically the right answer when disambiguation requires evidence from millimeters away.
Training data
bruniss's Dataset059_s1_s4_s5_patches_frangiedt — 1754 volume/label pairs from Scrolls 1, 4, 5 with Frangi vesselness + distance-transform-enhanced surface labels. Same training config as bruniss's d058:
- 3d_fullres, batch 2, 128³ patches, fold 0 (1403 train / 351 val)
- 6-level U-Net, 32 base features, Mamba blocks at every encoder stage
- Dice + cross-entropy loss
- SGD with Nesterov, initial LR 0.01, PolyLR schedule
Loading
from nnunetv2.training.nnUNetTrainer.nnUNetTrainerUMambaEnc import nnUNetTrainerUMambaEnc
import torch, json
state = torch.load("checkpoint_best.pth", map_location="cpu")
plans = json.load(open("plans.json"))
from nnunetv2.nets.UMambaEnc_3d import get_umamba_enc_3d_from_plans
network = get_umamba_enc_3d_from_plans(
plans["configurations"]["3d_fullres"],
num_input_channels=1, num_classes=2,
)
network.load_state_dict(state["network_weights"])
network.eval().cuda()
For sliding-window inference on a full scroll cube use nnUNetv2_predict with this checkpoint as the model dir; see the parent training repo for inference scripts.
Dependencies
- torch 2.5.1+cu124 (NOT cu130 — see parent repo's
setup-recipe.mdgotcha #4) - mamba-ssm 2.3.1, causal-conv1d 1.6.1 (built via
pip install --no-build-isolation) - U-Mamba's bundled nnunetv2 fork (2.1.1)
Inputs and outputs
- Input: 3D CT cube, single channel uint8, isotropic 1×1×1 voxel spacing. Patch size 128³.
- Output: 2-class softmax (background, fiber/surface). Take argmax for binary mask, or threshold the foreground prob at 0.5.
Honest caveat
This checkpoint hit a plateau in training-time pseudo-dice short of the line that the companion MedNeXt-L kernel5 + plain Dice+CE training crossed (0.6017 vs this model's 0.5883). On the held-out high_compressed IoU benchmark (the real headline number for issue #191), U-Mamba_Enc + plain Dice+CE was not directly evaluated; the architecture-only experiment on Dataset059 favored MedNeXt-L for this task.
For the actual production model that addresses #191, see the companion Spec 1.5 (MedNeXt-L kernel5 + SkeletonRecall) checkpoint, which reaches 0.671 mean high_compressed IoU on truly-held-out Scroll-1 cubes versus d058's 0.404.
This U-Mamba_Enc checkpoint is published for transparency and as a starting point for anyone wanting to test U-Mamba combined with topology-aware losses (clDice, SkeletonRecall) which we did not test here.
License
MIT for code/weights. Underlying training data (Dataset059) carries bruniss's license; we don't redistribute it.
Citation
@article{ma2024umamba,
title={U-Mamba: Enhancing long-range dependency for biomedical image segmentation},
author={Ma, Jun and Li, Feifei and Wang, Bo},
journal={arXiv preprint arXiv:2401.04722},
year={2024}
}
This work also depends on the publicly published bruniss Dataset059, the nnUNet framework, and the open Vesuvius Challenge data infrastructure.