Pressure Sore Cascade Classifier β Torchvision 3-Level
This repository contains 8 PyTorch model weights forming a 3-level hierarchical cascade for automated pressure sore detection and staging. The cascade progressively narrows from detecting any wound, to separating severity groups, to fine-grained stage classification β mirroring clinical decision-making.
These weights are used by ps_classifier_torch_cascade.py in the PS Classifier web application.
Cascade Structure
Image
β
βΌ
[Level 1 β PS vs No-PS] BCEWithLogitsLoss Β· sigmoid
MaxVit_T (linear head)
ResNet50 (mlp head)
β
ββ NO β "No pressure sore detected"
ββ YES βΌ
[Level 2 β Early (Stage I/II) vs Advanced (III/IV)] BCEWithLogitsLoss Β· sigmoid
ConvNeXt_Base (mlp head)
EfficientNet_V2_L (linear head)
β
ββ EARLY βββββββββββββββββββββββ
ββ ADVANCED ββββββββ β
βΌ βΌ
[Level 3b] [Level 3a]
Stage III vs Stage IV Stage I vs Stage II
ConvNeXt_Large (MSH) EfficientNet_V2_L (mlp)
ViT_B_16 (mlp) ConvNeXt_Tiny (linear)
CrossEntropyLoss BCEWithLogitsLoss Β· sigmoid
WrappedModel pattern Direct-attachment pattern
β Confidence gate 0.65 β β Confidence gate 0.65 β
Confidence gating: if the Level 3 ensemble confidence falls below 0.65 the prediction is still returned, but annotated with an uncertainty warning and flagged for clinical review (
details["level_3"]["gated"] == True).
Files
| File | Level | Architecture | Head | Loss |
|---|---|---|---|---|
Level 1 Binary PS or not PS MaxVit_T.pth |
L1 | MaxVit-T | linear | BCE |
Level 1 Binary PS or not PS ResNet50.pth |
L1 | ResNet-50 | mlp | BCE |
Level 2 Early vs Advanced ConvNeXt_Base.pth |
L2 | ConvNeXt-Base | mlp | BCE |
Level 2 Early vs Advanced EfficientNet_V2_L.pth |
L2 | EfficientNet-V2-L | linear | BCE |
Level 3a Early EfficientNet_V2_L.pth |
L3a | EfficientNet-V2-L | mlp | BCE |
Level 3a Early ConvNeXt_Tiny.pth |
L3a | ConvNeXt-Tiny | linear | BCE |
Level 3b Advanced ConvNeXt_Large.pth |
L3b | ConvNeXt-Large | multi_stage_head | XEnt |
Level 3b Advanced ViT_B_16.pth |
L3b | ViT-B/16 | mlp | XEnt |
MSH = MultiStageHead (Dropout β FC(inβin/2) β BN β ReLU β FC(in/2β2))
Model Performance
Level 1 β PS vs No-PS (test set: 261 images)
| Model | Head | Dropout | Scheduler | Accuracy | Macro F1 | AUC-ROC |
|---|---|---|---|---|---|---|
| MaxVit_T | linear | 0.2396 | CosineAnnealingLR | 0.9962 | 0.9962 | 1.0000 |
| ResNet50 | mlp | 0.5840 | CosineAnnealingLR | 1.0000 | 1.0000 | 1.0000 |
Level 2 β Early vs Advanced (test set: 125 images)
| Model | Head | Dropout | Optimizer | Scheduler | Accuracy | Macro F1 | AUC-ROC |
|---|---|---|---|---|---|---|---|
| ConvNeXt_Base | mlp | 0.1025 | AdamP | CosineAnnealingLR | 0.9520 | 0.9520 | 0.9857 |
| EfficientNet_V2_L | linear | 0.3564 | AdamP | CosineAnnealingLR | 0.9600 | 0.9600 | 0.9916 |
Level 3a β Stage I vs Stage II (test set: 63 images)
| Model | Head | Dropout | Optimizer | Scheduler | Accuracy | Macro F1 | AUC-ROC |
|---|---|---|---|---|---|---|---|
| EfficientNet_V2_L | mlp | 0.1949 | AdamP | StepLR | 0.9048 | 0.9047 | 0.9849 |
| ConvNeXt_Tiny | linear | 0.1601 | Lion | CosineAnnealingLR | 0.9683 | 0.9682 | 0.9909 |
Level 3b β Stage III vs Stage IV (test set: 63 images)
| Model | Head | Dropout | Optimizer | Scheduler | Accuracy | Macro F1 | AUC-ROC |
|---|---|---|---|---|---|---|---|
| ConvNeXt_Large | multi_stage_head | 0.6594 | Lion | ReduceLROnPlateau | 0.7778 | 0.7773 | 0.8861 |
| ViT_B_16 | mlp | 0.5445 | AdamW | ReduceLROnPlateau | 0.7937 | 0.7934 | 0.8569 |
Stage III vs Stage IV is the hardest sub-task β subtle visual differences between full-thickness tissue loss with and without exposed bone/muscle make it challenging even for clinicians. The confidence gate at Level 3b (0.65) flags the most uncertain predictions.
Training Details
All models were trained on a curated dataset of ~1,000 pressure sore images collected from public medical databases, with stratified 70/20/10 train/validation/test splits.
Shared configuration:
- Input: 224 Γ 224, ImageNet normalisation (mean
[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) - Augmentation: random flips, rotation, colour jitter, Gaussian blur, affine transforms (Albumentations)
- Freeze schedule: backbone frozen for initial epochs, then progressively unfrozen (2-stage)
- Early stopping: patience 8β10 epochs on validation loss
- Hyperparameters: selected by Optuna trials (learning rate, weight decay, dropout, head type)
- Mixed precision: fp16 via Accelerate
Architecture notes:
- L1, L2, L3a: head is attached directly to the backbone's native classifier slot (
model.classifier[2]for ConvNeXt,model.heads.headfor ViT, etc.). Saved state dict has flat keys. - L3b:
WrappedModelwrapper β backbone classifier slot replaced withnn.Identity, a separate head receives raw feature embeddings. Saved state dict hasbackbone.*/head.*key prefixes.
Usage
Installation
pip install torch torchvision albumentations pillow
Minimal inference example
import torch
import torch.nn as nn
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision import models
from PIL import Image
# Helpers
def load_standard(arch_fn, in_feat, head_type, dropout, path, num_classes=1):
"""L1 / L2 / L3a β head directly on backbone, BCE/sigmoid."""
model = arch_fn(weights=None)
if head_type == "linear":
head = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
else: # mlp
head = nn.Sequential(
nn.Dropout(dropout), nn.Linear(in_feat, in_feat // 2),
nn.ReLU(inplace=True), nn.Dropout(dropout),
nn.Linear(in_feat // 2, num_classes))
model.classifier[2] = head # adjust slot per arch (see ps_classifier_torch_cascade.py)
sd = torch.load(path, map_location="cpu", weights_only=False)
model.load_state_dict(sd, strict=False)
return model.eval()
transform = A.Compose([
A.Resize(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
])
def preprocess(path):
img = Image.open(path).convert("RGB")
return transform(image=np.array(img))["image"].unsqueeze(0)
# Full cascade (recommended: use ps_classifier_torch_cascade.py)
# Clone the repo and place weights under models/torch_cascade/
# then simply:
from ps_classifier_torch_cascade import classify_image_cascade, cascade_confidence
image, message, details = classify_image_cascade("path/to/wound.jpg")
print(message)
# β
Pressure sore detected
# Severity : early (0.96)
# Stage : Stage II (0.91)
print("Joint confidence:", cascade_confidence(details)) # e.g. 0.864
if details.get("level_3", {}).get("gated"):
print("β Low Level-3 confidence β recommend clinical review")
details schema
{
"level_1": {"label": str, "confidence": float},
"level_2": {"label": str, "confidence": float}, # only if PS detected
"level_3": {
"label": str, # "Stage I" / "Stage II" / "Stage III" / "Stage IV"
"confidence": float,
"group": str, # "Early" or "Advanced"
"gated": bool # True when confidence < 0.65
}
}
Limitations & Disclaimer
- Trained on ~1,000 images from public educational resources β not a clinical-grade dataset
- Stage III vs Stage IV accuracy (~0.79β0.79 AUC ~0.86β0.89) reflects the inherent difficulty of this sub-task
- Confidence gating reduces but does not eliminate incorrect staging
- This is a research/demonstration tool β not a medical device and not validated for clinical use
- Always consult a licensed healthcare professional for diagnosis and treatment decisions
Related Resources
- GitHub: MrCzaro/PS_Classifier
- YOLO 2-Stage weights: MrCzaro/Pressure_sore_classifier_YOLO
- YOLO Cascade weights: MrCzaro/Pressure_sore_cascade_classifier_YOLO
License
MIT β see LICENSE