MedPMC Multi-Figure Detection Model: ViT

This repository provides the vision transformer-based multi-figure detection model used in the MedPMC data curation pipeline.

The model is a binary image classifier trained to predict whether a biomedical figure is a multi-panel / compound figure or a single-panel figure. It is intended for processing figures from biomedical literature, especially figures from PubMed Central (PMC) articles.

Task

The model performs binary image classification.

0: single-panel figure
1: multi-panel / compound figure

Usage

import torch
import timm
from PIL import Image
from torchvision import transforms

checkpoint_path = "model.pth.tar"
image_path = "example.jpg"

device = "cuda" if torch.cuda.is_available() else "cpu"

checkpoint = torch.load(checkpoint_path, map_location="cpu")
arch = checkpoint["arch"]
state_dict = checkpoint["state_dict"]

# Remove DataParallel/DDP prefix if present.
state_dict = {
    k.replace("module.", "", 1) if k.startswith("module.") else k: v
    for k, v in state_dict.items()
}

# Binary classifier.
model = timm.create_model(
    arch,
    pretrained=False,
    num_classes=2,
)

model.load_state_dict(state_dict, strict=True)
model = model.to(device)
model.eval()

preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
    ),
])

image = Image.open(image_path).convert("RGB")
inputs = preprocess(image).unsqueeze(0).to(device)

with torch.no_grad():
    logits = model(inputs)
    probs = torch.softmax(logits, dim=-1)
    pred = torch.argmax(probs, dim=-1).item()

print("Prediction:", pred)
print("Probabilities:", probs.cpu().tolist())

Example output:

Prediction: 1
Probabilities: [[0.08, 0.92]]

This means that the model predicts the input image as a multi-panel / compound figure.

Batch Inference

import torch
import timm
from PIL import Image
from pathlib import Path
from torchvision import transforms

checkpoint_path = "model.pth.tar"
image_dir = "sample"

device = "cuda" if torch.cuda.is_available() else "cpu"

checkpoint = torch.load(checkpoint_path, map_location="cpu")
arch = checkpoint["arch"]
state_dict = checkpoint["state_dict"]

state_dict = {
    k.replace("module.", "", 1) if k.startswith("module.") else k: v
    for k, v in state_dict.items()
}

model = timm.create_model(arch, pretrained=False, num_classes=2)
model.load_state_dict(state_dict, strict=True)
model = model.to(device)
model.eval()

preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
    ),
])

image_paths = sorted(
    list(Path(image_dir).glob("*.jpg")) +
    list(Path(image_dir).glob("*.jpeg")) +
    list(Path(image_dir).glob("*.png"))
)

for image_path in image_paths:
    image = Image.open(image_path).convert("RGB")
    inputs = preprocess(image).unsqueeze(0).to(device)

    with torch.no_grad():
        logits = model(inputs)
        probs = torch.softmax(logits, dim=-1)
        pred = torch.argmax(probs, dim=-1).item()

    print("Image:", image_path)
    print("Prediction:", pred)
    print("Probabilities:", probs.cpu().tolist())

License

The model is released for non-commercial research use under CC BY-NC-SA 4.0.

Citation

Citation information will be updated soon.

Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including Yale-BIDS-Chen/medpmc-multi-fig-detection-vit