vehicle-damage-classifier / src /step6_dataloaders.py
efnanaladagg's picture
Clean push
6f6eb85
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_from_disk
from collections import Counter
SPLIT_DIR = "data/splits/comprehensive-car-damage_seed42_test0p2"
BATCH_SIZE = 16
NUM_WORKERS = 0 # For Windows compatibility; set higher for Linux/Mac
IMG_SIZE = 224
def compute_class_weights(labels, num_classes):
c = Counter(labels)
total = len(labels)
# simple inverse frequency weighting
weights = []
for k in range(num_classes):
freq = c.get(k, 1) / total
weights.append(1.0 / freq)
w = torch.tensor(weights, dtype=torch.float)
# normalize (optional)
w = w / w.mean()
return w
def main():
splits = load_from_disk(SPLIT_DIR)
train_ds = splits["train"]
val_ds = splits["val"]
label_names = train_ds.features["label"].names
num_classes = len(label_names)
print("Classes:", label_names)
train_tf = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
val_tf = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def transform_batch(examples, tf):
images = [tf(img.convert("RGB")) for img in examples["image"]]
labels = torch.tensor(examples["label"], dtype=torch.long)
return {"pixel_values": torch.stack(images), "labels": labels}
def collate_train(batch):
# batch: list of dicts from HF dataset rows
imgs = [row["image"] for row in batch]
labels = [row["label"] for row in batch]
return transform_batch({"image": imgs, "label": labels}, train_tf)
def collate_val(batch):
imgs = [row["image"] for row in batch]
labels = [row["label"] for row in batch]
return transform_batch({"image": imgs, "label": labels}, val_tf)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
num_workers=NUM_WORKERS, collate_fn=collate_train)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, collate_fn=collate_val)
# sanity check: one batch
batch = next(iter(train_loader))
print("Batch keys:", batch.keys())
print("pixel_values shape:", batch["pixel_values"].shape) # (B, C, H, W)
print("labels shape:", batch["labels"].shape)
print("labels sample:", batch["labels"][:8].tolist())
print("labels sample names:", [label_names[i] for i in batch["labels"][:8].tolist()])
# class weights (train)
w = compute_class_weights(train_ds["label"], num_classes)
print("Class weights:", {label_names[i]: float(w[i]) for i in range(num_classes)})
if __name__ == "__main__":
main()