File size: 2,889 Bytes
af59080 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | import os
import json
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
class NudeMultiLabelDataset(Dataset):
def __init__(self, data_dir, label_file, transform=None):
self.data_dir = data_dir
self.transform = transform
self.label_file = label_file
# Load labels
with open(label_file, "r") as f:
self.labels = json.load(f)
self.image_paths = list(self.labels.keys())
self.classes = sorted(set(tag for tags in self.labels.values() for tag in tags))
self.class_to_idx = {tag: idx for idx, tag in enumerate(self.classes)}
# Print dataset info
print(f"๐ Dataset loaded from: {data_dir}")
print(f"๐ Labels loaded from: {label_file}")
print(f"๐ผ๏ธ Total images: {len(self.image_paths)}")
print(f"๐ท๏ธ Unique labels: {len(self.classes)}")
print(f"๐น Label-to-Index Mapping: {self.class_to_idx}")
# Print example data
if self.image_paths:
example_img, example_label = self.__getitem__(0)
print(f"โ
Example Image Shape: {example_img.shape}")
print(f"โ
Example Label: {example_label}")
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_name = self.image_paths[idx]
img_path = os.path.join(self.data_dir, img_name)
image = Image.open(img_path).convert("RGB")
# Convert labels to multi-hot encoding
labels = self.labels[img_name]
label_tensor = torch.zeros(len(self.classes))
for tag in labels:
if tag in self.class_to_idx:
label_tensor[self.class_to_idx[tag]] = 1 # Multi-label
if self.transform:
image = self.transform(image)
return image, label_tensor
# ๐น Main function to test the dataset independently
if __name__ == "__main__":
# Set paths
DATA_DIR = "../data/images" # Change to actual path
LABEL_FILE = "../data/labels.json" # Change to actual path
# Define transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Load dataset
dataset = NudeMultiLabelDataset(DATA_DIR, LABEL_FILE, transform=transform)
# Create DataLoader for testing
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
# Fetch one batch and print information
for images, labels in dataloader:
print(f"๐ผ๏ธ Batch Image Shape: {images.shape}") # Should be [batch_size, 3, 224, 224]
print(f"๐ท๏ธ Batch Labels: {labels}")
break # Stop after one batch |