"""Image-level fracture classification using a pretrained HF vision model. YOLO is kept for localization only. This classifier provides the image-level second opinion, which is safer than treating detector "Not_Fracture" boxes as proof that the whole scan is normal. """ from __future__ import annotations import logging from PIL import Image logger = logging.getLogger(__name__) _processor = None _model = None def _get_model(): """Load the fracture classifier and processor lazily.""" global _processor, _model if _model is None: logger.info("Loading pretrained fracture classifier...") try: from transformers import AutoImageProcessor, AutoModelForImageClassification from app.config import get_settings model_name = get_settings().fracture_classifier_model_name try: _processor = AutoImageProcessor.from_pretrained(model_name) except Exception: from transformers import AutoProcessor _processor = AutoProcessor.from_pretrained(model_name) _model = AutoModelForImageClassification.from_pretrained(model_name) _model.eval() logger.info(f"Fracture classifier loaded: {model_name}") except Exception as exc: logger.error(f"Failed to load fracture classifier: {exc}") raise return _processor, _model def predict_fracture_presence(image: Image.Image) -> list[dict]: """Classify whether the full image likely contains a fracture.""" import torch processor, model = _get_model() inputs = processor(images=image.convert("RGB"), return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0] id2label = model.config.id2label ranked = sorted( ((float(probs[i]) * 100, str(id2label[int(i)])) for i in range(len(probs))), reverse=True, ) top_conf, top_label = ranked[0] top_key = _normalize_label(top_label) if _is_fracture_label(top_key): if top_conf >= 85: severity, color = "high", "destructive" elif top_conf >= 65: severity, color = "moderate", "warning" else: severity, color = "low", "info" return [{ "name": "Fracture suspected", "confidence": round(top_conf, 1), "severity": severity, "model": "FractureClassifier", "region": "Full image", "icd_code": "S02-S92", "color": color, }] return [{ "name": "No fracture suspected by classifier", "confidence": round(top_conf, 1), "severity": "clear", "model": "FractureClassifier", "region": "Full image", "icd_code": "", "color": "success", }] def _normalize_label(label: str) -> str: return label.lower().replace("-", "_").replace(" ", "_") def _is_fracture_label(label: str) -> bool: if any(token in label for token in ("not", "normal", "negative", "no_fracture", "nofracture")): return False return "fracture" in label or "fractured" in label or label in {"positive", "label_1", "1"}