| import os |
| import pandas as pd |
| import numpy as np |
| from torch.utils.data import Dataset |
| from PIL import Image |
|
|
| class ImageData(Dataset): |
| def __init__(self, img_dir, annotation_file, validation_set, transform=None): |
| """ |
| Custom Dataset that respects the 'validation_set' column in the CSV. |
| 0 = Training Set |
| 1 = Validation Set |
| """ |
| |
| try: |
| gt = pd.read_csv(annotation_file) |
| except Exception as e: |
| print(f"Error reading CSV {annotation_file}: {e}") |
| |
| self.img_labels = pd.DataFrame() |
| self.img_dir = img_dir |
| self.transform = transform |
| self.images = [] |
| self.labels = [] |
| return |
|
|
| |
| if validation_set: |
| self.img_labels = gt[gt["validation_set"] == 1].reset_index(drop=True) |
| else: |
| self.img_labels = gt[gt["validation_set"] == 0].reset_index(drop=True) |
| |
| self.img_dir = img_dir |
| self.transform = transform |
| |
| |
| self.images = self.img_labels["file_name"].values |
| self.labels = self.img_labels["category_id"].values |
| |
| def __len__(self): |
| return len(self.img_labels) |
|
|
| def __getitem__(self, idx): |
| img_name = self.images[idx] |
| img_path = os.path.join(self.img_dir, img_name) |
| |
| |
| image = Image.open(img_path).convert("RGB") |
| |
| label = self.labels[idx] |
| |
| if self.transform: |
| image = self.transform(image) |
|
|
| |
| return image, int(label) |