Image Classification
Transformers
PyTorch
English
computer-vision
medical-imaging
multiclass-classification
diagnostic-imaging
anatomy-classification
xray
RohitManglik's picture
Update README.md
eabd061 verified
metadata
license: cc-by-4.0
datasets:
  - InfoBayAI/xray_clinical_reports_without_findings_medical_nlp
  - InfoBayAI/xray_clinical_reports_with_findings_medical_nlp
language:
  - en
metrics:
  - accuracy
  - confusion_matrix
base_model:
  - microsoft/resnet-50
library_name: transformers
tags:
  - computer-vision
  - medical-imaging
  - multiclass-classification
  - diagnostic-imaging
  - anatomy-classification
  - xray
  - pytorch
pipeline_tag: image-classification

Model Description

MedXray-ResNet50-BodyPart-Classifier is a deep learning–based medical imaging model designed for automated X-ray anatomy classification using a fine-tuned ResNet-50 architecture.

The model was trained on a curated Medical X-Ray imaging dataset provided by InfoBay.AI, containing X-ray scans from several anatomical body regions including chest, skull, spine, pelvis, joints, hands, and foot images. Using transfer learning and fine-tuning techniques, the pretrained ResNet-50 backbone was adapted for medical radiology classification tasks, enabling efficient anatomical X-ray recognition and healthcare-focused computer vision analysis.

This project demonstrates how fine-tuning ImageNet-pretrained computer vision models can be effectively leveraged for healthcare AI, radiology automation, medical image understanding, and diagnostic support systems.

resnet50_xray_finetuning (2)


Medical X-ray Anatomy Classification

The model classifies radiographic scans into the following body regions:

Chest X-ray
Skull X-ray
Spine X-ray
Pelvis X-ray
Knee and Joint X-ray
Hand X-ray
Foot X-ray

The system is designed for medical imaging research, radiology AI experimentation, healthcare computer vision workflows, and automated dataset structuring applications.


X-ray Classification Pipeline

The complete deep learning workflow used for training is as follows:

Raw X-ray Images → Image Preprocessing → Image Normalization → Dataset Labeling → ResNet50 Fine-Tuning → Multi-Class Classification

Data Processing Steps
Medical X-ray image ingestion
Grayscale to 3-channel RGB conversion
Tensor transformation using PyTorch
Pixel normalization using mean/std normalization
Multi-class anatomical labeling
Fine-tuning of pretrained ResNet50 CNN backbone

Deep Learning Architecture

Architecture: ResNet50
Framework: PyTorch
Backbone: ImageNet-pretrained ResNet50
Training Strategy: Transfer Learning + Fine-Tuning
Task Type: Multi-Class Image Classification
Domain: Medical Imaging / Radiology AI

The model uses a fine-tuned ResNet50 convolutional neural network where the final fully connected classification layer was modified to support seven anatomical X-ray categories.


Key Features

Automated X-ray body-part classification
Fine-tuned ResNet50 medical imaging model
Radiology-focused computer vision pipeline
Multi-class anatomical recognition
Transfer learning for healthcare AI
PyTorch-based medical AI workflow
Medical image preprocessing and normalization
Deep learning–based radiographic analysis

Dataset Split

Training Set: 70%
Validation Set: 15%
Test Set: 15%
Split Strategy: Random sampling
Number of Classes: 7

Training Hyperparameters

Number of Epochs: 5
Batch Size: 16
Learning Rate: 0.0001
Optimizer: Adam
Loss Function: Cross-Entropy Loss
Input Image Size: 224 × 224
Device Support: CPU / GPU (CUDA)

Model Performance

The fine-tuned ResNet50 model achieved strong performance on validation and unseen test datasets for anatomical X-ray classification tasks.

Before fine-tuning, the ImageNet-pretrained baseline model achieved only:

Validation Accuracy: 12.2%
Test Accuracy: 10.1%

After transfer learning and fine-tuning on the multi-region X-ray anatomy dataset, the model achieved:

Validation Accuracy: 97.0%
Test Accuracy: 94.0%

Performance may vary depending on:

image quality
dataset diversity
class balance
preprocessing consistency
domain distribution

Classification Labels

Class ID Label
0 Chest
1 Skull
2 Spine
3 Pelvis
4 Joints
5 Hands
6 Foot

Usage

Install Dependencies

pip install torch torchvision pillow numpy

Load Trained Model

import torch
import torch.nn as nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_download
from PIL import Image
import json

# ==============================
# DEVICE
# ==============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==============================
# DOWNLOAD MODEL FILES
# ==============================
model_path = hf_hub_download(
    repo_id="InfoBayAI/ResNet50-Xray-Anatomy-Classifier",
    filename="xray_resnet50.pth"
)

labels_path = hf_hub_download(
    repo_id="InfoBayAI/ResNet50-Xray-Anatomy-Classifier",
    filename="labels.json"
)

# ==============================
# LOAD LABELS
# ==============================
with open(labels_path, "r") as f:
    idx_to_class = json.load(f)

# ==============================
# CREATE MODEL
# ==============================
model = models.resnet50(pretrained=False)

model.fc = nn.Linear(
    model.fc.in_features,
    len(idx_to_class)
)

# ==============================
# LOAD TRAINED WEIGHTS
# ==============================
model.load_state_dict(
    torch.load(model_path, map_location=device)
)

model.to(device)
model.eval()

print("Model loaded successfully!")

# ==============================
# IMAGE TRANSFORM
# ==============================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.5, 0.5, 0.5],
        [0.5, 0.5, 0.5]
    )
])

# ==============================
# PREDICTION FUNCTION
# ==============================
def predict_xray(image_path):

    img = Image.open(image_path).convert("RGB")
    img = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(img)

        probs = torch.softmax(outputs, dim=1)

        confidence, pred = torch.max(probs, 1)

    predicted_label = idx_to_class[str(pred.item())]

    print(f"Prediction: {predicted_label}")
    print(f"Confidence: {confidence.item()*100:.2f}%")

# ==============================
# EXAMPLE
# ==============================
predict_xray("test_xray.png")

Considerations

This model is trained on a X-Ray image dataset of InfoBay.AI and is intended for research and evaluation purposes only.

For access to the full dataset or enterprise licensing inquiries, please contact InfoBay.AI.

Ph: +91 8303174762
Email: datareq@infobay.ai