skanner / classify.py
blmsrz's picture
Deploy clean Skanner v1
cf630b4 verified
"""
Skanner - Melanoma Classification Module (v2)
==============================================
Uses a binary melanoma/benign classifier (Hemgg/Melanoma-Cancer-Image-classification)
tuned to threshold 0.15 for screening-optimized sensitivity.
Benchmarked on ISIC 2024 (n=50 balanced sample):
- Sensitivity: 84% (catches 21 of 25 melanomas)
- Specificity: 56% (clears 14 of 25 benign)
- Threshold: 0.15 (tuned for screening use case)
Why threshold 0.15 instead of 0.50:
In melanoma screening, missing cancer is worse than a false alarm. A false
alarm sends someone to a dermatologist who clears them. A missed cancer
becomes an advanced tumor. We tune toward sensitivity.
Usage (CLI):
python classify.py path/to/lesion.jpg
Usage (Python):
from classify import SkannerClassifier
clf = SkannerClassifier()
result = clf.classify("lesion.jpg")
print(result["risk_level"], result["melanoma_probability"])
"""
from __future__ import annotations
import sys
from pathlib import Path
from typing import Union
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForImageClassification
# Model + threshold determined by compare_models.py benchmark
DEFAULT_MODEL = "Hemgg/Melanoma-Cancer-Image-classification"
MELANOMA_THRESHOLD = 0.15 # Screening-optimized (vs default 0.50)
MODERATE_THRESHOLD = 0.08 # Below this is Low risk; above is Moderate
def _detect_device() -> str:
"""Pick best available device. M-series Macs get MPS acceleration."""
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
class SkannerClassifier:
"""Binary melanoma/benign classifier tuned for screening."""
def __init__(self, model_name: str = DEFAULT_MODEL, device: str | None = None):
self.device = device or _detect_device()
print(f"[Skanner] Loading model '{model_name}' on {self.device}...")
self.processor = AutoImageProcessor.from_pretrained(model_name)
self.model = AutoModelForImageClassification.from_pretrained(model_name)
self.model.to(self.device)
self.model.eval()
self.id2label = self.model.config.id2label
print(f"[Skanner] Ready. Classes: {list(self.id2label.values())}")
def classify(self, image: Union[str, Path, Image.Image]) -> dict:
"""
Run classification on a single image.
Returns:
{
"melanoma_probability": float, # 0.0 - 1.0 (primary output)
"risk_level": "Low"|"Moderate"|"High",
"top_prediction": str,
"top_confidence": float,
"all_probabilities": {class_name: prob, ...},
"threshold_used": float,
}
"""
# Accept either a path or a pre-loaded PIL image
if isinstance(image, (str, Path)):
image = Image.open(image).convert("RGB")
elif not isinstance(image, Image.Image):
raise TypeError(
f"Expected str, Path, or PIL.Image; got {type(image).__name__}"
)
else:
image = image.convert("RGB")
# Preprocess and run the model
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad():
logits = self.model(**inputs).logits
probs = torch.nn.functional.softmax(logits, dim=-1)[0].cpu()
# Build per-class probabilities dict
all_probs = {self.id2label[i]: float(probs[i]) for i in range(len(probs))}
# Find the melanoma-indicating probability.
# Hemgg model uses labels ['Benign', 'Malignant'] -> malignant == melanoma here.
melanoma_prob = 0.0
for label, prob in all_probs.items():
label_lower = label.lower()
if "malignant" in label_lower or "melanoma" in label_lower:
melanoma_prob = prob
break
# Top prediction (for display)
top_idx = int(torch.argmax(probs))
top_class = self.id2label[top_idx]
top_conf = float(probs[top_idx])
return {
"melanoma_probability": melanoma_prob,
"risk_level": self._triage(melanoma_prob),
"top_prediction": top_class,
"top_confidence": top_conf,
"all_probabilities": all_probs,
"threshold_used": MELANOMA_THRESHOLD,
}
@staticmethod
def _triage(melanoma_prob: float) -> str:
"""Three-tier risk stratification, tuned for screening.
High: prob >= 15% (flag for dermatologist referral)
Moderate: prob >= 8% (monitor / follow-up recommended)
Low: prob < 8% (routine self-monitoring)
"""
if melanoma_prob >= MELANOMA_THRESHOLD:
return "High"
if melanoma_prob >= MODERATE_THRESHOLD:
return "Moderate"
return "Low"
def _print_result(result: dict) -> None:
"""Pretty-print a classification result to the terminal."""
print()
print("=" * 60)
print(" SKANNER CLASSIFICATION RESULT")
print("=" * 60)
print(f" Melanoma probability: {result['melanoma_probability']:.1%}")
print(f" Risk level: {result['risk_level']}")
print(f" Threshold used: {result['threshold_used']:.2f} (screening-tuned)")
print()
print(" Class breakdown:")
sorted_probs = sorted(
result["all_probabilities"].items(), key=lambda x: -x[1]
)
for cls, prob in sorted_probs:
bar = "█" * int(prob * 30)
print(f" {cls:<24s} {prob:6.1%} {bar}")
print("=" * 60)
print()
print(" REMINDER: This is a screening tool, NOT a medical diagnosis.")
print(" Always consult a qualified dermatologist.")
print()
def main():
if len(sys.argv) < 2:
print("Usage: python classify.py <image_path>")
print("Example: python classify.py ISIC_2024_Permissive_Training_Input/ISIC_9855202.jpg")
sys.exit(1)
image_path = Path(sys.argv[1])
if not image_path.exists():
print(f"Error: file not found: {image_path}")
sys.exit(1)
classifier = SkannerClassifier()
result = classifier.classify(image_path)
_print_result(result)
if __name__ == "__main__":
main()