--- license: mit language: - en tags: - medical-imaging - image-classification - pressure-sore - wound-classification - pytorch - torchvision - cascade - ensemble - convnext - efficientnet - vit - maxvit - resnet pipeline_tag: image-classification --- # Pressure Sore Cascade Classifier — Torchvision 3-Level This repository contains 8 PyTorch model weights forming a **3-level hierarchical cascade** for automated pressure sore detection and staging. The cascade progressively narrows from detecting any wound, to separating severity groups, to fine-grained stage classification — mirroring clinical decision-making. These weights are used by [ps_classifier_torch_cascade.py](https://github.com/MrCzaro/PS_Classifier) in the PS Classifier web application. --- ## Cascade Structure ``` Image │ ▼ [Level 1 — PS vs No-PS] BCEWithLogitsLoss · sigmoid MaxVit_T (linear head) ResNet50 (mlp head) │ ├─ NO → "No pressure sore detected" └─ YES ▼ [Level 2 — Early (Stage I/II) vs Advanced (III/IV)] BCEWithLogitsLoss · sigmoid ConvNeXt_Base (mlp head) EfficientNet_V2_L (linear head) │ ├─ EARLY ──────────────────────┐ └─ ADVANCED ───────┐ │ ▼ ▼ [Level 3b] [Level 3a] Stage III vs Stage IV Stage I vs Stage II ConvNeXt_Large (MSH) EfficientNet_V2_L (mlp) ViT_B_16 (mlp) ConvNeXt_Tiny (linear) CrossEntropyLoss BCEWithLogitsLoss · sigmoid WrappedModel pattern Direct-attachment pattern ↓ Confidence gate 0.65 ↓ ↓ Confidence gate 0.65 ↓ ``` > **Confidence gating**: if the Level 3 ensemble confidence falls below 0.65 the prediction is still returned, but annotated with an uncertainty warning and flagged for clinical review (`details["level_3"]["gated"] == True`). --- ## Files | File | Level | Architecture | Head | Loss | |------|-------|-------------|------|------| | `Level 1 Binary PS or not PS MaxVit_T.pth` | L1 | MaxVit-T | linear | BCE | | `Level 1 Binary PS or not PS ResNet50.pth` | L1 | ResNet-50 | mlp | BCE | | `Level 2 Early vs Advanced ConvNeXt_Base.pth` | L2 | ConvNeXt-Base | mlp | BCE | | `Level 2 Early vs Advanced EfficientNet_V2_L.pth` | L2 | EfficientNet-V2-L | linear | BCE | | `Level 3a Early EfficientNet_V2_L.pth` | L3a | EfficientNet-V2-L | mlp | BCE | | `Level 3a Early ConvNeXt_Tiny.pth` | L3a | ConvNeXt-Tiny | linear | BCE | | `Level 3b Advanced ConvNeXt_Large.pth` | L3b | ConvNeXt-Large | multi_stage_head | XEnt | | `Level 3b Advanced ViT_B_16.pth` | L3b | ViT-B/16 | mlp | XEnt | **MSH = MultiStageHead** (Dropout → FC(in→in/2) → BN → ReLU → FC(in/2→2)) --- ## Model Performance ### Level 1 — PS vs No-PS (test set: 261 images) | Model | Head | Dropout | Scheduler | Accuracy | Macro F1 | AUC-ROC | |-------|------|---------|-----------|----------|----------|---------| | MaxVit_T | linear | 0.2396 | CosineAnnealingLR | 0.9962 | 0.9962 | 1.0000 | | ResNet50 | mlp | 0.5840 | CosineAnnealingLR | 1.0000 | 1.0000 | 1.0000 | ### Level 2 — Early vs Advanced (test set: 125 images) | Model | Head | Dropout | Optimizer | Scheduler | Accuracy | Macro F1 | AUC-ROC | |-------|------|---------|-----------|-----------|----------|----------|---------| | ConvNeXt_Base | mlp | 0.1025 | AdamP | CosineAnnealingLR | 0.9520 | 0.9520 | 0.9857 | | EfficientNet_V2_L | linear | 0.3564 | AdamP | CosineAnnealingLR | 0.9600 | 0.9600 | 0.9916 | ### Level 3a — Stage I vs Stage II (test set: 63 images) | Model | Head | Dropout | Optimizer | Scheduler | Accuracy | Macro F1 | AUC-ROC | |-------|------|---------|-----------|-----------|----------|----------|---------| | EfficientNet_V2_L | mlp | 0.1949 | AdamP | StepLR | 0.9048 | 0.9047 | 0.9849 | | ConvNeXt_Tiny | linear | 0.1601 | Lion | CosineAnnealingLR | 0.9683 | 0.9682 | 0.9909 | ### Level 3b — Stage III vs Stage IV (test set: 63 images) | Model | Head | Dropout | Optimizer | Scheduler | Accuracy | Macro F1 | AUC-ROC | |-------|------|---------|-----------|-----------|----------|----------|---------| | ConvNeXt_Large | multi_stage_head | 0.6594 | Lion | ReduceLROnPlateau | 0.7778 | 0.7773 | 0.8861 | | ViT_B_16 | mlp | 0.5445 | AdamW | ReduceLROnPlateau | 0.7937 | 0.7934 | 0.8569 | > **Stage III vs Stage IV is the hardest sub-task** — subtle visual differences between full-thickness tissue loss with and without exposed bone/muscle make it challenging even for clinicians. The confidence gate at Level 3b (0.65) flags the most uncertain predictions. --- ## Training Details All models were trained on a curated dataset of ~1,000 pressure sore images collected from public medical databases, with stratified 70/20/10 train/validation/test splits. **Shared configuration**: - Input: 224 × 224, ImageNet normalisation (mean `[0.485, 0.456, 0.406]`, std `[0.229, 0.224, 0.225]`) - Augmentation: random flips, rotation, colour jitter, Gaussian blur, affine transforms (Albumentations) - Freeze schedule: backbone frozen for initial epochs, then progressively unfrozen (2-stage) - Early stopping: patience 8–10 epochs on validation loss - Hyperparameters: selected by Optuna trials (learning rate, weight decay, dropout, head type) - Mixed precision: fp16 via Accelerate **Architecture notes**: - L1, L2, L3a: head is attached directly to the backbone's native classifier slot (`model.classifier[2]` for ConvNeXt, `model.heads.head` for ViT, etc.). Saved state dict has flat keys. - L3b: `WrappedModel` wrapper — backbone classifier slot replaced with `nn.Identity`, a separate head receives raw feature embeddings. Saved state dict has `backbone.*` / `head.*` key prefixes. --- ## Usage ### Installation ```bash pip install torch torchvision albumentations pillow ``` ### Minimal inference example ```python import torch import torch.nn as nn import numpy as np import albumentations as A from albumentations.pytorch import ToTensorV2 from torchvision import models from PIL import Image # Helpers def load_standard(arch_fn, in_feat, head_type, dropout, path, num_classes=1): """L1 / L2 / L3a — head directly on backbone, BCE/sigmoid.""" model = arch_fn(weights=None) if head_type == "linear": head = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) else: # mlp head = nn.Sequential( nn.Dropout(dropout), nn.Linear(in_feat, in_feat // 2), nn.ReLU(inplace=True), nn.Dropout(dropout), nn.Linear(in_feat // 2, num_classes)) model.classifier[2] = head # adjust slot per arch (see ps_classifier_torch_cascade.py) sd = torch.load(path, map_location="cpu", weights_only=False) model.load_state_dict(sd, strict=False) return model.eval() transform = A.Compose([ A.Resize(224, 224), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2(), ]) def preprocess(path): img = Image.open(path).convert("RGB") return transform(image=np.array(img))["image"].unsqueeze(0) # Full cascade (recommended: use ps_classifier_torch_cascade.py) # Clone the repo and place weights under models/torch_cascade/ # then simply: from ps_classifier_torch_cascade import classify_image_cascade, cascade_confidence image, message, details = classify_image_cascade("path/to/wound.jpg") print(message) # ✅ Pressure sore detected # Severity : early (0.96) # Stage : Stage II (0.91) print("Joint confidence:", cascade_confidence(details)) # e.g. 0.864 if details.get("level_3", {}).get("gated"): print("⚠ Low Level-3 confidence — recommend clinical review") ``` ### `details` schema ```python { "level_1": {"label": str, "confidence": float}, "level_2": {"label": str, "confidence": float}, # only if PS detected "level_3": { "label": str, # "Stage I" / "Stage II" / "Stage III" / "Stage IV" "confidence": float, "group": str, # "Early" or "Advanced" "gated": bool # True when confidence < 0.65 } } ``` --- ## Limitations & Disclaimer - Trained on ~1,000 images from public educational resources — **not a clinical-grade dataset** - Stage III vs Stage IV accuracy (~0.79–0.79 AUC ~0.86–0.89) reflects the inherent difficulty of this sub-task - Confidence gating reduces but does not eliminate incorrect staging - **This is a research/demonstration tool — not a medical device and not validated for clinical use** - Always consult a licensed healthcare professional for diagnosis and treatment decisions --- ## Related Resources - **GitHub**: [MrCzaro/PS_Classifier](https://github.com/MrCzaro/PS_Classifier) - **YOLO 2-Stage weights**: [MrCzaro/Pressure_sore_classifier_YOLO](https://huggingface.co/MrCzaro/Pressure_sore_classifier_YOLO) - **YOLO Cascade weights**: [MrCzaro/Pressure_sore_cascade_classifier_YOLO](https://huggingface.co/MrCzaro/Pressure_sore_cascade_classifier_YOLO) --- ## License MIT — see [LICENSE](https://opensource.org/license/mit)