Image Classification
timm
PyTorch
chest-x-ray
medical-imaging
binary-classification
chest-drain
confounder-detection
Instructions to use sindri101/chest-drain-predictor with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- timm
How to use sindri101/chest-drain-predictor with timm:
import timm model = timm.create_model("hf_hub:sindri101/chest-drain-predictor", pretrained=True) - Notebooks
- Google Colab
- Kaggle
| """ | |
| Chest Drain Detector — detects chest drain presence on chest X-rays. | |
| Usage: | |
| # Option 1: one-liner from HuggingFace (recommended) | |
| from chest_drain_detector import load_model | |
| detector = load_model() | |
| # Option 2: from a local directory | |
| from chest_drain_detector import ChestDrainDetector | |
| detector = ChestDrainDetector.from_pretrained("/path/to/model_dir") | |
| # Predict | |
| result = detector("path/to/cxr.png") | |
| # {'prediction': 'chest_drain_present', 'probability': 0.987, 'label': 1} | |
| results = detector.predict_batch(["img1.png", "img2.png"]) | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import importlib | |
| import torch | |
| import torch.nn as nn | |
| import timm | |
| import numpy as np | |
| from PIL import Image | |
| from torchvision import transforms | |
| REPO_ID = "sindri101/chest-drain-detector" | |
| def load_model(repo_id=REPO_ID, device="auto"): | |
| """Download (if needed) and load the chest drain detector from HuggingFace. | |
| Args: | |
| repo_id: HuggingFace repo id (default: sindri101/chest-drain-detector) | |
| device: 'auto', 'cuda', or 'cpu' | |
| Returns: | |
| ChestDrainDetector ready for inference. | |
| Example: | |
| >>> from chest_drain_detector import load_model | |
| >>> detector = load_model() | |
| >>> detector("path/to/chest_xray.png") | |
| {'prediction': 'chest_drain_present', 'probability': 0.9993, 'label': 1} | |
| """ | |
| from huggingface_hub import snapshot_download | |
| model_dir = snapshot_download(repo_id) | |
| return ChestDrainDetector.from_pretrained(model_dir, device=device) | |
| class ChestDrainDetector(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.model = timm.create_model( | |
| config["backbone"], pretrained=False, num_classes=config["num_classes"] | |
| ) | |
| self.threshold = config["threshold"] | |
| self.labels = config["labels"] | |
| self.device = torch.device("cpu") | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((config["image_size"], config["image_size"])), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=config["normalize_mean"], | |
| std=config["normalize_std"], | |
| ), | |
| ]) | |
| def from_pretrained(cls, model_dir, device="auto"): | |
| """Load model from a directory containing config.json and pytorch_model.bin.""" | |
| config_path = os.path.join(model_dir, "config.json") | |
| weights_path = os.path.join(model_dir, "pytorch_model.bin") | |
| with open(config_path) as f: | |
| config = json.load(f) | |
| detector = cls(config) | |
| state_dict = torch.load(weights_path, map_location="cpu", weights_only=True) | |
| detector.model.load_state_dict(state_dict) | |
| if device == "auto": | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| detector.device = torch.device(device) | |
| detector.model.to(detector.device) | |
| detector.model.eval() | |
| return detector | |
| def _load_image(self, image_path): | |
| """Load and preprocess a single image from path.""" | |
| img = Image.open(image_path).convert("RGB") | |
| return self.transform(img) | |
| def __call__(self, image_path): | |
| """Predict chest drain presence for a single image. | |
| Args: | |
| image_path: path to a chest X-ray image (PNG, JPEG, DICOM not supported) | |
| Returns: | |
| dict with keys: | |
| - prediction: 'chest_drain_present' or 'no_chest_drain' | |
| - probability: float, probability of chest drain presence | |
| - label: int, 1 if present, 0 if absent | |
| """ | |
| img = self._load_image(image_path).unsqueeze(0).to(self.device) | |
| logit = self.model(img).squeeze() | |
| prob = torch.sigmoid(logit).item() | |
| label = int(prob >= self.threshold) | |
| return { | |
| "prediction": self.labels[str(label)], | |
| "probability": round(prob, 4), | |
| "label": label, | |
| } | |
| def predict_batch(self, image_paths, batch_size=32): | |
| """Predict chest drain presence for a list of images. | |
| Args: | |
| image_paths: list of paths to chest X-ray images | |
| batch_size: number of images to process at once | |
| Returns: | |
| list of dicts, same format as __call__ | |
| """ | |
| results = [] | |
| for i in range(0, len(image_paths), batch_size): | |
| batch_paths = image_paths[i:i + batch_size] | |
| imgs = torch.stack([self._load_image(p) for p in batch_paths]) | |
| imgs = imgs.to(self.device) | |
| logits = self.model(imgs).squeeze(-1) | |
| probs = torch.sigmoid(logits).cpu().numpy() | |
| for path, prob in zip(batch_paths, probs): | |
| label = int(prob >= self.threshold) | |
| results.append({ | |
| "image_path": path, | |
| "prediction": self.labels[str(label)], | |
| "probability": round(float(prob), 4), | |
| "label": label, | |
| }) | |
| return results | |
| if __name__ == "__main__": | |
| import sys | |
| model_dir = os.path.dirname(os.path.abspath(__file__)) | |
| detector = ChestDrainDetector.from_pretrained(model_dir) | |
| if len(sys.argv) > 1: | |
| for path in sys.argv[1:]: | |
| result = detector(path) | |
| print(f"{path}: {result['prediction']} (prob={result['probability']:.4f})") | |
| else: | |
| print("Usage: python model.py <image_path> [image_path2 ...]") | |
| print(f"Model loaded from {model_dir}") | |
| print(f"Threshold: {detector.threshold}") | |
| print(f"Device: {detector.device}") | |