sindri101's picture
Upload folder using huggingface_hub
d5b33af verified
"""
Chest Drain Detector — detects chest drain presence on chest X-rays.
Usage:
# Option 1: one-liner from HuggingFace (recommended)
from chest_drain_detector import load_model
detector = load_model()
# Option 2: from a local directory
from chest_drain_detector import ChestDrainDetector
detector = ChestDrainDetector.from_pretrained("/path/to/model_dir")
# Predict
result = detector("path/to/cxr.png")
# {'prediction': 'chest_drain_present', 'probability': 0.987, 'label': 1}
results = detector.predict_batch(["img1.png", "img2.png"])
"""
import os
import sys
import json
import importlib
import torch
import torch.nn as nn
import timm
import numpy as np
from PIL import Image
from torchvision import transforms
REPO_ID = "sindri101/chest-drain-detector"
def load_model(repo_id=REPO_ID, device="auto"):
"""Download (if needed) and load the chest drain detector from HuggingFace.
Args:
repo_id: HuggingFace repo id (default: sindri101/chest-drain-detector)
device: 'auto', 'cuda', or 'cpu'
Returns:
ChestDrainDetector ready for inference.
Example:
>>> from chest_drain_detector import load_model
>>> detector = load_model()
>>> detector("path/to/chest_xray.png")
{'prediction': 'chest_drain_present', 'probability': 0.9993, 'label': 1}
"""
from huggingface_hub import snapshot_download
model_dir = snapshot_download(repo_id)
return ChestDrainDetector.from_pretrained(model_dir, device=device)
class ChestDrainDetector(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.model = timm.create_model(
config["backbone"], pretrained=False, num_classes=config["num_classes"]
)
self.threshold = config["threshold"]
self.labels = config["labels"]
self.device = torch.device("cpu")
self.transform = transforms.Compose([
transforms.Resize((config["image_size"], config["image_size"])),
transforms.ToTensor(),
transforms.Normalize(
mean=config["normalize_mean"],
std=config["normalize_std"],
),
])
@classmethod
def from_pretrained(cls, model_dir, device="auto"):
"""Load model from a directory containing config.json and pytorch_model.bin."""
config_path = os.path.join(model_dir, "config.json")
weights_path = os.path.join(model_dir, "pytorch_model.bin")
with open(config_path) as f:
config = json.load(f)
detector = cls(config)
state_dict = torch.load(weights_path, map_location="cpu", weights_only=True)
detector.model.load_state_dict(state_dict)
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
detector.device = torch.device(device)
detector.model.to(detector.device)
detector.model.eval()
return detector
def _load_image(self, image_path):
"""Load and preprocess a single image from path."""
img = Image.open(image_path).convert("RGB")
return self.transform(img)
@torch.no_grad()
def __call__(self, image_path):
"""Predict chest drain presence for a single image.
Args:
image_path: path to a chest X-ray image (PNG, JPEG, DICOM not supported)
Returns:
dict with keys:
- prediction: 'chest_drain_present' or 'no_chest_drain'
- probability: float, probability of chest drain presence
- label: int, 1 if present, 0 if absent
"""
img = self._load_image(image_path).unsqueeze(0).to(self.device)
logit = self.model(img).squeeze()
prob = torch.sigmoid(logit).item()
label = int(prob >= self.threshold)
return {
"prediction": self.labels[str(label)],
"probability": round(prob, 4),
"label": label,
}
@torch.no_grad()
def predict_batch(self, image_paths, batch_size=32):
"""Predict chest drain presence for a list of images.
Args:
image_paths: list of paths to chest X-ray images
batch_size: number of images to process at once
Returns:
list of dicts, same format as __call__
"""
results = []
for i in range(0, len(image_paths), batch_size):
batch_paths = image_paths[i:i + batch_size]
imgs = torch.stack([self._load_image(p) for p in batch_paths])
imgs = imgs.to(self.device)
logits = self.model(imgs).squeeze(-1)
probs = torch.sigmoid(logits).cpu().numpy()
for path, prob in zip(batch_paths, probs):
label = int(prob >= self.threshold)
results.append({
"image_path": path,
"prediction": self.labels[str(label)],
"probability": round(float(prob), 4),
"label": label,
})
return results
if __name__ == "__main__":
import sys
model_dir = os.path.dirname(os.path.abspath(__file__))
detector = ChestDrainDetector.from_pretrained(model_dir)
if len(sys.argv) > 1:
for path in sys.argv[1:]:
result = detector(path)
print(f"{path}: {result['prediction']} (prob={result['probability']:.4f})")
else:
print("Usage: python model.py <image_path> [image_path2 ...]")
print(f"Model loaded from {model_dir}")
print(f"Threshold: {detector.threshold}")
print(f"Device: {detector.device}")