# Standard library import os import io import zipfile import pickle from pathlib import Path # Data handling import pandas as pd import numpy as np # PyTorch import torch from torch.utils.data import Dataset # Image processing from PIL import Image import cv2 # Augmentations import albumentations as A from albumentations.pytorch import ToTensorV2 # Progress bar (for precompute_all_masks) from tqdm import tqdm class OptimizedZipReader: """ Fast ZIP file reader with LRU caching """ def __init__(self, zip_path, cache_size=1000): """ Args: zip_path: Path to ZIP file cache_size: Number of images to cache in RAM """ self.zip_path = zip_path self.cache_size = cache_size self._zip_file = None # Will be lazily initialized self._name_to_info = None # Cache self._cache = {} self._cache_order = [] self._hits = 0 self._misses = 0 @property def zip_file(self): """Lazy initialization of ZIP file handle""" if self._zip_file is None: print(f"Opening ZIP file: {self.zip_path}") self._zip_file = zipfile.ZipFile(self.zip_path, 'r', allowZip64=True) # Build index on first access print("Building ZIP index...") self._name_to_info = { info.filename: info for info in self._zip_file.infolist() } print(f"✓ Indexed {len(self._name_to_info)} files") return self._zip_file def read_image(self, path): """ Read image data with automatic caching Returns: bytes (image file data) """ # Check cache first if path in self._cache: self._hits += 1 return self._cache[path] # Cache miss - read from ZIP (this triggers lazy initialization) self._misses += 1 img_data = self.zip_file.read(path) # Uses property getter # Add to cache with LRU eviction if len(self._cache) >= self.cache_size: oldest = self._cache_order.pop(0) del self._cache[oldest] self._cache[path] = img_data self._cache_order.append(path) return img_data def get_cache_stats(self): """Return cache hit rate statistics""" total = self._hits + self._misses hit_rate = self._hits / total * 100 if total > 0 else 0 return { 'hits': self._hits, 'misses': self._misses, 'hit_rate': f"{hit_rate:.2f}%", 'cache_size': len(self._cache) } def close(self): """Close ZIP file and clear cache""" if self._zip_file is not None: self._zip_file.close() self._zip_file = None self._cache.clear() self._cache_order.clear() self._name_to_info = None class CheXpertDataset(Dataset): """ CheXpert Dataset class NEW: Returns 3-channel images: (img, img*mask, mask) - Channel 0: Original grayscale image - Channel 1: Masked image (lung region only) - Channel 2: Binary lung mask Args: csv_path (str): Path to the CSV file (train.csv or valid.csv) root_dir (str): Root directory of the CheXpert dataset image_size (int): Target image size (default: 384) augment (bool): Whether to apply augmentations (default: False) use_frontal_only (bool): If True, only use frontal view images (default: True) fill_uncertain (str): How to handle uncertain labels: 'zeros', 'ones', 'ignore' (default: 'zeros') """ # 14 pathology classes in CheXpert PATHOLOGIES = [ 'No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices' ] def __init__( self, csv_path, root_dir, image_size=384, augment=False, use_frontal_only=False, fill_uncertain='ignore', lmdb_path=None, zip_path=None, zip_cache_size=1000, mask_dir=None, domask=False ): self.root_dir = root_dir self.image_size = image_size self.augment = augment self.fill_uncertain = fill_uncertain self.env =None #lmdb.open(lmdb_path, readonly=True, lock=False) if lmdb_path else None self._zip_path = zip_path self._zip_cache_size = zip_cache_size self._zip_reader_instance = None # Read CSV file self.df = pd.read_csv(csv_path) for pathology in self.PATHOLOGIES: if pathology in self.df.columns: self.df[pathology] = pd.to_numeric(self.df[pathology], errors='coerce') # Filter for frontal views only if specified if use_frontal_only: self.df = self.df[self.df['Frontal/Lateral'] == 'Frontal'].reset_index(drop=True) # Handle uncertain labels (-1 values) self._process_uncertain_labels() # Setup augmentations self.train_transform = self._get_train_transforms() self.val_transform = self._get_val_transforms() print(f"Loaded {len(self.df)} images from {csv_path}") print(f"Image size: {image_size}x{image_size}") print(f"Augmentation: {augment}") print(f"Uncertain labels filled with: {fill_uncertain}") if mask_dir and domask: self.precompute_all_masks(mask_dir) # Run this ONCE before training def precompute_all_masks(self, save_dir): os.makedirs(save_dir, exist_ok=True) for idx in tqdm(range(len(self))): img_path = os.path.join(self.root_dir,self.df.iloc[idx]['Path']) part_path="/".join(self.df.iloc[idx]['Path'].split("/")[1:]) if self.zip_reader: # Read image data from ZIP (no extraction!) img_data = self.zip_reader.read_image(part_path) # Open image from bytes in memory image = Image.open(io.BytesIO(img_data)).convert('L') else: image = Image.open(img_path).convert('L') image = np.array(image) mask = chexpert_medsam_mask(image) mask_path = os.path.join(save_dir, "_".join(self.df.iloc[idx]['Path'].split("/")[-3:]).replace('.jpg', '_mask.pt')) os.makedirs(os.path.dirname(mask_path), exist_ok=True) torch.save(mask, mask_path) @property def zip_reader(self): """ Lazy property getter for ZIP reader The ZIP file is only opened when first accessed, not during __init__. This is useful when: - Creating multiple dataset objects but only using some - Saving memory during dataset setup - Working with multiprocessing (each worker creates its own) """ if self._zip_reader_instance is None and self._zip_path is not None: self._zip_reader_instance = OptimizedZipReader( self._zip_path, cache_size=self._zip_cache_size ) return self._zip_reader_instance def _load_and_cache_image(self, img_path, idx): """ Load image with automatic resizing and caching. If resized version exists, load it. Otherwise, resize, save, and load. Args: img_path (str): Original image path from CSV idx (int): Index for tracking Returns: np.ndarray: Loaded image (grayscale) """ # Create cache directory structure cache_dir = Path(self.root_dir) #/ f"cache_{self.image_size}" # Preserve the relative path structure in cache path_parts = list(Path(img_path).parts) path_parts[-1]=f"{self.image_size}_{path_parts[-1]}" relative_path = Path(*path_parts) cached_path =relative_path.with_suffix('.jpg') # Check if cached version exists if cached_path.exists(): # Load cached image image = Image.open(cached_path).convert('L') image = np.array(image) # Verify it's the correct size if image.shape[0] == self.image_size and image.shape[1] == self.image_size: return image # Cache doesn't exist or wrong size - load original original_path = img_path image = Image.open(original_path).convert('L') # Check if original is already target size width, height = image.size if width == self.image_size and height == self.image_size: # Already correct size, just convert to array return np.array(image) # Resize image image_resized = image.resize( (self.image_size, self.image_size), Image.LANCZOS ) # Save to cache cached_path.parent.mkdir(parents=True, exist_ok=True) image_resized.save(cached_path, 'JPEG', quality=95, optimize=True) return np.array(image_resized) def _process_uncertain_labels(self): """Process uncertain labels (-1) based on the chosen strategy.""" for pathology in self.PATHOLOGIES: if pathology in self.df.columns: if self.fill_uncertain == 'zeros': # Map uncertain (-1) to negative (0) self.df[pathology] = self.df[pathology].replace(-1, 0) elif self.fill_uncertain == 'ones': # Map uncertain (-1) to positive (1) self.df[pathology] = self.df[pathology].replace(-1, 1) elif self.fill_uncertain == 'ignore': # Keep -1 as is (you'll need to handle this in loss function) pass # Fill NaN with 0 (negative) self.df[pathology] = self.df[pathology].fillna(0) def _get_train_transforms(self): """Get training augmentations suitable for chest X-rays.""" import cv2 return A.Compose([ # Resize to target size A.LongestMaxSize(max_size=self.image_size), A.PadIfNeeded(self.image_size, self.image_size, border_mode=cv2.BORDER_CONSTANT, position='center'), # Geometric augmentations (conservative for medical images) A.HorizontalFlip(p=0.5), A.Affine( translate_percent={"x": (-0.1, 0.1), "y": (-0.1, 0.1)}, scale=(0.9, 1.1), rotate=(-10, 10), fit_output=False, p=0.5 ), # Intensity augmentations A.OneOf([ A.RandomBrightnessContrast( brightness_limit=0.2, contrast_limit=0.2, p=1.0 ), A.RandomGamma(gamma_limit=(80, 120), p=1.0), A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=1.0), ], p=0.5), # Add slight blur to simulate different imaging conditions A.OneOf([ A.GaussianBlur(blur_limit=(3, 5), p=1.0), A.MedianBlur(blur_limit=3, p=1.0), ], p=0.2), # Add noise A.GaussNoise(p=0.2), # Normalize to [0, 1] A.Normalize( mean=[0.5], std=[0.5], max_pixel_value=255.0 ), ToTensorV2() ]) def _get_val_transforms(self): """Get validation/test transforms (no augmentation).""" return A.Compose([ A.LongestMaxSize(max_size=self.image_size), A.PadIfNeeded(self.image_size, self.image_size, border_mode=cv2.BORDER_CONSTANT, position='center'), A.Normalize( mean=[0.5], std=[0.5], max_pixel_value=255.0 ), ToTensorV2() ]) def __len__(self): return len(self.df) def __del__(self): """Close ZIP when done""" if hasattr(self, 'zip_reader'): self.zip_reader.close() def __getitem__(self, idx): if self.env: with self.env.begin() as txn: # Retrieve serialized data data = txn.get(str(idx).encode()) sample = pickle.loads(data) return sample else: # Get image path img_path = os.path.join(self.root_dir,self.df.iloc[idx]['Path']) #image = self._load_and_cache_image(img_path, idx) # Load image #image = Image.open(img_path).convert('L') # Convert to grayscale part_path="/".join(self.df.iloc[idx]['Path'].split("/")[1:]) if self.zip_reader: # Read image data from ZIP (no extraction!) img_data = self.zip_reader.read_image(part_path) # Open image from bytes in memory image = Image.open(io.BytesIO(img_data)).convert('L') else: image = Image.open(img_path).convert('L') image = np.array(image) # Load pre-computed mask #mask_path = os.path.join(self.mask_dir, "_".join(self.df.iloc[idx]['Path'].split("/")[-3:]).replace('.jpg', '_mask.pt')) #masked_img = torch.load(mask_path) # Apply transforms to BOTH image and mask together if self.augment: # Augmentation applies to both image and mask transformed = self.train_transform(image=image) image_transformed = transformed['image'] # (1, H, W) tensor, normalized #masked_img=transformed['mask'] # (H, W) tensor else: transformed = self.val_transform(image=image) image_transformed = transformed['image'] # (1, H, W) tensor, normalized #masked_img=transformed['mask'] # Expand dimensions to match image_1ch = image_transformed # (1, H, W) masked_img = image_transformed # Get labels for all pathologies labels = [] for pathology in self.PATHOLOGIES: if pathology in self.df.columns: label = self.df.iloc[idx][pathology] labels.append(float(label) if not pd.isna(label) else 0.0) else: labels.append(0.0) labels = torch.tensor(labels, dtype=torch.float32) # Get additional metadata metadata = { 'patient_id': self.df.iloc[idx]['Path'].split('/')[2], # Extract patient ID from path 'study_id': self.df.iloc[idx]['Path'].split('/')[3], # Extract study ID from path 'view': self.df.iloc[idx]['Frontal/Lateral'], 'sex': self.df.iloc[idx]['Sex'] if 'Sex' in self.df.columns else 'Unknown', 'age': self.df.iloc[idx]['Age'] if 'Age' in self.df.columns else -1, 'path': self.df.iloc[idx]['Path'] } return { 'image': image_1ch, 'labels': labels, 'metadata': metadata } def get_label_names(self): """Return list of pathology label names.""" return self.PATHOLOGIES def get_label_distribution(self): """Get distribution of positive labels for each pathology.""" distribution = {} for pathology in self.PATHOLOGIES: if pathology in self.df.columns: positive_count = (self.df[pathology] == 1.0).sum() distribution[pathology] = { 'positive': int(positive_count), 'percentage': round(positive_count / len(self.df) * 100, 2) } return distribution def get_class_weights(self): """ OPTIMIZED: Vectorized class weights calculation """ weights = [] for pathology in self.PATHOLOGIES: if pathology in self.df.columns: # Vectorized counting (much faster than iterating) values = self.df[pathology].values pos = np.sum(values == 1.0) neg = np.sum(values == 0.0) weight = neg / pos if pos > 0 else 1.0 weights.append(weight) return torch.tensor(weights, dtype=torch.float32) def get_sample_weights(self): """ OPTIMIZED: Vectorized sample weights calculation Performance: ~1000x faster than original Original: 15-30 seconds for 200k samples This: 0.01-0.05 seconds for 200k samples """ # Get class weights as numpy array class_weights = self.get_class_weights().numpy() # Get all labels as numpy array in ONE vectorized operation labels_array = self.df[self.PATHOLOGIES].values.astype(np.float32) # Create weighted labels matrix: where label=1, use class_weight, else -inf # Shape: (n_samples, n_classes) weighted_labels = np.where( labels_array == 1.0, class_weights, -np.inf # Use -inf instead of 0 so max will only consider positive labels ) # For each sample, find the maximum class weight of its positive labels # If a sample has no positive labels, max will be -inf, which we'll replace with 1.0 sample_weights = np.max(weighted_labels, axis=1) sample_weights = np.where( np.isinf(sample_weights), 1.0, # Samples with no positive labels get weight 1.0 sample_weights ) return torch.tensor(sample_weights, dtype=torch.float32)