Decoder24's picture
Upload folder using huggingface_hub
a6eed2b verified
import sys
from pathlib import Path
# tambahkan parent project ke sys.path sehingga 'src' dapat diimport saat menjalankan skrip langsung
sys.path.append(str(Path(__file__).resolve().parents[1]))
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset, random_split, WeightedRandomSampler
from torchvision import datasets, transforms
from src import config # Mengimpor dari file config.py Anda
import matplotlib.pyplot as plt
import warnings
from pathlib import Path
# --- 1. Mendefinisikan Transformasi (Augmentasi) ---
# Statistik ImageNet untuk normalisasi (penting untuk model pre-trained)
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
# Transformasi untuk data TRAINING
# Tujuannya: "menyiksa" data agar model bisa generalisasi dengan teknik terbaru
train_transform = transforms.Compose([
transforms.Resize((config.IMAGE_SIZE + 32, config.IMAGE_SIZE + 32)), # Resize lebih besar dulu
transforms.RandomCrop((config.IMAGE_SIZE, config.IMAGE_SIZE), padding=4), # Random crop dengan padding
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.2), # Tambah vertical flip
transforms.RandomRotation(degrees=15), # Moderate rotation
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05), # Moderate color augmentation
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=5), # Enhanced geometric augmentation
# Advanced augmentations
transforms.RandomPerspective(distortion_scale=0.2, p=0.3), # Perspective distortion
transforms.RandomErasing(p=0.2, scale=(0.02, 0.33), ratio=(0.3, 3.3)), # Random erasing
# --- TAMBAHKAN INI ---
# Ini akan menerapkan augmentasi acak yang kuat
transforms.TrivialAugmentWide(num_magnitude_bins=31),
# ---------------------
transforms.ToTensor(), # ToTensor() HARUS setelah augmentasi
transforms.Normalize(mean=MEAN, std=STD)
])
# Transformasi untuk data VALIDASI
# Tujuannya: Hanya membersihkan data untuk evaluasi, TANPA augmentasi acak
val_transform = transforms.Compose([
transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)), # Ukuran seragam
transforms.ToTensor(), # Konversi ke tensor PyTorch
transforms.Normalize(mean=MEAN, std=STD) # Normalisasi
])
# --- 2. Helper Class untuk Menerapkan Transformasi Berbeda ---
# INI PENTING:
# Kita perlu membagi dataset (split) SEBELUM menerapkan augmentasi.
# Helper class ini memungkinkan kita menerapkan transform yang berbeda (train/val)
# pada dataset subset yang sudah dibagi.
class TransformedDataset(Dataset):
"""Wrapper Dataset untuk menerapkan transformasi ke Subset."""
def __init__(self, subset, transform=None):
self.subset = subset
self.transform = transform
def __getitem__(self, index):
# Ambil data asli (gambar, label) dari subset
try:
x, y = self.subset[index]
# Terapkan transformasi jika ada
if self.transform:
x = self.transform(x)
return x, y
except Exception as e:
# Jika ada error (file rusak), coba index berikutnya
print(f"[Warning] Error pada index {index}: {e}")
# Coba index berikutnya (dengan wraparound)
next_index = (index + 1) % len(self.subset)
return self.__getitem__(next_index)
def __len__(self):
return len(self.subset)
# --- 3. Fungsi Utama Pembuat DataLoader ---
def create_dataloaders():
"""
Fungsi utama untuk membuat dan mengembalikan data loader
untuk training dan validasi.
"""
# --- VALIDASI: Pastikan config.DATA_PATH ada, coba beberapa alternatif jika tidak ---
data_path = Path(config.DATA_PATH)
if not data_path.exists():
project_root = Path(__file__).resolve().parents[1]
alt_names = ["Batik_Indonesia_JPG", "Batik-Indonesia", "Batik_Indonesia", "data", "dataset"]
found = None
for name in alt_names:
candidate = project_root / name
if candidate.exists() and candidate.is_dir():
found = candidate
break
if found:
print(f"[Data] config.DATA_PATH '{config.DATA_PATH}' tidak ditemukan. Menggunakan alternatif: {found}")
# update atribut di module config agar konsisten
try:
config.DATA_PATH = str(found)
except Exception:
pass
data_path = found
else:
raise FileNotFoundError(
f"config.DATA_PATH='{config.DATA_PATH}' tidak ditemukan. "
f"Pastikan folder dataset ada atau set config.DATA_PATH ke path yang benar."
)
# --- LANGKAH A: Muat Dataset Induk ---
print(f"[Data] Memuat dataset induk dari: {data_path}")
full_dataset = datasets.ImageFolder(str(data_path))
# Simpan nama kelas
class_names = full_dataset.classes
num_classes = len(class_names)
print(f"[Data] Ditemukan {num_classes} kelas: {class_names}")
# --- LANGKAH B: Bagi Dataset 80:20 (Secara Hati-hati) ---
print(f"[Data] Membagi dataset 80:20 (seed: {config.RANDOM_SEED})...")
total_size = len(full_dataset)
val_size = int(total_size * config.TEST_SPLIT_SIZE)
train_size = total_size - val_size
# Bagi dataset menggunakan random_split dengan SEED yang tetap
# Ini memastikan pembagian data SELALU SAMA setiap kali skrip dijalankan
train_dataset_raw, val_dataset_raw = random_split(
full_dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(config.RANDOM_SEED)
)
print(f"[Data] Ukuran Train: {len(train_dataset_raw)} | Ukuran Validasi: {len(val_dataset_raw)}")
# --- LANGKAH C: Terapkan Transformasi yang Berbeda ---
train_dataset = TransformedDataset(train_dataset_raw, transform=train_transform)
val_dataset = TransformedDataset(val_dataset_raw, transform=val_transform)
# --- LANGKAH D: Mengatasi Ketidakseimbangan Kelas (Wajib!) ---
print("[Data] Menghitung bobot untuk mengatasi ketidakseimbangan kelas...")
# 1. Ambil semua label (target) HANYA dari set training
train_targets = [full_dataset.targets[i] for i in train_dataset_raw.indices]
# 2. Hitung jumlah gambar per kelas
# Kita gunakan bincount untuk efisiensi
class_counts = np.bincount(train_targets)
# 3. Hitung bobot kebalikan (inverse weight) untuk setiap kelas
# Kelas langka -> bobot tinggi
# Kelas umum -> bobot rendah
class_weights = 1.0 / torch.tensor(class_counts, dtype=torch.float)
# 4. Buat daftar bobot untuk SETIAP sampel di set training
# Setiap sampel akan memiliki bobot sesuai kelasnya
sample_weights = class_weights[train_targets]
# 5. Buat Sampler
# WeightedRandomSampler akan mengambil data berdasarkan bobot ini
train_sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True # Izinkan pengambilan sampel berulang (oversampling)
)
print("[Data] WeightedRandomSampler berhasil dibuat.")
# --- LANGKAH E: Buat DataLoaders ---
# DataLoader untuk Training
# PENTING: Jika menggunakan 'sampler', 'shuffle' HARUS False.
train_loader = DataLoader(
train_dataset,
batch_size=config.BATCH_SIZE,
sampler=train_sampler,
num_workers=2, # Disable multiprocessing untuk Windows
pin_memory=False, # Disable untuk CPU training
shuffle=False
)
# DataLoader untuk Validasi
# Tidak perlu sampler, tidak perlu shuffle (evaluasi harus konsisten)
val_loader = DataLoader(
val_dataset,
batch_size=config.BATCH_SIZE,
num_workers=2, # Disable multiprocessing untuk Windows
pin_memory=False, # Disable untuk CPU training
shuffle=False
)
print("[Data] Data loader untuk Train dan Validasi siap.")
return train_loader, val_loader, class_names
# --- 5. Blok Pengujian (Opsional tapi Sangat Direkomendasikan) ---
# Kode ini HANYA akan berjalan jika Anda menjalankan file ini secara langsung
# (misal: `python src/data_loader.py`)
# Ini sangat berguna untuk memverifikasi bahwa loader Anda berfungsi.
if __name__ == "__main__":
print("Menjalankan pengujian data_loader.py...")
# Coba buat data loader
train_loader, val_loader, class_names = create_dataloaders()
print(f"\nTotal kelas: {len(class_names)}")
# Ambil satu batch dari train_loader
print("\nMengambil 1 batch dari train_loader (untuk tes)...")
with warnings.catch_warnings():
warnings.simplefilter("ignore") # Abaikan peringatan UserWarning dari matplotlib
try:
images, labels = next(iter(train_loader))
print(f" > Ukuran batch gambar: {images.shape}") # [Batch, Channel, H, W]
print(f" > Ukuran batch label: {labels.shape}")
print(f" > Contoh 5 label di batch ini: {labels[:5]}")
# Coba visualisasikan 1 gambar (untuk cek normalisasi)
img_to_show = images[0].permute(1, 2, 0).numpy() # Ubah (C, H, W) -> (H, W, C)
# Denormalisasi (penting untuk visualisasi)
img_to_show = STD * img_to_show + MEAN
img_to_show = np.clip(img_to_show, 0, 1) # Pastikan nilai antara 0 dan 1
plt.imshow(img_to_show)
plt.title(f"Contoh Gambar (Label: {class_names[labels[0]]})")
plt.axis('off')
plt.show()
print("\n[Sukses] data_loader.py berfungsi dengan baik!")
except Exception as e:
print(f"\n[Error] Gagal menguji data loader: {e}")