MrCzaro's picture
Update README.md
f495173 verified
---
license: mit
language:
- en
tags:
- medical-imaging
- image-classification
- pressure-sore
- wound-classification
- pytorch
- torchvision
- cascade
- ensemble
- convnext
- efficientnet
- vit
- maxvit
- resnet
pipeline_tag: image-classification
---
# Pressure Sore Cascade Classifier β€” Torchvision 3-Level
This repository contains 8 PyTorch model weights forming a **3-level hierarchical cascade** for automated pressure sore detection and staging. The cascade progressively narrows from detecting any wound, to separating severity groups, to fine-grained stage classification β€” mirroring clinical decision-making.
These weights are used by [ps_classifier_torch_cascade.py](https://github.com/MrCzaro/PS_Classifier) in the PS Classifier web application.
---
## Cascade Structure
```
Image
β”‚
β–Ό
[Level 1 β€” PS vs No-PS] BCEWithLogitsLoss Β· sigmoid
MaxVit_T (linear head)
ResNet50 (mlp head)
β”‚
β”œβ”€ NO β†’ "No pressure sore detected"
└─ YES β–Ό
[Level 2 β€” Early (Stage I/II) vs Advanced (III/IV)] BCEWithLogitsLoss Β· sigmoid
ConvNeXt_Base (mlp head)
EfficientNet_V2_L (linear head)
β”‚
β”œβ”€ EARLY ──────────────────────┐
└─ ADVANCED ───────┐ β”‚
β–Ό β–Ό
[Level 3b] [Level 3a]
Stage III vs Stage IV Stage I vs Stage II
ConvNeXt_Large (MSH) EfficientNet_V2_L (mlp)
ViT_B_16 (mlp) ConvNeXt_Tiny (linear)
CrossEntropyLoss BCEWithLogitsLoss Β· sigmoid
WrappedModel pattern Direct-attachment pattern
↓ Confidence gate 0.65 ↓ ↓ Confidence gate 0.65 ↓
```
> **Confidence gating**: if the Level 3 ensemble confidence falls below 0.65 the prediction is still returned, but annotated with an uncertainty warning and flagged for clinical review (`details["level_3"]["gated"] == True`).
---
## Files
| File | Level | Architecture | Head | Loss |
|------|-------|-------------|------|------|
| `Level 1 Binary PS or not PS MaxVit_T.pth` | L1 | MaxVit-T | linear | BCE |
| `Level 1 Binary PS or not PS ResNet50.pth` | L1 | ResNet-50 | mlp | BCE |
| `Level 2 Early vs Advanced ConvNeXt_Base.pth` | L2 | ConvNeXt-Base | mlp | BCE |
| `Level 2 Early vs Advanced EfficientNet_V2_L.pth` | L2 | EfficientNet-V2-L | linear | BCE |
| `Level 3a Early EfficientNet_V2_L.pth` | L3a | EfficientNet-V2-L | mlp | BCE |
| `Level 3a Early ConvNeXt_Tiny.pth` | L3a | ConvNeXt-Tiny | linear | BCE |
| `Level 3b Advanced ConvNeXt_Large.pth` | L3b | ConvNeXt-Large | multi_stage_head | XEnt |
| `Level 3b Advanced ViT_B_16.pth` | L3b | ViT-B/16 | mlp | XEnt |
**MSH = MultiStageHead** (Dropout → FC(in→in/2) → BN → ReLU → FC(in/2→2))
---
## Model Performance
### Level 1 β€” PS vs No-PS (test set: 261 images)
| Model | Head | Dropout | Scheduler | Accuracy | Macro F1 | AUC-ROC |
|-------|------|---------|-----------|----------|----------|---------|
| MaxVit_T | linear | 0.2396 | CosineAnnealingLR | 0.9962 | 0.9962 | 1.0000 |
| ResNet50 | mlp | 0.5840 | CosineAnnealingLR | 1.0000 | 1.0000 | 1.0000 |
### Level 2 β€” Early vs Advanced (test set: 125 images)
| Model | Head | Dropout | Optimizer | Scheduler | Accuracy | Macro F1 | AUC-ROC |
|-------|------|---------|-----------|-----------|----------|----------|---------|
| ConvNeXt_Base | mlp | 0.1025 | AdamP | CosineAnnealingLR | 0.9520 | 0.9520 | 0.9857 |
| EfficientNet_V2_L | linear | 0.3564 | AdamP | CosineAnnealingLR | 0.9600 | 0.9600 | 0.9916 |
### Level 3a β€” Stage I vs Stage II (test set: 63 images)
| Model | Head | Dropout | Optimizer | Scheduler | Accuracy | Macro F1 | AUC-ROC |
|-------|------|---------|-----------|-----------|----------|----------|---------|
| EfficientNet_V2_L | mlp | 0.1949 | AdamP | StepLR | 0.9048 | 0.9047 | 0.9849 |
| ConvNeXt_Tiny | linear | 0.1601 | Lion | CosineAnnealingLR | 0.9683 | 0.9682 | 0.9909 |
### Level 3b β€” Stage III vs Stage IV (test set: 63 images)
| Model | Head | Dropout | Optimizer | Scheduler | Accuracy | Macro F1 | AUC-ROC |
|-------|------|---------|-----------|-----------|----------|----------|---------|
| ConvNeXt_Large | multi_stage_head | 0.6594 | Lion | ReduceLROnPlateau | 0.7778 | 0.7773 | 0.8861 |
| ViT_B_16 | mlp | 0.5445 | AdamW | ReduceLROnPlateau | 0.7937 | 0.7934 | 0.8569 |
> **Stage III vs Stage IV is the hardest sub-task** β€” subtle visual differences between full-thickness tissue loss with and without exposed bone/muscle make it challenging even for clinicians. The confidence gate at Level 3b (0.65) flags the most uncertain predictions.
---
## Training Details
All models were trained on a curated dataset of ~1,000 pressure sore images collected from public medical databases, with stratified 70/20/10 train/validation/test splits.
**Shared configuration**:
- Input: 224 Γ— 224, ImageNet normalisation (mean `[0.485, 0.456, 0.406]`, std `[0.229, 0.224, 0.225]`)
- Augmentation: random flips, rotation, colour jitter, Gaussian blur, affine transforms (Albumentations)
- Freeze schedule: backbone frozen for initial epochs, then progressively unfrozen (2-stage)
- Early stopping: patience 8–10 epochs on validation loss
- Hyperparameters: selected by Optuna trials (learning rate, weight decay, dropout, head type)
- Mixed precision: fp16 via Accelerate
**Architecture notes**:
- L1, L2, L3a: head is attached directly to the backbone's native classifier slot (`model.classifier[2]` for ConvNeXt, `model.heads.head` for ViT, etc.). Saved state dict has flat keys.
- L3b: `WrappedModel` wrapper β€” backbone classifier slot replaced with `nn.Identity`, a separate head receives raw feature embeddings. Saved state dict has `backbone.*` / `head.*` key prefixes.
---
## Usage
### Installation
```bash
pip install torch torchvision albumentations pillow
```
### Minimal inference example
```python
import torch
import torch.nn as nn
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision import models
from PIL import Image
# Helpers
def load_standard(arch_fn, in_feat, head_type, dropout, path, num_classes=1):
"""L1 / L2 / L3a β€” head directly on backbone, BCE/sigmoid."""
model = arch_fn(weights=None)
if head_type == "linear":
head = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
else: # mlp
head = nn.Sequential(
nn.Dropout(dropout), nn.Linear(in_feat, in_feat // 2),
nn.ReLU(inplace=True), nn.Dropout(dropout),
nn.Linear(in_feat // 2, num_classes))
model.classifier[2] = head # adjust slot per arch (see ps_classifier_torch_cascade.py)
sd = torch.load(path, map_location="cpu", weights_only=False)
model.load_state_dict(sd, strict=False)
return model.eval()
transform = A.Compose([
A.Resize(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
])
def preprocess(path):
img = Image.open(path).convert("RGB")
return transform(image=np.array(img))["image"].unsqueeze(0)
# Full cascade (recommended: use ps_classifier_torch_cascade.py)
# Clone the repo and place weights under models/torch_cascade/
# then simply:
from ps_classifier_torch_cascade import classify_image_cascade, cascade_confidence
image, message, details = classify_image_cascade("path/to/wound.jpg")
print(message)
# βœ… Pressure sore detected
# Severity : early (0.96)
# Stage : Stage II (0.91)
print("Joint confidence:", cascade_confidence(details)) # e.g. 0.864
if details.get("level_3", {}).get("gated"):
print("⚠ Low Level-3 confidence β€” recommend clinical review")
```
### `details` schema
```python
{
"level_1": {"label": str, "confidence": float},
"level_2": {"label": str, "confidence": float}, # only if PS detected
"level_3": {
"label": str, # "Stage I" / "Stage II" / "Stage III" / "Stage IV"
"confidence": float,
"group": str, # "Early" or "Advanced"
"gated": bool # True when confidence < 0.65
}
}
```
---
## Limitations & Disclaimer
- Trained on ~1,000 images from public educational resources β€” **not a clinical-grade dataset**
- Stage III vs Stage IV accuracy (~0.79–0.79 AUC ~0.86–0.89) reflects the inherent difficulty of this sub-task
- Confidence gating reduces but does not eliminate incorrect staging
- **This is a research/demonstration tool β€” not a medical device and not validated for clinical use**
- Always consult a licensed healthcare professional for diagnosis and treatment decisions
---
## Related Resources
- **GitHub**: [MrCzaro/PS_Classifier](https://github.com/MrCzaro/PS_Classifier)
- **YOLO 2-Stage weights**: [MrCzaro/Pressure_sore_classifier_YOLO](https://huggingface.co/MrCzaro/Pressure_sore_classifier_YOLO)
- **YOLO Cascade weights**: [MrCzaro/Pressure_sore_cascade_classifier_YOLO](https://huggingface.co/MrCzaro/Pressure_sore_cascade_classifier_YOLO)
---
## License
MIT β€” see [LICENSE](https://opensource.org/license/mit)