File size: 4,174 Bytes
ff1ef32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""DataLoader'ы с агрессивной аугментацией под малый датасет."""
from __future__ import annotations
from pathlib import Path

import cv2
import numpy as np
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import albumentations as A
from albumentations.pytorch import ToTensorV2

from . import config as C
from .prepare_data import imread_unicode


CLASS_TO_IDX = {"clean": 0, "defect": 1}


def build_transforms(train: bool) -> A.Compose:
    if train:
        return A.Compose([
            A.LongestMaxSize(max_size=C.IMG_SIZE + 32),
            A.PadIfNeeded(min_height=C.IMG_SIZE + 32, min_width=C.IMG_SIZE + 32,
                          border_mode=cv2.BORDER_REFLECT_101),
            A.RandomCrop(height=C.IMG_SIZE, width=C.IMG_SIZE),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.OneOf([
                A.RandomBrightnessContrast(brightness_limit=0.25, contrast_limit=0.25, p=1.0),
                A.HueSaturationValue(hue_shift_limit=8, sat_shift_limit=20, val_shift_limit=20, p=1.0),
                A.CLAHE(clip_limit=2.0, p=1.0),
            ], p=0.7),
            A.OneOf([
                A.GaussianBlur(blur_limit=(3, 5), p=1.0),
                A.MotionBlur(blur_limit=5, p=1.0),
                A.GaussNoise(var_limit=(5.0, 25.0), p=1.0),
            ], p=0.4),
            # имитируем блики/тени из реального цеха
            A.RandomShadow(p=0.2),
            A.RandomSunFlare(src_radius=80, num_flare_circles_lower=1,
                             num_flare_circles_upper=2, p=0.15),
            A.CoarseDropout(max_holes=4, max_height=48, max_width=48, p=0.3),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ])
    return A.Compose([
        A.LongestMaxSize(max_size=C.IMG_SIZE),
        A.PadIfNeeded(min_height=C.IMG_SIZE, min_width=C.IMG_SIZE,
                      border_mode=cv2.BORDER_REFLECT_101),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])


class PatchDataset(Dataset):
    """Каталог: <root>/<class>/*.jpg, метки по имени папки."""

    def __init__(self, root: Path, train: bool):
        self.samples: list[tuple[Path, int]] = []
        for cls, idx in CLASS_TO_IDX.items():
            for f in (root / cls).glob("*.jpg"):
                self.samples.append((f, idx))
        if not self.samples:
            raise RuntimeError(f"Нет патчей в {root}. Запустите prepare_data.py")
        self.transform = build_transforms(train)

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, i: int):
        path, label = self.samples[i]
        img = imread_unicode(path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.transform(image=img)["image"]
        return img, label


def make_loaders(batch_size: int = C.BATCH_SIZE, num_workers: int = C.NUM_WORKERS):
    train_ds = PatchDataset(C.DATA_PATCHES / "train", train=True)
    val_ds = PatchDataset(C.DATA_PATCHES / "val", train=False)

    # балансировка классов через WeightedRandomSampler
    labels = np.array([lbl for _, lbl in train_ds.samples])
    class_counts = np.bincount(labels, minlength=2).astype(np.float32)
    class_weights = 1.0 / np.maximum(class_counts, 1.0)
    sample_weights = class_weights[labels]
    sampler = WeightedRandomSampler(weights=sample_weights.tolist(),
                                    num_samples=len(sample_weights),
                                    replacement=True)

    train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler,
                              num_workers=num_workers, pin_memory=True, drop_last=False)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
                            num_workers=num_workers, pin_memory=True)
    print(f"train: {len(train_ds)} (классы={class_counts.tolist()})  val: {len(val_ds)}")
    return train_loader, val_loader