Spaces:
Sleeping
Sleeping
| """ | |
| PyTorch Dataset and DataLoader utilities for Chest X-Ray classification. | |
| """ | |
| from pathlib import Path | |
| from typing import Tuple, Optional, List | |
| import random | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler | |
| from torchvision import transforms | |
| from PIL import Image | |
| from sklearn.model_selection import train_test_split | |
| from .config import ( | |
| DATA_DIR, IMAGE_SIZE, BATCH_SIZE, NUM_WORKERS, | |
| IMAGENET_MEAN, IMAGENET_STD, CLASS_NAMES, SEED | |
| ) | |
| class ChestXRayDataset(Dataset): | |
| """Dataset for Chest X-Ray images.""" | |
| def __init__( | |
| self, | |
| image_paths: List[Path], | |
| labels: List[int], | |
| transform: Optional[transforms.Compose] = None | |
| ): | |
| self.image_paths = image_paths | |
| self.labels = labels | |
| self.transform = transform | |
| def __len__(self) -> int: | |
| return len(self.image_paths) | |
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: | |
| img_path = self.image_paths[idx] | |
| label = self.labels[idx] | |
| # Load image and convert to RGB | |
| image = Image.open(img_path).convert('RGB') | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, label | |
| def get_transforms(is_training: bool = True) -> transforms.Compose: | |
| """Get image transforms for training or validation/test.""" | |
| if is_training: | |
| return transforms.Compose([ | |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.RandomRotation(10), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) | |
| ]) | |
| else: | |
| return transforms.Compose([ | |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) | |
| ]) | |
| def load_image_paths_and_labels( | |
| data_dir: Path, | |
| split: str | |
| ) -> Tuple[List[Path], List[int]]: | |
| """Load image paths and labels from a data split directory.""" | |
| image_paths = [] | |
| labels = [] | |
| for class_idx, class_name in enumerate(CLASS_NAMES): | |
| class_dir = data_dir / split / class_name | |
| if class_dir.exists(): | |
| for img_path in class_dir.glob('*.jpeg'): | |
| image_paths.append(img_path) | |
| labels.append(class_idx) | |
| return image_paths, labels | |
| def create_train_val_split( | |
| data_dir: Path = DATA_DIR, | |
| val_ratio: float = 0.15, | |
| seed: int = SEED | |
| ) -> Tuple[List[Path], List[int], List[Path], List[int]]: | |
| """Create stratified train/val split from training data.""" | |
| # Load all training images | |
| train_paths, train_labels = load_image_paths_and_labels(data_dir, 'train') | |
| # Stratified split | |
| train_paths, val_paths, train_labels, val_labels = train_test_split( | |
| train_paths, train_labels, | |
| test_size=val_ratio, | |
| stratify=train_labels, | |
| random_state=seed | |
| ) | |
| return train_paths, train_labels, val_paths, val_labels | |
| def get_class_weights(labels: List[int]) -> torch.Tensor: | |
| """Calculate class weights for imbalanced dataset.""" | |
| class_counts = torch.bincount(torch.tensor(labels)) | |
| total = len(labels) | |
| weights = total / (len(class_counts) * class_counts.float()) | |
| return weights | |
| def get_sampler(labels: List[int]) -> WeightedRandomSampler: | |
| """Create weighted sampler for balanced batches.""" | |
| class_weights = get_class_weights(labels) | |
| sample_weights = [class_weights[label] for label in labels] | |
| sampler = WeightedRandomSampler( | |
| weights=sample_weights, | |
| num_samples=len(labels), | |
| replacement=True | |
| ) | |
| return sampler | |
| def get_dataloaders( | |
| data_dir: Path = DATA_DIR, | |
| batch_size: int = BATCH_SIZE, | |
| num_workers: int = NUM_WORKERS, | |
| val_ratio: float = 0.15, | |
| use_weighted_sampling: bool = True | |
| ) -> Tuple[DataLoader, DataLoader, DataLoader]: | |
| """Create train, validation, and test DataLoaders.""" | |
| # Create train/val split | |
| train_paths, train_labels, val_paths, val_labels = create_train_val_split( | |
| data_dir, val_ratio | |
| ) | |
| # Load test data | |
| test_paths, test_labels = load_image_paths_and_labels(data_dir, 'test') | |
| # Create datasets | |
| train_dataset = ChestXRayDataset( | |
| train_paths, train_labels, transform=get_transforms(is_training=True) | |
| ) | |
| val_dataset = ChestXRayDataset( | |
| val_paths, val_labels, transform=get_transforms(is_training=False) | |
| ) | |
| test_dataset = ChestXRayDataset( | |
| test_paths, test_labels, transform=get_transforms(is_training=False) | |
| ) | |
| # Create sampler for training if using weighted sampling | |
| train_sampler = get_sampler(train_labels) if use_weighted_sampling else None | |
| # Only use pin_memory for CUDA (not supported on MPS) | |
| pin_memory = torch.cuda.is_available() | |
| # Create dataloaders | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=batch_size, | |
| sampler=train_sampler, | |
| shuffle=(train_sampler is None), | |
| num_workers=num_workers, | |
| pin_memory=pin_memory | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory | |
| ) | |
| test_loader = DataLoader( | |
| test_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory | |
| ) | |
| # Print dataset info | |
| print(f"Train: {len(train_dataset)} images") | |
| print(f"Val: {len(val_dataset)} images") | |
| print(f"Test: {len(test_dataset)} images") | |
| return train_loader, val_loader, test_loader | |
| def get_pos_weight(labels: List[int]) -> torch.Tensor: | |
| """Calculate pos_weight for BCEWithLogitsLoss to handle class imbalance.""" | |
| labels_tensor = torch.tensor(labels) | |
| neg_count = (labels_tensor == 0).sum().float() # NORMAL | |
| pos_count = (labels_tensor == 1).sum().float() # PNEUMONIA | |
| pos_weight = neg_count / pos_count | |
| return pos_weight | |