| """ |
| Skanner - Melanoma Classification Module (v2) |
| ============================================== |
| |
| Uses a binary melanoma/benign classifier (Hemgg/Melanoma-Cancer-Image-classification) |
| tuned to threshold 0.15 for screening-optimized sensitivity. |
| |
| Benchmarked on ISIC 2024 (n=50 balanced sample): |
| - Sensitivity: 84% (catches 21 of 25 melanomas) |
| - Specificity: 56% (clears 14 of 25 benign) |
| - Threshold: 0.15 (tuned for screening use case) |
| |
| Why threshold 0.15 instead of 0.50: |
| In melanoma screening, missing cancer is worse than a false alarm. A false |
| alarm sends someone to a dermatologist who clears them. A missed cancer |
| becomes an advanced tumor. We tune toward sensitivity. |
| |
| Usage (CLI): |
| python classify.py path/to/lesion.jpg |
| |
| Usage (Python): |
| from classify import SkannerClassifier |
| clf = SkannerClassifier() |
| result = clf.classify("lesion.jpg") |
| print(result["risk_level"], result["melanoma_probability"]) |
| """ |
|
|
| from __future__ import annotations |
|
|
| import sys |
| from pathlib import Path |
| from typing import Union |
|
|
| import torch |
| from PIL import Image |
| from transformers import AutoImageProcessor, AutoModelForImageClassification |
|
|
| |
| DEFAULT_MODEL = "Hemgg/Melanoma-Cancer-Image-classification" |
| MELANOMA_THRESHOLD = 0.15 |
| MODERATE_THRESHOLD = 0.08 |
|
|
|
|
| def _detect_device() -> str: |
| """Pick best available device. M-series Macs get MPS acceleration.""" |
| if torch.cuda.is_available(): |
| return "cuda" |
| if torch.backends.mps.is_available(): |
| return "mps" |
| return "cpu" |
|
|
|
|
| class SkannerClassifier: |
| """Binary melanoma/benign classifier tuned for screening.""" |
|
|
| def __init__(self, model_name: str = DEFAULT_MODEL, device: str | None = None): |
| self.device = device or _detect_device() |
| print(f"[Skanner] Loading model '{model_name}' on {self.device}...") |
| self.processor = AutoImageProcessor.from_pretrained(model_name) |
| self.model = AutoModelForImageClassification.from_pretrained(model_name) |
| self.model.to(self.device) |
| self.model.eval() |
| self.id2label = self.model.config.id2label |
| print(f"[Skanner] Ready. Classes: {list(self.id2label.values())}") |
|
|
| def classify(self, image: Union[str, Path, Image.Image]) -> dict: |
| """ |
| Run classification on a single image. |
| |
| Returns: |
| { |
| "melanoma_probability": float, # 0.0 - 1.0 (primary output) |
| "risk_level": "Low"|"Moderate"|"High", |
| "top_prediction": str, |
| "top_confidence": float, |
| "all_probabilities": {class_name: prob, ...}, |
| "threshold_used": float, |
| } |
| """ |
| |
| if isinstance(image, (str, Path)): |
| image = Image.open(image).convert("RGB") |
| elif not isinstance(image, Image.Image): |
| raise TypeError( |
| f"Expected str, Path, or PIL.Image; got {type(image).__name__}" |
| ) |
| else: |
| image = image.convert("RGB") |
|
|
| |
| inputs = self.processor(images=image, return_tensors="pt").to(self.device) |
| with torch.no_grad(): |
| logits = self.model(**inputs).logits |
| probs = torch.nn.functional.softmax(logits, dim=-1)[0].cpu() |
|
|
| |
| all_probs = {self.id2label[i]: float(probs[i]) for i in range(len(probs))} |
|
|
| |
| |
| melanoma_prob = 0.0 |
| for label, prob in all_probs.items(): |
| label_lower = label.lower() |
| if "malignant" in label_lower or "melanoma" in label_lower: |
| melanoma_prob = prob |
| break |
|
|
| |
| top_idx = int(torch.argmax(probs)) |
| top_class = self.id2label[top_idx] |
| top_conf = float(probs[top_idx]) |
|
|
| return { |
| "melanoma_probability": melanoma_prob, |
| "risk_level": self._triage(melanoma_prob), |
| "top_prediction": top_class, |
| "top_confidence": top_conf, |
| "all_probabilities": all_probs, |
| "threshold_used": MELANOMA_THRESHOLD, |
| } |
|
|
| @staticmethod |
| def _triage(melanoma_prob: float) -> str: |
| """Three-tier risk stratification, tuned for screening. |
| |
| High: prob >= 15% (flag for dermatologist referral) |
| Moderate: prob >= 8% (monitor / follow-up recommended) |
| Low: prob < 8% (routine self-monitoring) |
| """ |
| if melanoma_prob >= MELANOMA_THRESHOLD: |
| return "High" |
| if melanoma_prob >= MODERATE_THRESHOLD: |
| return "Moderate" |
| return "Low" |
|
|
|
|
| def _print_result(result: dict) -> None: |
| """Pretty-print a classification result to the terminal.""" |
| print() |
| print("=" * 60) |
| print(" SKANNER CLASSIFICATION RESULT") |
| print("=" * 60) |
| print(f" Melanoma probability: {result['melanoma_probability']:.1%}") |
| print(f" Risk level: {result['risk_level']}") |
| print(f" Threshold used: {result['threshold_used']:.2f} (screening-tuned)") |
| print() |
| print(" Class breakdown:") |
| sorted_probs = sorted( |
| result["all_probabilities"].items(), key=lambda x: -x[1] |
| ) |
| for cls, prob in sorted_probs: |
| bar = "█" * int(prob * 30) |
| print(f" {cls:<24s} {prob:6.1%} {bar}") |
| print("=" * 60) |
| print() |
| print(" REMINDER: This is a screening tool, NOT a medical diagnosis.") |
| print(" Always consult a qualified dermatologist.") |
| print() |
|
|
|
|
| def main(): |
| if len(sys.argv) < 2: |
| print("Usage: python classify.py <image_path>") |
| print("Example: python classify.py ISIC_2024_Permissive_Training_Input/ISIC_9855202.jpg") |
| sys.exit(1) |
|
|
| image_path = Path(sys.argv[1]) |
| if not image_path.exists(): |
| print(f"Error: file not found: {image_path}") |
| sys.exit(1) |
|
|
| classifier = SkannerClassifier() |
| result = classifier.classify(image_path) |
| _print_result(result) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|