| import os
|
| import json
|
| import torch
|
| from torch.utils.data import Dataset, DataLoader
|
| from PIL import Image
|
| import torchvision.transforms as transforms
|
|
|
| class NudeMultiLabelDataset(Dataset):
|
| def __init__(self, data_dir, label_file, transform=None):
|
| self.data_dir = data_dir
|
| self.transform = transform
|
| self.label_file = label_file
|
|
|
|
|
| with open(label_file, "r") as f:
|
| self.labels = json.load(f)
|
|
|
| self.image_paths = list(self.labels.keys())
|
| self.classes = sorted(set(tag for tags in self.labels.values() for tag in tags))
|
| self.class_to_idx = {tag: idx for idx, tag in enumerate(self.classes)}
|
|
|
|
|
| print(f"π Dataset loaded from: {data_dir}")
|
| print(f"π Labels loaded from: {label_file}")
|
| print(f"πΌοΈ Total images: {len(self.image_paths)}")
|
| print(f"π·οΈ Unique labels: {len(self.classes)}")
|
| print(f"πΉ Label-to-Index Mapping: {self.class_to_idx}")
|
|
|
|
|
| if self.image_paths:
|
| example_img, example_label = self.__getitem__(0)
|
| print(f"β
Example Image Shape: {example_img.shape}")
|
| print(f"β
Example Label: {example_label}")
|
|
|
| def __len__(self):
|
| return len(self.image_paths)
|
|
|
| def __getitem__(self, idx):
|
| img_name = self.image_paths[idx]
|
| img_path = os.path.join(self.data_dir, img_name)
|
| image = Image.open(img_path).convert("RGB")
|
|
|
|
|
| labels = self.labels[img_name]
|
| label_tensor = torch.zeros(len(self.classes))
|
| for tag in labels:
|
| if tag in self.class_to_idx:
|
| label_tensor[self.class_to_idx[tag]] = 1
|
|
|
| if self.transform:
|
| image = self.transform(image)
|
|
|
| return image, label_tensor
|
|
|
|
|
| if __name__ == "__main__":
|
|
|
| DATA_DIR = "../data/images"
|
| LABEL_FILE = "../data/labels.json"
|
|
|
|
|
| transform = transforms.Compose([
|
| transforms.Resize((224, 224)),
|
| transforms.ToTensor(),
|
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| ])
|
|
|
|
|
| dataset = NudeMultiLabelDataset(DATA_DIR, LABEL_FILE, transform=transform)
|
|
|
|
|
| dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
|
|
|
|
|
| for images, labels in dataloader:
|
| print(f"πΌοΈ Batch Image Shape: {images.shape}")
|
| print(f"π·οΈ Batch Labels: {labels}")
|
| break |