""" Chest Drain Detector — detects chest drain presence on chest X-rays. Usage: # Option 1: one-liner from HuggingFace (recommended) from chest_drain_detector import load_model detector = load_model() # Option 2: from a local directory from chest_drain_detector import ChestDrainDetector detector = ChestDrainDetector.from_pretrained("/path/to/model_dir") # Predict result = detector("path/to/cxr.png") # {'prediction': 'chest_drain_present', 'probability': 0.987, 'label': 1} results = detector.predict_batch(["img1.png", "img2.png"]) """ import os import sys import json import importlib import torch import torch.nn as nn import timm import numpy as np from PIL import Image from torchvision import transforms REPO_ID = "sindri101/chest-drain-detector" def load_model(repo_id=REPO_ID, device="auto"): """Download (if needed) and load the chest drain detector from HuggingFace. Args: repo_id: HuggingFace repo id (default: sindri101/chest-drain-detector) device: 'auto', 'cuda', or 'cpu' Returns: ChestDrainDetector ready for inference. Example: >>> from chest_drain_detector import load_model >>> detector = load_model() >>> detector("path/to/chest_xray.png") {'prediction': 'chest_drain_present', 'probability': 0.9993, 'label': 1} """ from huggingface_hub import snapshot_download model_dir = snapshot_download(repo_id) return ChestDrainDetector.from_pretrained(model_dir, device=device) class ChestDrainDetector(nn.Module): def __init__(self, config): super().__init__() self.config = config self.model = timm.create_model( config["backbone"], pretrained=False, num_classes=config["num_classes"] ) self.threshold = config["threshold"] self.labels = config["labels"] self.device = torch.device("cpu") self.transform = transforms.Compose([ transforms.Resize((config["image_size"], config["image_size"])), transforms.ToTensor(), transforms.Normalize( mean=config["normalize_mean"], std=config["normalize_std"], ), ]) @classmethod def from_pretrained(cls, model_dir, device="auto"): """Load model from a directory containing config.json and pytorch_model.bin.""" config_path = os.path.join(model_dir, "config.json") weights_path = os.path.join(model_dir, "pytorch_model.bin") with open(config_path) as f: config = json.load(f) detector = cls(config) state_dict = torch.load(weights_path, map_location="cpu", weights_only=True) detector.model.load_state_dict(state_dict) if device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" detector.device = torch.device(device) detector.model.to(detector.device) detector.model.eval() return detector def _load_image(self, image_path): """Load and preprocess a single image from path.""" img = Image.open(image_path).convert("RGB") return self.transform(img) @torch.no_grad() def __call__(self, image_path): """Predict chest drain presence for a single image. Args: image_path: path to a chest X-ray image (PNG, JPEG, DICOM not supported) Returns: dict with keys: - prediction: 'chest_drain_present' or 'no_chest_drain' - probability: float, probability of chest drain presence - label: int, 1 if present, 0 if absent """ img = self._load_image(image_path).unsqueeze(0).to(self.device) logit = self.model(img).squeeze() prob = torch.sigmoid(logit).item() label = int(prob >= self.threshold) return { "prediction": self.labels[str(label)], "probability": round(prob, 4), "label": label, } @torch.no_grad() def predict_batch(self, image_paths, batch_size=32): """Predict chest drain presence for a list of images. Args: image_paths: list of paths to chest X-ray images batch_size: number of images to process at once Returns: list of dicts, same format as __call__ """ results = [] for i in range(0, len(image_paths), batch_size): batch_paths = image_paths[i:i + batch_size] imgs = torch.stack([self._load_image(p) for p in batch_paths]) imgs = imgs.to(self.device) logits = self.model(imgs).squeeze(-1) probs = torch.sigmoid(logits).cpu().numpy() for path, prob in zip(batch_paths, probs): label = int(prob >= self.threshold) results.append({ "image_path": path, "prediction": self.labels[str(label)], "probability": round(float(prob), 4), "label": label, }) return results if __name__ == "__main__": import sys model_dir = os.path.dirname(os.path.abspath(__file__)) detector = ChestDrainDetector.from_pretrained(model_dir) if len(sys.argv) > 1: for path in sys.argv[1:]: result = detector(path) print(f"{path}: {result['prediction']} (prob={result['probability']:.4f})") else: print("Usage: python model.py [image_path2 ...]") print(f"Model loaded from {model_dir}") print(f"Threshold: {detector.threshold}") print(f"Device: {detector.device}")