import os import pandas as pd import numpy as np from torch.utils.data import Dataset from PIL import Image class ImageData(Dataset): def __init__(self, img_dir, annotation_file, validation_set, transform=None): """ Custom Dataset that respects the 'validation_set' column in the CSV. 0 = Training Set 1 = Validation Set """ # Read the CSV file try: gt = pd.read_csv(annotation_file) except Exception as e: print(f"Error reading CSV {annotation_file}: {e}") # Return empty if failed, to prevent crash during init self.img_labels = pd.DataFrame() self.img_dir = img_dir self.transform = transform self.images = [] self.labels = [] return # Filter: 0 = Train, 1 = Validation if validation_set: self.img_labels = gt[gt["validation_set"] == 1].reset_index(drop=True) else: self.img_labels = gt[gt["validation_set"] == 0].reset_index(drop=True) self.img_dir = img_dir self.transform = transform # Store filenames and labels self.images = self.img_labels["file_name"].values self.labels = self.img_labels["category_id"].values def __len__(self): return len(self.img_labels) def __getitem__(self, idx): img_name = self.images[idx] img_path = os.path.join(self.img_dir, img_name) # CRITICAL: Open in RGB mode. OpenCV loads BGR by default, but PIL is safer here. image = Image.open(img_path).convert("RGB") label = self.labels[idx] if self.transform: image = self.transform(image) # Return image and label (as long/int for CrossEntropyLoss) return image, int(label)