| | --- |
| | license: mit |
| | language: |
| | - en |
| | tags: |
| | - medical-imaging |
| | - image-classification |
| | - pressure-sore |
| | - wound-classification |
| | - pytorch |
| | - torchvision |
| | - cascade |
| | - ensemble |
| | - convnext |
| | - efficientnet |
| | - vit |
| | - maxvit |
| | - resnet |
| | pipeline_tag: image-classification |
| | --- |
| | |
| | # Pressure Sore Cascade Classifier β Torchvision 3-Level |
| |
|
| | This repository contains 8 PyTorch model weights forming a **3-level hierarchical cascade** for automated pressure sore detection and staging. The cascade progressively narrows from detecting any wound, to separating severity groups, to fine-grained stage classification β mirroring clinical decision-making. |
| |
|
| | These weights are used by [ps_classifier_torch_cascade.py](https://github.com/MrCzaro/PS_Classifier) in the PS Classifier web application. |
| |
|
| | --- |
| |
|
| | ## Cascade Structure |
| |
|
| | ``` |
| | Image |
| | β |
| | βΌ |
| | [Level 1 β PS vs No-PS] BCEWithLogitsLoss Β· sigmoid |
| | MaxVit_T (linear head) |
| | ResNet50 (mlp head) |
| | β |
| | ββ NO β "No pressure sore detected" |
| | ββ YES βΌ |
| | [Level 2 β Early (Stage I/II) vs Advanced (III/IV)] BCEWithLogitsLoss Β· sigmoid |
| | ConvNeXt_Base (mlp head) |
| | EfficientNet_V2_L (linear head) |
| | β |
| | ββ EARLY βββββββββββββββββββββββ |
| | ββ ADVANCED ββββββββ β |
| | βΌ βΌ |
| | [Level 3b] [Level 3a] |
| | Stage III vs Stage IV Stage I vs Stage II |
| | ConvNeXt_Large (MSH) EfficientNet_V2_L (mlp) |
| | ViT_B_16 (mlp) ConvNeXt_Tiny (linear) |
| | CrossEntropyLoss BCEWithLogitsLoss Β· sigmoid |
| | WrappedModel pattern Direct-attachment pattern |
| | β Confidence gate 0.65 β β Confidence gate 0.65 β |
| | ``` |
| |
|
| | > **Confidence gating**: if the Level 3 ensemble confidence falls below 0.65 the prediction is still returned, but annotated with an uncertainty warning and flagged for clinical review (`details["level_3"]["gated"] == True`). |
| |
|
| | --- |
| |
|
| | ## Files |
| |
|
| | | File | Level | Architecture | Head | Loss | |
| | |------|-------|-------------|------|------| |
| | | `Level 1 Binary PS or not PS MaxVit_T.pth` | L1 | MaxVit-T | linear | BCE | |
| | | `Level 1 Binary PS or not PS ResNet50.pth` | L1 | ResNet-50 | mlp | BCE | |
| | | `Level 2 Early vs Advanced ConvNeXt_Base.pth` | L2 | ConvNeXt-Base | mlp | BCE | |
| | | `Level 2 Early vs Advanced EfficientNet_V2_L.pth` | L2 | EfficientNet-V2-L | linear | BCE | |
| | | `Level 3a Early EfficientNet_V2_L.pth` | L3a | EfficientNet-V2-L | mlp | BCE | |
| | | `Level 3a Early ConvNeXt_Tiny.pth` | L3a | ConvNeXt-Tiny | linear | BCE | |
| | | `Level 3b Advanced ConvNeXt_Large.pth` | L3b | ConvNeXt-Large | multi_stage_head | XEnt | |
| | | `Level 3b Advanced ViT_B_16.pth` | L3b | ViT-B/16 | mlp | XEnt | |
| |
|
| | **MSH = MultiStageHead** (Dropout β FC(inβin/2) β BN β ReLU β FC(in/2β2)) |
| |
|
| | --- |
| |
|
| | ## Model Performance |
| |
|
| | ### Level 1 β PS vs No-PS (test set: 261 images) |
| |
|
| | | Model | Head | Dropout | Scheduler | Accuracy | Macro F1 | AUC-ROC | |
| | |-------|------|---------|-----------|----------|----------|---------| |
| | | MaxVit_T | linear | 0.2396 | CosineAnnealingLR | 0.9962 | 0.9962 | 1.0000 | |
| | | ResNet50 | mlp | 0.5840 | CosineAnnealingLR | 1.0000 | 1.0000 | 1.0000 | |
| | |
| | ### Level 2 β Early vs Advanced (test set: 125 images) |
| | |
| | | Model | Head | Dropout | Optimizer | Scheduler | Accuracy | Macro F1 | AUC-ROC | |
| | |-------|------|---------|-----------|-----------|----------|----------|---------| |
| | | ConvNeXt_Base | mlp | 0.1025 | AdamP | CosineAnnealingLR | 0.9520 | 0.9520 | 0.9857 | |
| | | EfficientNet_V2_L | linear | 0.3564 | AdamP | CosineAnnealingLR | 0.9600 | 0.9600 | 0.9916 | |
| |
|
| | ### Level 3a β Stage I vs Stage II (test set: 63 images) |
| |
|
| | | Model | Head | Dropout | Optimizer | Scheduler | Accuracy | Macro F1 | AUC-ROC | |
| | |-------|------|---------|-----------|-----------|----------|----------|---------| |
| | | EfficientNet_V2_L | mlp | 0.1949 | AdamP | StepLR | 0.9048 | 0.9047 | 0.9849 | |
| | | ConvNeXt_Tiny | linear | 0.1601 | Lion | CosineAnnealingLR | 0.9683 | 0.9682 | 0.9909 | |
| | |
| | ### Level 3b β Stage III vs Stage IV (test set: 63 images) |
| | |
| | | Model | Head | Dropout | Optimizer | Scheduler | Accuracy | Macro F1 | AUC-ROC | |
| | |-------|------|---------|-----------|-----------|----------|----------|---------| |
| | | ConvNeXt_Large | multi_stage_head | 0.6594 | Lion | ReduceLROnPlateau | 0.7778 | 0.7773 | 0.8861 | |
| | | ViT_B_16 | mlp | 0.5445 | AdamW | ReduceLROnPlateau | 0.7937 | 0.7934 | 0.8569 | |
| |
|
| | > **Stage III vs Stage IV is the hardest sub-task** β subtle visual differences between full-thickness tissue loss with and without exposed bone/muscle make it challenging even for clinicians. The confidence gate at Level 3b (0.65) flags the most uncertain predictions. |
| |
|
| | --- |
| |
|
| | ## Training Details |
| |
|
| | All models were trained on a curated dataset of ~1,000 pressure sore images collected from public medical databases, with stratified 70/20/10 train/validation/test splits. |
| |
|
| | **Shared configuration**: |
| | - Input: 224 Γ 224, ImageNet normalisation (mean `[0.485, 0.456, 0.406]`, std `[0.229, 0.224, 0.225]`) |
| | - Augmentation: random flips, rotation, colour jitter, Gaussian blur, affine transforms (Albumentations) |
| | - Freeze schedule: backbone frozen for initial epochs, then progressively unfrozen (2-stage) |
| | - Early stopping: patience 8β10 epochs on validation loss |
| | - Hyperparameters: selected by Optuna trials (learning rate, weight decay, dropout, head type) |
| | - Mixed precision: fp16 via Accelerate |
| |
|
| | **Architecture notes**: |
| | - L1, L2, L3a: head is attached directly to the backbone's native classifier slot (`model.classifier[2]` for ConvNeXt, `model.heads.head` for ViT, etc.). Saved state dict has flat keys. |
| | - L3b: `WrappedModel` wrapper β backbone classifier slot replaced with `nn.Identity`, a separate head receives raw feature embeddings. Saved state dict has `backbone.*` / `head.*` key prefixes. |
| |
|
| | --- |
| |
|
| | ## Usage |
| |
|
| | ### Installation |
| |
|
| | ```bash |
| | pip install torch torchvision albumentations pillow |
| | ``` |
| |
|
| | ### Minimal inference example |
| |
|
| | ```python |
| | import torch |
| | import torch.nn as nn |
| | import numpy as np |
| | import albumentations as A |
| | from albumentations.pytorch import ToTensorV2 |
| | from torchvision import models |
| | from PIL import Image |
| | |
| | # Helpers |
| | |
| | def load_standard(arch_fn, in_feat, head_type, dropout, path, num_classes=1): |
| | """L1 / L2 / L3a β head directly on backbone, BCE/sigmoid.""" |
| | model = arch_fn(weights=None) |
| | if head_type == "linear": |
| | head = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) |
| | else: # mlp |
| | head = nn.Sequential( |
| | nn.Dropout(dropout), nn.Linear(in_feat, in_feat // 2), |
| | nn.ReLU(inplace=True), nn.Dropout(dropout), |
| | nn.Linear(in_feat // 2, num_classes)) |
| | model.classifier[2] = head # adjust slot per arch (see ps_classifier_torch_cascade.py) |
| | sd = torch.load(path, map_location="cpu", weights_only=False) |
| | model.load_state_dict(sd, strict=False) |
| | return model.eval() |
| | |
| | |
| | transform = A.Compose([ |
| | A.Resize(224, 224), |
| | A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| | ToTensorV2(), |
| | ]) |
| | |
| | def preprocess(path): |
| | img = Image.open(path).convert("RGB") |
| | return transform(image=np.array(img))["image"].unsqueeze(0) |
| | |
| | # Full cascade (recommended: use ps_classifier_torch_cascade.py) |
| | |
| | # Clone the repo and place weights under models/torch_cascade/ |
| | # then simply: |
| | |
| | from ps_classifier_torch_cascade import classify_image_cascade, cascade_confidence |
| | |
| | image, message, details = classify_image_cascade("path/to/wound.jpg") |
| | print(message) |
| | # β
Pressure sore detected |
| | # Severity : early (0.96) |
| | # Stage : Stage II (0.91) |
| | |
| | print("Joint confidence:", cascade_confidence(details)) # e.g. 0.864 |
| | |
| | if details.get("level_3", {}).get("gated"): |
| | print("β Low Level-3 confidence β recommend clinical review") |
| | ``` |
| |
|
| | ### `details` schema |
| |
|
| | ```python |
| | { |
| | "level_1": {"label": str, "confidence": float}, |
| | "level_2": {"label": str, "confidence": float}, # only if PS detected |
| | "level_3": { |
| | "label": str, # "Stage I" / "Stage II" / "Stage III" / "Stage IV" |
| | "confidence": float, |
| | "group": str, # "Early" or "Advanced" |
| | "gated": bool # True when confidence < 0.65 |
| | } |
| | } |
| | ``` |
| |
|
| | --- |
| |
|
| | ## Limitations & Disclaimer |
| |
|
| | - Trained on ~1,000 images from public educational resources β **not a clinical-grade dataset** |
| | - Stage III vs Stage IV accuracy (~0.79β0.79 AUC ~0.86β0.89) reflects the inherent difficulty of this sub-task |
| | - Confidence gating reduces but does not eliminate incorrect staging |
| | - **This is a research/demonstration tool β not a medical device and not validated for clinical use** |
| | - Always consult a licensed healthcare professional for diagnosis and treatment decisions |
| |
|
| | --- |
| |
|
| | ## Related Resources |
| |
|
| | - **GitHub**: [MrCzaro/PS_Classifier](https://github.com/MrCzaro/PS_Classifier) |
| | - **YOLO 2-Stage weights**: [MrCzaro/Pressure_sore_classifier_YOLO](https://huggingface.co/MrCzaro/Pressure_sore_classifier_YOLO) |
| | - **YOLO Cascade weights**: [MrCzaro/Pressure_sore_cascade_classifier_YOLO](https://huggingface.co/MrCzaro/Pressure_sore_cascade_classifier_YOLO) |
| |
|
| | --- |
| |
|
| |
|
| | ## License |
| |
|
| | MIT β see [LICENSE](https://opensource.org/license/mit) |
| |
|