File size: 10,106 Bytes
a6eed2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import sys
from pathlib import Path
# tambahkan parent project ke sys.path sehingga 'src' dapat diimport saat menjalankan skrip langsung
sys.path.append(str(Path(__file__).resolve().parents[1]))

import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset, random_split, WeightedRandomSampler
from torchvision import datasets, transforms
from src import config  # Mengimpor dari file config.py Anda
import matplotlib.pyplot as plt
import warnings
from pathlib import Path

# --- 1. Mendefinisikan Transformasi (Augmentasi) ---

# Statistik ImageNet untuk normalisasi (penting untuk model pre-trained)
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

# Transformasi untuk data TRAINING
# Tujuannya: "menyiksa" data agar model bisa generalisasi dengan teknik terbaru
train_transform = transforms.Compose([
    transforms.Resize((config.IMAGE_SIZE + 32, config.IMAGE_SIZE + 32)),  # Resize lebih besar dulu
    transforms.RandomCrop((config.IMAGE_SIZE, config.IMAGE_SIZE), padding=4),  # Random crop dengan padding
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.2),  # Tambah vertical flip
    transforms.RandomRotation(degrees=15),  # Moderate rotation
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),  # Moderate color augmentation
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=5),  # Enhanced geometric augmentation
    
    # Advanced augmentations
    transforms.RandomPerspective(distortion_scale=0.2, p=0.3),  # Perspective distortion
    transforms.RandomErasing(p=0.2, scale=(0.02, 0.33), ratio=(0.3, 3.3)),  # Random erasing

    # --- TAMBAHKAN INI ---
    # Ini akan menerapkan augmentasi acak yang kuat
    transforms.TrivialAugmentWide(num_magnitude_bins=31), 
    # ---------------------

    transforms.ToTensor(), # ToTensor() HARUS setelah augmentasi
    transforms.Normalize(mean=MEAN, std=STD)
])
# Transformasi untuk data VALIDASI
# Tujuannya: Hanya membersihkan data untuk evaluasi, TANPA augmentasi acak
val_transform = transforms.Compose([
    transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)), # Ukuran seragam
    transforms.ToTensor(), # Konversi ke tensor PyTorch
    transforms.Normalize(mean=MEAN, std=STD) # Normalisasi
])


# --- 2. Helper Class untuk Menerapkan Transformasi Berbeda ---
# INI PENTING:
# Kita perlu membagi dataset (split) SEBELUM menerapkan augmentasi.
# Helper class ini memungkinkan kita menerapkan transform yang berbeda (train/val)
# pada dataset subset yang sudah dibagi.

class TransformedDataset(Dataset):
    """Wrapper Dataset untuk menerapkan transformasi ke Subset."""
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform
        
    def __getitem__(self, index):
        # Ambil data asli (gambar, label) dari subset
        try:
            x, y = self.subset[index]
            
            # Terapkan transformasi jika ada
            if self.transform:
                x = self.transform(x)
                
            return x, y
        except Exception as e:
            # Jika ada error (file rusak), coba index berikutnya
            print(f"[Warning] Error pada index {index}: {e}")
            # Coba index berikutnya (dengan wraparound)
            next_index = (index + 1) % len(self.subset)
            return self.__getitem__(next_index)
    
    def __len__(self):
        return len(self.subset)

# --- 3. Fungsi Utama Pembuat DataLoader ---

def create_dataloaders():
    """

    Fungsi utama untuk membuat dan mengembalikan data loader

    untuk training dan validasi.

    """
    # --- VALIDASI: Pastikan config.DATA_PATH ada, coba beberapa alternatif jika tidak ---
    data_path = Path(config.DATA_PATH)
    if not data_path.exists():
        project_root = Path(__file__).resolve().parents[1]
        alt_names = ["Batik_Indonesia_JPG", "Batik-Indonesia", "Batik_Indonesia", "data", "dataset"]
        found = None
        for name in alt_names:
            candidate = project_root / name
            if candidate.exists() and candidate.is_dir():
                found = candidate
                break
        if found:
            print(f"[Data] config.DATA_PATH '{config.DATA_PATH}' tidak ditemukan. Menggunakan alternatif: {found}")
            # update atribut di module config agar konsisten
            try:
                config.DATA_PATH = str(found)
            except Exception:
                pass
            data_path = found
        else:
            raise FileNotFoundError(
                f"config.DATA_PATH='{config.DATA_PATH}' tidak ditemukan. "
                f"Pastikan folder dataset ada atau set config.DATA_PATH ke path yang benar."
            )

    # --- LANGKAH A: Muat Dataset Induk ---
    print(f"[Data] Memuat dataset induk dari: {data_path}")
    full_dataset = datasets.ImageFolder(str(data_path))
    
    # Simpan nama kelas
    class_names = full_dataset.classes
    num_classes = len(class_names)
    print(f"[Data] Ditemukan {num_classes} kelas: {class_names}")

    # --- LANGKAH B: Bagi Dataset 80:20 (Secara Hati-hati) ---
    print(f"[Data] Membagi dataset 80:20 (seed: {config.RANDOM_SEED})...")
    total_size = len(full_dataset)
    val_size = int(total_size * config.TEST_SPLIT_SIZE)
    train_size = total_size - val_size
    
    # Bagi dataset menggunakan random_split dengan SEED yang tetap
    # Ini memastikan pembagian data SELALU SAMA setiap kali skrip dijalankan
    train_dataset_raw, val_dataset_raw = random_split(
        full_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(config.RANDOM_SEED)
    )
    
    print(f"[Data] Ukuran Train: {len(train_dataset_raw)} | Ukuran Validasi: {len(val_dataset_raw)}")

    # --- LANGKAH C: Terapkan Transformasi yang Berbeda ---
    train_dataset = TransformedDataset(train_dataset_raw, transform=train_transform)
    val_dataset = TransformedDataset(val_dataset_raw, transform=val_transform)

    # --- LANGKAH D: Mengatasi Ketidakseimbangan Kelas (Wajib!) ---
    print("[Data] Menghitung bobot untuk mengatasi ketidakseimbangan kelas...")
    
    # 1. Ambil semua label (target) HANYA dari set training
    train_targets = [full_dataset.targets[i] for i in train_dataset_raw.indices]
    
    # 2. Hitung jumlah gambar per kelas
    #    Kita gunakan bincount untuk efisiensi
    class_counts = np.bincount(train_targets)
    
    # 3. Hitung bobot kebalikan (inverse weight) untuk setiap kelas
    #    Kelas langka -> bobot tinggi
    #    Kelas umum -> bobot rendah
    class_weights = 1.0 / torch.tensor(class_counts, dtype=torch.float)
    
    # 4. Buat daftar bobot untuk SETIAP sampel di set training
    #    Setiap sampel akan memiliki bobot sesuai kelasnya
    sample_weights = class_weights[train_targets]
    
    # 5. Buat Sampler
    #    WeightedRandomSampler akan mengambil data berdasarkan bobot ini
    train_sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True # Izinkan pengambilan sampel berulang (oversampling)
    )
    
    print("[Data] WeightedRandomSampler berhasil dibuat.")

    # --- LANGKAH E: Buat DataLoaders ---
    
    # DataLoader untuk Training
    # PENTING: Jika menggunakan 'sampler', 'shuffle' HARUS False.
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.BATCH_SIZE,
        sampler=train_sampler,
        num_workers=2, # Disable multiprocessing untuk Windows
        pin_memory=False, # Disable untuk CPU training
        shuffle=False 
    )
    
    # DataLoader untuk Validasi
    # Tidak perlu sampler, tidak perlu shuffle (evaluasi harus konsisten)
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.BATCH_SIZE,
        num_workers=2, # Disable multiprocessing untuk Windows
        pin_memory=False, # Disable untuk CPU training
        shuffle=False 
    )
    
    print("[Data] Data loader untuk Train dan Validasi siap.")
    
    return train_loader, val_loader, class_names


# --- 5. Blok Pengujian (Opsional tapi Sangat Direkomendasikan) ---
# Kode ini HANYA akan berjalan jika Anda menjalankan file ini secara langsung
# (misal: `python src/data_loader.py`)
# Ini sangat berguna untuk memverifikasi bahwa loader Anda berfungsi.

if __name__ == "__main__":
    print("Menjalankan pengujian data_loader.py...")
    
    # Coba buat data loader
    train_loader, val_loader, class_names = create_dataloaders()
    
    print(f"\nTotal kelas: {len(class_names)}")
    
    # Ambil satu batch dari train_loader
    print("\nMengambil 1 batch dari train_loader (untuk tes)...")
    with warnings.catch_warnings():
        warnings.simplefilter("ignore") # Abaikan peringatan UserWarning dari matplotlib
        
        try:
            images, labels = next(iter(train_loader))
            
            print(f"  > Ukuran batch gambar: {images.shape}") # [Batch, Channel, H, W]
            print(f"  > Ukuran batch label: {labels.shape}")
            print(f"  > Contoh 5 label di batch ini: {labels[:5]}")
            
            # Coba visualisasikan 1 gambar (untuk cek normalisasi)
            img_to_show = images[0].permute(1, 2, 0).numpy() # Ubah (C, H, W) -> (H, W, C)
            # Denormalisasi (penting untuk visualisasi)
            img_to_show = STD * img_to_show + MEAN 
            img_to_show = np.clip(img_to_show, 0, 1) # Pastikan nilai antara 0 dan 1
            
            plt.imshow(img_to_show)
            plt.title(f"Contoh Gambar (Label: {class_names[labels[0]]})")
            plt.axis('off')
            plt.show()
            
            print("\n[Sukses] data_loader.py berfungsi dengan baik!")
            
        except Exception as e:
            print(f"\n[Error] Gagal menguji data loader: {e}")