""" Skanner - Melanoma Classification Module (v2) ============================================== Uses a binary melanoma/benign classifier (Hemgg/Melanoma-Cancer-Image-classification) tuned to threshold 0.15 for screening-optimized sensitivity. Benchmarked on ISIC 2024 (n=50 balanced sample): - Sensitivity: 84% (catches 21 of 25 melanomas) - Specificity: 56% (clears 14 of 25 benign) - Threshold: 0.15 (tuned for screening use case) Why threshold 0.15 instead of 0.50: In melanoma screening, missing cancer is worse than a false alarm. A false alarm sends someone to a dermatologist who clears them. A missed cancer becomes an advanced tumor. We tune toward sensitivity. Usage (CLI): python classify.py path/to/lesion.jpg Usage (Python): from classify import SkannerClassifier clf = SkannerClassifier() result = clf.classify("lesion.jpg") print(result["risk_level"], result["melanoma_probability"]) """ from __future__ import annotations import sys from pathlib import Path from typing import Union import torch from PIL import Image from transformers import AutoImageProcessor, AutoModelForImageClassification # Model + threshold determined by compare_models.py benchmark DEFAULT_MODEL = "Hemgg/Melanoma-Cancer-Image-classification" MELANOMA_THRESHOLD = 0.15 # Screening-optimized (vs default 0.50) MODERATE_THRESHOLD = 0.08 # Below this is Low risk; above is Moderate def _detect_device() -> str: """Pick best available device. M-series Macs get MPS acceleration.""" if torch.cuda.is_available(): return "cuda" if torch.backends.mps.is_available(): return "mps" return "cpu" class SkannerClassifier: """Binary melanoma/benign classifier tuned for screening.""" def __init__(self, model_name: str = DEFAULT_MODEL, device: str | None = None): self.device = device or _detect_device() print(f"[Skanner] Loading model '{model_name}' on {self.device}...") self.processor = AutoImageProcessor.from_pretrained(model_name) self.model = AutoModelForImageClassification.from_pretrained(model_name) self.model.to(self.device) self.model.eval() self.id2label = self.model.config.id2label print(f"[Skanner] Ready. Classes: {list(self.id2label.values())}") def classify(self, image: Union[str, Path, Image.Image]) -> dict: """ Run classification on a single image. Returns: { "melanoma_probability": float, # 0.0 - 1.0 (primary output) "risk_level": "Low"|"Moderate"|"High", "top_prediction": str, "top_confidence": float, "all_probabilities": {class_name: prob, ...}, "threshold_used": float, } """ # Accept either a path or a pre-loaded PIL image if isinstance(image, (str, Path)): image = Image.open(image).convert("RGB") elif not isinstance(image, Image.Image): raise TypeError( f"Expected str, Path, or PIL.Image; got {type(image).__name__}" ) else: image = image.convert("RGB") # Preprocess and run the model inputs = self.processor(images=image, return_tensors="pt").to(self.device) with torch.no_grad(): logits = self.model(**inputs).logits probs = torch.nn.functional.softmax(logits, dim=-1)[0].cpu() # Build per-class probabilities dict all_probs = {self.id2label[i]: float(probs[i]) for i in range(len(probs))} # Find the melanoma-indicating probability. # Hemgg model uses labels ['Benign', 'Malignant'] -> malignant == melanoma here. melanoma_prob = 0.0 for label, prob in all_probs.items(): label_lower = label.lower() if "malignant" in label_lower or "melanoma" in label_lower: melanoma_prob = prob break # Top prediction (for display) top_idx = int(torch.argmax(probs)) top_class = self.id2label[top_idx] top_conf = float(probs[top_idx]) return { "melanoma_probability": melanoma_prob, "risk_level": self._triage(melanoma_prob), "top_prediction": top_class, "top_confidence": top_conf, "all_probabilities": all_probs, "threshold_used": MELANOMA_THRESHOLD, } @staticmethod def _triage(melanoma_prob: float) -> str: """Three-tier risk stratification, tuned for screening. High: prob >= 15% (flag for dermatologist referral) Moderate: prob >= 8% (monitor / follow-up recommended) Low: prob < 8% (routine self-monitoring) """ if melanoma_prob >= MELANOMA_THRESHOLD: return "High" if melanoma_prob >= MODERATE_THRESHOLD: return "Moderate" return "Low" def _print_result(result: dict) -> None: """Pretty-print a classification result to the terminal.""" print() print("=" * 60) print(" SKANNER CLASSIFICATION RESULT") print("=" * 60) print(f" Melanoma probability: {result['melanoma_probability']:.1%}") print(f" Risk level: {result['risk_level']}") print(f" Threshold used: {result['threshold_used']:.2f} (screening-tuned)") print() print(" Class breakdown:") sorted_probs = sorted( result["all_probabilities"].items(), key=lambda x: -x[1] ) for cls, prob in sorted_probs: bar = "█" * int(prob * 30) print(f" {cls:<24s} {prob:6.1%} {bar}") print("=" * 60) print() print(" REMINDER: This is a screening tool, NOT a medical diagnosis.") print(" Always consult a qualified dermatologist.") print() def main(): if len(sys.argv) < 2: print("Usage: python classify.py ") print("Example: python classify.py ISIC_2024_Permissive_Training_Input/ISIC_9855202.jpg") sys.exit(1) image_path = Path(sys.argv[1]) if not image_path.exists(): print(f"Error: file not found: {image_path}") sys.exit(1) classifier = SkannerClassifier() result = classifier.classify(image_path) _print_result(result) if __name__ == "__main__": main()