"""DataLoader'ы с агрессивной аугментацией под малый датасет.""" from __future__ import annotations from pathlib import Path import cv2 import numpy as np from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler import albumentations as A from albumentations.pytorch import ToTensorV2 from . import config as C from .prepare_data import imread_unicode CLASS_TO_IDX = {"clean": 0, "defect": 1} def build_transforms(train: bool) -> A.Compose: if train: return A.Compose([ A.LongestMaxSize(max_size=C.IMG_SIZE + 32), A.PadIfNeeded(min_height=C.IMG_SIZE + 32, min_width=C.IMG_SIZE + 32, border_mode=cv2.BORDER_REFLECT_101), A.RandomCrop(height=C.IMG_SIZE, width=C.IMG_SIZE), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.RandomRotate90(p=0.5), A.OneOf([ A.RandomBrightnessContrast(brightness_limit=0.25, contrast_limit=0.25, p=1.0), A.HueSaturationValue(hue_shift_limit=8, sat_shift_limit=20, val_shift_limit=20, p=1.0), A.CLAHE(clip_limit=2.0, p=1.0), ], p=0.7), A.OneOf([ A.GaussianBlur(blur_limit=(3, 5), p=1.0), A.MotionBlur(blur_limit=5, p=1.0), A.GaussNoise(var_limit=(5.0, 25.0), p=1.0), ], p=0.4), # имитируем блики/тени из реального цеха A.RandomShadow(p=0.2), A.RandomSunFlare(src_radius=80, num_flare_circles_lower=1, num_flare_circles_upper=2, p=0.15), A.CoarseDropout(max_holes=4, max_height=48, max_width=48, p=0.3), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2(), ]) return A.Compose([ A.LongestMaxSize(max_size=C.IMG_SIZE), A.PadIfNeeded(min_height=C.IMG_SIZE, min_width=C.IMG_SIZE, border_mode=cv2.BORDER_REFLECT_101), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2(), ]) class PatchDataset(Dataset): """Каталог: //*.jpg, метки по имени папки.""" def __init__(self, root: Path, train: bool): self.samples: list[tuple[Path, int]] = [] for cls, idx in CLASS_TO_IDX.items(): for f in (root / cls).glob("*.jpg"): self.samples.append((f, idx)) if not self.samples: raise RuntimeError(f"Нет патчей в {root}. Запустите prepare_data.py") self.transform = build_transforms(train) def __len__(self) -> int: return len(self.samples) def __getitem__(self, i: int): path, label = self.samples[i] img = imread_unicode(path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = self.transform(image=img)["image"] return img, label def make_loaders(batch_size: int = C.BATCH_SIZE, num_workers: int = C.NUM_WORKERS): train_ds = PatchDataset(C.DATA_PATCHES / "train", train=True) val_ds = PatchDataset(C.DATA_PATCHES / "val", train=False) # балансировка классов через WeightedRandomSampler labels = np.array([lbl for _, lbl in train_ds.samples]) class_counts = np.bincount(labels, minlength=2).astype(np.float32) class_weights = 1.0 / np.maximum(class_counts, 1.0) sample_weights = class_weights[labels] sampler = WeightedRandomSampler(weights=sample_weights.tolist(), num_samples=len(sample_weights), replacement=True) train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=True, drop_last=False) val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) print(f"train: {len(train_ds)} (классы={class_counts.tolist()}) val: {len(val_ds)}") return train_loader, val_loader