Mittalyash's picture
Upload folder using huggingface_hub
9916246 verified
import os
import cv2
import pandas as pd
import torch
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
class VinDrCXRClassificationDataset(Dataset):
def __init__(self, csv_file, img_dir, transform=None):
"""
Original dataset implementation for AI-CliniScan classification.
Labels are aggregated per image_id.
"""
self.img_dir = img_dir
self.df = pd.read_csv(csv_file)
# Aggregate unique labels per image
self.image_labels = self.df.groupby('image_id')['class_id'].apply(lambda x: list(set(x))).to_dict()
self.image_ids = list(self.image_labels.keys())
# Limit to 2000 images for speedy training
import random
random.seed(42)
if len(self.image_ids) > 2000:
self.image_ids = random.sample(self.image_ids, 2000)
self.num_classes = 15 # VinDr-CXR has 14 abnormalities + 1 'No finding'
if transform is None:
self.transform = A.Compose([
A.Resize(256, 256),
A.Normalize(mean=(0.485,), std=(0.229,)),
ToTensorV2()
])
else:
self.transform = transform
def __len__(self):
return len(self.image_ids)
def __getitem__(self, idx):
img_id = self.image_ids[idx]
# Append .png since the 256x256 kaggle dataset contains pngs
img_path = os.path.join(self.img_dir, img_id + '.png')
# Load grayscale and convert to RGB format for ResNet
image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
if self.transform:
augmented = self.transform(image=image)
image = augmented['image']
# Multi-label one-hot encoding
labels = self.image_labels[img_id]
target = torch.zeros(self.num_classes, dtype=torch.float32)
for label in labels:
if not pd.isna(label):
target[int(label)] = 1.0
return image, target
def get_train_val_transforms():
train_transform = A.Compose([
A.Resize(256, 256),
A.RandomCrop(224, 224),
A.HorizontalFlip(p=0.5),
A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=15, p=0.5),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2()
])
val_transform = A.Compose([
A.Resize(256, 256),
A.CenterCrop(224, 224),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2()
])
return train_transform, val_transform