import torch from torch.utils.data import DataLoader from torchvision import transforms from datasets import load_from_disk from collections import Counter SPLIT_DIR = "data/splits/comprehensive-car-damage_seed42_test0p2" BATCH_SIZE = 16 NUM_WORKERS = 0 # For Windows compatibility; set higher for Linux/Mac IMG_SIZE = 224 def compute_class_weights(labels, num_classes): c = Counter(labels) total = len(labels) # simple inverse frequency weighting weights = [] for k in range(num_classes): freq = c.get(k, 1) / total weights.append(1.0 / freq) w = torch.tensor(weights, dtype=torch.float) # normalize (optional) w = w / w.mean() return w def main(): splits = load_from_disk(SPLIT_DIR) train_ds = splits["train"] val_ds = splits["val"] label_names = train_ds.features["label"].names num_classes = len(label_names) print("Classes:", label_names) train_tf = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) val_tf = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def transform_batch(examples, tf): images = [tf(img.convert("RGB")) for img in examples["image"]] labels = torch.tensor(examples["label"], dtype=torch.long) return {"pixel_values": torch.stack(images), "labels": labels} def collate_train(batch): # batch: list of dicts from HF dataset rows imgs = [row["image"] for row in batch] labels = [row["label"] for row in batch] return transform_batch({"image": imgs, "label": labels}, train_tf) def collate_val(batch): imgs = [row["image"] for row in batch] labels = [row["label"] for row in batch] return transform_batch({"image": imgs, "label": labels}, val_tf) train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, collate_fn=collate_train) val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, collate_fn=collate_val) # sanity check: one batch batch = next(iter(train_loader)) print("Batch keys:", batch.keys()) print("pixel_values shape:", batch["pixel_values"].shape) # (B, C, H, W) print("labels shape:", batch["labels"].shape) print("labels sample:", batch["labels"][:8].tolist()) print("labels sample names:", [label_names[i] for i in batch["labels"][:8].tolist()]) # class weights (train) w = compute_class_weights(train_ds["label"], num_classes) print("Class weights:", {label_names[i]: float(w[i]) for i in range(num_classes)}) if __name__ == "__main__": main()