| """DataLoader'ы с агрессивной аугментацией под малый датасет.""" |
| from __future__ import annotations |
| from pathlib import Path |
|
|
| import cv2 |
| import numpy as np |
| from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler |
| import albumentations as A |
| from albumentations.pytorch import ToTensorV2 |
|
|
| from . import config as C |
| from .prepare_data import imread_unicode |
|
|
|
|
| CLASS_TO_IDX = {"clean": 0, "defect": 1} |
|
|
|
|
| def build_transforms(train: bool) -> A.Compose: |
| if train: |
| return A.Compose([ |
| A.LongestMaxSize(max_size=C.IMG_SIZE + 32), |
| A.PadIfNeeded(min_height=C.IMG_SIZE + 32, min_width=C.IMG_SIZE + 32, |
| border_mode=cv2.BORDER_REFLECT_101), |
| A.RandomCrop(height=C.IMG_SIZE, width=C.IMG_SIZE), |
| A.HorizontalFlip(p=0.5), |
| A.VerticalFlip(p=0.5), |
| A.RandomRotate90(p=0.5), |
| A.OneOf([ |
| A.RandomBrightnessContrast(brightness_limit=0.25, contrast_limit=0.25, p=1.0), |
| A.HueSaturationValue(hue_shift_limit=8, sat_shift_limit=20, val_shift_limit=20, p=1.0), |
| A.CLAHE(clip_limit=2.0, p=1.0), |
| ], p=0.7), |
| A.OneOf([ |
| A.GaussianBlur(blur_limit=(3, 5), p=1.0), |
| A.MotionBlur(blur_limit=5, p=1.0), |
| A.GaussNoise(var_limit=(5.0, 25.0), p=1.0), |
| ], p=0.4), |
| |
| A.RandomShadow(p=0.2), |
| A.RandomSunFlare(src_radius=80, num_flare_circles_lower=1, |
| num_flare_circles_upper=2, p=0.15), |
| A.CoarseDropout(max_holes=4, max_height=48, max_width=48, p=0.3), |
| A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
| ToTensorV2(), |
| ]) |
| return A.Compose([ |
| A.LongestMaxSize(max_size=C.IMG_SIZE), |
| A.PadIfNeeded(min_height=C.IMG_SIZE, min_width=C.IMG_SIZE, |
| border_mode=cv2.BORDER_REFLECT_101), |
| A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
| ToTensorV2(), |
| ]) |
|
|
|
|
| class PatchDataset(Dataset): |
| """Каталог: <root>/<class>/*.jpg, метки по имени папки.""" |
|
|
| def __init__(self, root: Path, train: bool): |
| self.samples: list[tuple[Path, int]] = [] |
| for cls, idx in CLASS_TO_IDX.items(): |
| for f in (root / cls).glob("*.jpg"): |
| self.samples.append((f, idx)) |
| if not self.samples: |
| raise RuntimeError(f"Нет патчей в {root}. Запустите prepare_data.py") |
| self.transform = build_transforms(train) |
|
|
| def __len__(self) -> int: |
| return len(self.samples) |
|
|
| def __getitem__(self, i: int): |
| path, label = self.samples[i] |
| img = imread_unicode(path) |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| img = self.transform(image=img)["image"] |
| return img, label |
|
|
|
|
| def make_loaders(batch_size: int = C.BATCH_SIZE, num_workers: int = C.NUM_WORKERS): |
| train_ds = PatchDataset(C.DATA_PATCHES / "train", train=True) |
| val_ds = PatchDataset(C.DATA_PATCHES / "val", train=False) |
|
|
| |
| labels = np.array([lbl for _, lbl in train_ds.samples]) |
| class_counts = np.bincount(labels, minlength=2).astype(np.float32) |
| class_weights = 1.0 / np.maximum(class_counts, 1.0) |
| sample_weights = class_weights[labels] |
| sampler = WeightedRandomSampler(weights=sample_weights.tolist(), |
| num_samples=len(sample_weights), |
| replacement=True) |
|
|
| train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler, |
| num_workers=num_workers, pin_memory=True, drop_last=False) |
| val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, |
| num_workers=num_workers, pin_memory=True) |
| print(f"train: {len(train_ds)} (классы={class_counts.tolist()}) val: {len(val_ds)}") |
| return train_loader, val_loader |
|
|