|
|
import torch |
|
|
from torch.utils.data import DataLoader |
|
|
from torchvision import transforms |
|
|
from datasets import load_from_disk |
|
|
from collections import Counter |
|
|
|
|
|
SPLIT_DIR = "data/splits/comprehensive-car-damage_seed42_test0p2" |
|
|
BATCH_SIZE = 16 |
|
|
NUM_WORKERS = 0 |
|
|
IMG_SIZE = 224 |
|
|
|
|
|
def compute_class_weights(labels, num_classes): |
|
|
c = Counter(labels) |
|
|
total = len(labels) |
|
|
|
|
|
weights = [] |
|
|
for k in range(num_classes): |
|
|
freq = c.get(k, 1) / total |
|
|
weights.append(1.0 / freq) |
|
|
w = torch.tensor(weights, dtype=torch.float) |
|
|
|
|
|
w = w / w.mean() |
|
|
return w |
|
|
|
|
|
def main(): |
|
|
splits = load_from_disk(SPLIT_DIR) |
|
|
train_ds = splits["train"] |
|
|
val_ds = splits["val"] |
|
|
|
|
|
label_names = train_ds.features["label"].names |
|
|
num_classes = len(label_names) |
|
|
print("Classes:", label_names) |
|
|
|
|
|
train_tf = transforms.Compose([ |
|
|
transforms.Resize((IMG_SIZE, IMG_SIZE)), |
|
|
transforms.RandomHorizontalFlip(p=0.5), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
val_tf = transforms.Compose([ |
|
|
transforms.Resize((IMG_SIZE, IMG_SIZE)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
def transform_batch(examples, tf): |
|
|
images = [tf(img.convert("RGB")) for img in examples["image"]] |
|
|
labels = torch.tensor(examples["label"], dtype=torch.long) |
|
|
return {"pixel_values": torch.stack(images), "labels": labels} |
|
|
|
|
|
def collate_train(batch): |
|
|
|
|
|
imgs = [row["image"] for row in batch] |
|
|
labels = [row["label"] for row in batch] |
|
|
return transform_batch({"image": imgs, "label": labels}, train_tf) |
|
|
|
|
|
def collate_val(batch): |
|
|
imgs = [row["image"] for row in batch] |
|
|
labels = [row["label"] for row in batch] |
|
|
return transform_batch({"image": imgs, "label": labels}, val_tf) |
|
|
|
|
|
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, |
|
|
num_workers=NUM_WORKERS, collate_fn=collate_train) |
|
|
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, |
|
|
num_workers=NUM_WORKERS, collate_fn=collate_val) |
|
|
|
|
|
|
|
|
batch = next(iter(train_loader)) |
|
|
print("Batch keys:", batch.keys()) |
|
|
print("pixel_values shape:", batch["pixel_values"].shape) |
|
|
print("labels shape:", batch["labels"].shape) |
|
|
print("labels sample:", batch["labels"][:8].tolist()) |
|
|
print("labels sample names:", [label_names[i] for i in batch["labels"][:8].tolist()]) |
|
|
|
|
|
|
|
|
w = compute_class_weights(train_ds["label"], num_classes) |
|
|
print("Class weights:", {label_names[i]: float(w[i]) for i in range(num_classes)}) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|