Pressure Sore Cascade Classifier β€” Torchvision 3-Level

This repository contains 8 PyTorch model weights forming a 3-level hierarchical cascade for automated pressure sore detection and staging. The cascade progressively narrows from detecting any wound, to separating severity groups, to fine-grained stage classification β€” mirroring clinical decision-making.

These weights are used by ps_classifier_torch_cascade.py in the PS Classifier web application.


Cascade Structure

Image
  β”‚
  β–Ό
[Level 1 β€” PS vs No-PS]                         BCEWithLogitsLoss Β· sigmoid
  MaxVit_T  (linear head)
  ResNet50  (mlp head)
  β”‚
  β”œβ”€ NO  β†’ "No pressure sore detected"
  └─ YES β–Ό
[Level 2 β€” Early (Stage I/II) vs Advanced (III/IV)]   BCEWithLogitsLoss Β· sigmoid
  ConvNeXt_Base      (mlp head)
  EfficientNet_V2_L  (linear head)
  β”‚
  β”œβ”€ EARLY ──────────────────────┐
  └─ ADVANCED ───────┐           β”‚
                     β–Ό           β–Ό
  [Level 3b]                   [Level 3a]
  Stage III vs Stage IV         Stage I vs Stage II
  ConvNeXt_Large (MSH)          EfficientNet_V2_L (mlp)
  ViT_B_16 (mlp)                ConvNeXt_Tiny (linear)
  CrossEntropyLoss              BCEWithLogitsLoss Β· sigmoid
  WrappedModel pattern          Direct-attachment pattern
  ↓ Confidence gate 0.65 ↓      ↓ Confidence gate 0.65 ↓

Confidence gating: if the Level 3 ensemble confidence falls below 0.65 the prediction is still returned, but annotated with an uncertainty warning and flagged for clinical review (details["level_3"]["gated"] == True).


Files

File Level Architecture Head Loss
Level 1 Binary PS or not PS MaxVit_T.pth L1 MaxVit-T linear BCE
Level 1 Binary PS or not PS ResNet50.pth L1 ResNet-50 mlp BCE
Level 2 Early vs Advanced ConvNeXt_Base.pth L2 ConvNeXt-Base mlp BCE
Level 2 Early vs Advanced EfficientNet_V2_L.pth L2 EfficientNet-V2-L linear BCE
Level 3a Early EfficientNet_V2_L.pth L3a EfficientNet-V2-L mlp BCE
Level 3a Early ConvNeXt_Tiny.pth L3a ConvNeXt-Tiny linear BCE
Level 3b Advanced ConvNeXt_Large.pth L3b ConvNeXt-Large multi_stage_head XEnt
Level 3b Advanced ViT_B_16.pth L3b ViT-B/16 mlp XEnt

MSH = MultiStageHead (Dropout → FC(in→in/2) → BN → ReLU → FC(in/2→2))


Model Performance

Level 1 β€” PS vs No-PS (test set: 261 images)

Model Head Dropout Scheduler Accuracy Macro F1 AUC-ROC
MaxVit_T linear 0.2396 CosineAnnealingLR 0.9962 0.9962 1.0000
ResNet50 mlp 0.5840 CosineAnnealingLR 1.0000 1.0000 1.0000

Level 2 β€” Early vs Advanced (test set: 125 images)

Model Head Dropout Optimizer Scheduler Accuracy Macro F1 AUC-ROC
ConvNeXt_Base mlp 0.1025 AdamP CosineAnnealingLR 0.9520 0.9520 0.9857
EfficientNet_V2_L linear 0.3564 AdamP CosineAnnealingLR 0.9600 0.9600 0.9916

Level 3a β€” Stage I vs Stage II (test set: 63 images)

Model Head Dropout Optimizer Scheduler Accuracy Macro F1 AUC-ROC
EfficientNet_V2_L mlp 0.1949 AdamP StepLR 0.9048 0.9047 0.9849
ConvNeXt_Tiny linear 0.1601 Lion CosineAnnealingLR 0.9683 0.9682 0.9909

Level 3b β€” Stage III vs Stage IV (test set: 63 images)

Model Head Dropout Optimizer Scheduler Accuracy Macro F1 AUC-ROC
ConvNeXt_Large multi_stage_head 0.6594 Lion ReduceLROnPlateau 0.7778 0.7773 0.8861
ViT_B_16 mlp 0.5445 AdamW ReduceLROnPlateau 0.7937 0.7934 0.8569

Stage III vs Stage IV is the hardest sub-task β€” subtle visual differences between full-thickness tissue loss with and without exposed bone/muscle make it challenging even for clinicians. The confidence gate at Level 3b (0.65) flags the most uncertain predictions.


Training Details

All models were trained on a curated dataset of ~1,000 pressure sore images collected from public medical databases, with stratified 70/20/10 train/validation/test splits.

Shared configuration:

  • Input: 224 Γ— 224, ImageNet normalisation (mean [0.485, 0.456, 0.406], std [0.229, 0.224, 0.225])
  • Augmentation: random flips, rotation, colour jitter, Gaussian blur, affine transforms (Albumentations)
  • Freeze schedule: backbone frozen for initial epochs, then progressively unfrozen (2-stage)
  • Early stopping: patience 8–10 epochs on validation loss
  • Hyperparameters: selected by Optuna trials (learning rate, weight decay, dropout, head type)
  • Mixed precision: fp16 via Accelerate

Architecture notes:

  • L1, L2, L3a: head is attached directly to the backbone's native classifier slot (model.classifier[2] for ConvNeXt, model.heads.head for ViT, etc.). Saved state dict has flat keys.
  • L3b: WrappedModel wrapper β€” backbone classifier slot replaced with nn.Identity, a separate head receives raw feature embeddings. Saved state dict has backbone.* / head.* key prefixes.

Usage

Installation

pip install torch torchvision albumentations pillow

Minimal inference example

import torch
import torch.nn as nn
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision import models
from PIL import Image

# Helpers 

def load_standard(arch_fn, in_feat, head_type, dropout, path, num_classes=1):
    """L1 / L2 / L3a β€” head directly on backbone, BCE/sigmoid."""
    model = arch_fn(weights=None)
    if head_type == "linear":
        head = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
    else:  # mlp
        head = nn.Sequential(
            nn.Dropout(dropout), nn.Linear(in_feat, in_feat // 2),
            nn.ReLU(inplace=True), nn.Dropout(dropout),
            nn.Linear(in_feat // 2, num_classes))
    model.classifier[2] = head   # adjust slot per arch (see ps_classifier_torch_cascade.py)
    sd = torch.load(path, map_location="cpu", weights_only=False)
    model.load_state_dict(sd, strict=False)
    return model.eval()


transform = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

def preprocess(path):
    img = Image.open(path).convert("RGB")
    return transform(image=np.array(img))["image"].unsqueeze(0)

# Full cascade (recommended: use ps_classifier_torch_cascade.py) 

# Clone the repo and place weights under models/torch_cascade/
# then simply:

from ps_classifier_torch_cascade import classify_image_cascade, cascade_confidence

image, message, details = classify_image_cascade("path/to/wound.jpg")
print(message)
# βœ… Pressure sore detected
# Severity : early (0.96)
# Stage    : Stage II (0.91)

print("Joint confidence:", cascade_confidence(details))  # e.g. 0.864

if details.get("level_3", {}).get("gated"):
    print("⚠ Low Level-3 confidence β€” recommend clinical review")

details schema

{
  "level_1": {"label": str,  "confidence": float},
  "level_2": {"label": str,  "confidence": float},          # only if PS detected
  "level_3": {
      "label":      str,     # "Stage I" / "Stage II" / "Stage III" / "Stage IV"
      "confidence": float,
      "group":      str,     # "Early" or "Advanced"
      "gated":      bool     # True when confidence < 0.65
  }
}

Limitations & Disclaimer

  • Trained on ~1,000 images from public educational resources β€” not a clinical-grade dataset
  • Stage III vs Stage IV accuracy (~0.79–0.79 AUC ~0.86–0.89) reflects the inherent difficulty of this sub-task
  • Confidence gating reduces but does not eliminate incorrect staging
  • This is a research/demonstration tool β€” not a medical device and not validated for clinical use
  • Always consult a licensed healthcare professional for diagnosis and treatment decisions

Related Resources


License

MIT β€” see LICENSE

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support