unet_windowz / README.md
lhallee's picture
Upload folder using huggingface_hub
f045e4e verified
metadata
license: mit
tags:
  - image-segmentation
  - multilabel
  - unet
  - pytorch
  - medical-imaging
library_name: transformers
pipeline_tag: image-segmentation

unet_windowz

A attention_unet model for multilabel image segmentation trained with sliding window approach.

Model Description

  • Architecture: attention_unet
  • Input Channels: 3
  • Output Classes: 4
  • Base Filters: 32
  • Window Size: 128

Model-Specific Parameters

Training Configuration

Parameter Value
Batch Size 24
Learning Rate 0.0006966596027953339
Weight Decay 0.0032033857192098666
Epochs 100
Patience 10
Dataset GleghornLab/Semi-Automated_LN_Segmentation_10_11_2025

Performance Metrics

Metric Mean Class 0 Class 1 Class 2 Class 3
Dice 0.7068 0.7111 0.3810 0.7933 0.9417
IoU 0.6057 0.5602 0.3079 0.6629 0.8919
F1 0.7068 0.7111 0.3810 0.7933 0.9417
MCC 0.7055 0.7222 0.3770 0.7938 0.9289
ROC AUC 0.9239 0.9917 0.7119 0.9955 0.9963
PR AUC 0.7900 0.8115 0.4304 0.9336 0.9844

Usage

import numpy as np
from model import MODEL_REGISTRY, SegmentationConfig

# Load model
config = SegmentationConfig.from_pretrained("lhallee/unet_windowz")
model = MODEL_REGISTRY["attention_unet"].from_pretrained("lhallee/unet_windowz")
model.eval()

# Run inference on a full image with sliding window
image = np.random.rand(2048, 2048, 3).astype(np.float32)  # Your image here
probs = model.predict_full_image(
    image,
    dim=128,
    batch_size=16,
    device="cuda"  # or "cpu"
)
# probs shape: (num_classes, H, W) with values in [0, 1]

# Threshold to get binary masks
masks = (probs > 0.5).astype(np.uint8)

Training Plots

Training Loss Dice Curves MCC Curves Best Validation

Citation

If you use this model, please cite:

@software{windowz_segmentation,
  title={Multilabel Image Segmentation with Sliding Window U-Net},
  author={Gleghorn Lab},
  year={2025},
  url={https://github.com/GleghornLab/ComputerVision2}
}