File size: 2,889 Bytes
af59080
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import json
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms

class NudeMultiLabelDataset(Dataset):
    def __init__(self, data_dir, label_file, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.label_file = label_file

        # Load labels
        with open(label_file, "r") as f:
            self.labels = json.load(f)

        self.image_paths = list(self.labels.keys())
        self.classes = sorted(set(tag for tags in self.labels.values() for tag in tags))
        self.class_to_idx = {tag: idx for idx, tag in enumerate(self.classes)}

        # Print dataset info
        print(f"๐Ÿ“‚ Dataset loaded from: {data_dir}")
        print(f"๐Ÿ“„ Labels loaded from: {label_file}")
        print(f"๐Ÿ–ผ๏ธ Total images: {len(self.image_paths)}")
        print(f"๐Ÿท๏ธ Unique labels: {len(self.classes)}")
        print(f"๐Ÿ”น Label-to-Index Mapping: {self.class_to_idx}")

        # Print example data
        if self.image_paths:
            example_img, example_label = self.__getitem__(0)
            print(f"โœ… Example Image Shape: {example_img.shape}")
            print(f"โœ… Example Label: {example_label}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_name = self.image_paths[idx]
        img_path = os.path.join(self.data_dir, img_name)
        image = Image.open(img_path).convert("RGB")

        # Convert labels to multi-hot encoding
        labels = self.labels[img_name]
        label_tensor = torch.zeros(len(self.classes))
        for tag in labels:
            if tag in self.class_to_idx:
                label_tensor[self.class_to_idx[tag]] = 1  # Multi-label

        if self.transform:
            image = self.transform(image)

        return image, label_tensor

# ๐Ÿ”น Main function to test the dataset independently
if __name__ == "__main__":
    # Set paths
    DATA_DIR = "../data/images"   # Change to actual path
    LABEL_FILE = "../data/labels.json"   # Change to actual path

    # Define transformations
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Load dataset
    dataset = NudeMultiLabelDataset(DATA_DIR, LABEL_FILE, transform=transform)
    
    # Create DataLoader for testing
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

    # Fetch one batch and print information
    for images, labels in dataloader:
        print(f"๐Ÿ–ผ๏ธ Batch Image Shape: {images.shape}")  # Should be [batch_size, 3, 224, 224]
        print(f"๐Ÿท๏ธ Batch Labels: {labels}")
        break  # Stop after one batch