therealestcoder commited on
Commit
ff1ef32
·
verified ·
1 Parent(s): 3771fa4

Upload src\dataset.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src//dataset.py +96 -0
src//dataset.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DataLoader'ы с агрессивной аугментацией под малый датасет."""
2
+ from __future__ import annotations
3
+ from pathlib import Path
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
8
+ import albumentations as A
9
+ from albumentations.pytorch import ToTensorV2
10
+
11
+ from . import config as C
12
+ from .prepare_data import imread_unicode
13
+
14
+
15
+ CLASS_TO_IDX = {"clean": 0, "defect": 1}
16
+
17
+
18
+ def build_transforms(train: bool) -> A.Compose:
19
+ if train:
20
+ return A.Compose([
21
+ A.LongestMaxSize(max_size=C.IMG_SIZE + 32),
22
+ A.PadIfNeeded(min_height=C.IMG_SIZE + 32, min_width=C.IMG_SIZE + 32,
23
+ border_mode=cv2.BORDER_REFLECT_101),
24
+ A.RandomCrop(height=C.IMG_SIZE, width=C.IMG_SIZE),
25
+ A.HorizontalFlip(p=0.5),
26
+ A.VerticalFlip(p=0.5),
27
+ A.RandomRotate90(p=0.5),
28
+ A.OneOf([
29
+ A.RandomBrightnessContrast(brightness_limit=0.25, contrast_limit=0.25, p=1.0),
30
+ A.HueSaturationValue(hue_shift_limit=8, sat_shift_limit=20, val_shift_limit=20, p=1.0),
31
+ A.CLAHE(clip_limit=2.0, p=1.0),
32
+ ], p=0.7),
33
+ A.OneOf([
34
+ A.GaussianBlur(blur_limit=(3, 5), p=1.0),
35
+ A.MotionBlur(blur_limit=5, p=1.0),
36
+ A.GaussNoise(var_limit=(5.0, 25.0), p=1.0),
37
+ ], p=0.4),
38
+ # имитируем блики/тени из реального цеха
39
+ A.RandomShadow(p=0.2),
40
+ A.RandomSunFlare(src_radius=80, num_flare_circles_lower=1,
41
+ num_flare_circles_upper=2, p=0.15),
42
+ A.CoarseDropout(max_holes=4, max_height=48, max_width=48, p=0.3),
43
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
44
+ ToTensorV2(),
45
+ ])
46
+ return A.Compose([
47
+ A.LongestMaxSize(max_size=C.IMG_SIZE),
48
+ A.PadIfNeeded(min_height=C.IMG_SIZE, min_width=C.IMG_SIZE,
49
+ border_mode=cv2.BORDER_REFLECT_101),
50
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
51
+ ToTensorV2(),
52
+ ])
53
+
54
+
55
+ class PatchDataset(Dataset):
56
+ """Каталог: <root>/<class>/*.jpg, метки по имени папки."""
57
+
58
+ def __init__(self, root: Path, train: bool):
59
+ self.samples: list[tuple[Path, int]] = []
60
+ for cls, idx in CLASS_TO_IDX.items():
61
+ for f in (root / cls).glob("*.jpg"):
62
+ self.samples.append((f, idx))
63
+ if not self.samples:
64
+ raise RuntimeError(f"Нет патчей в {root}. Запустите prepare_data.py")
65
+ self.transform = build_transforms(train)
66
+
67
+ def __len__(self) -> int:
68
+ return len(self.samples)
69
+
70
+ def __getitem__(self, i: int):
71
+ path, label = self.samples[i]
72
+ img = imread_unicode(path)
73
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
74
+ img = self.transform(image=img)["image"]
75
+ return img, label
76
+
77
+
78
+ def make_loaders(batch_size: int = C.BATCH_SIZE, num_workers: int = C.NUM_WORKERS):
79
+ train_ds = PatchDataset(C.DATA_PATCHES / "train", train=True)
80
+ val_ds = PatchDataset(C.DATA_PATCHES / "val", train=False)
81
+
82
+ # балансировка классов через WeightedRandomSampler
83
+ labels = np.array([lbl for _, lbl in train_ds.samples])
84
+ class_counts = np.bincount(labels, minlength=2).astype(np.float32)
85
+ class_weights = 1.0 / np.maximum(class_counts, 1.0)
86
+ sample_weights = class_weights[labels]
87
+ sampler = WeightedRandomSampler(weights=sample_weights.tolist(),
88
+ num_samples=len(sample_weights),
89
+ replacement=True)
90
+
91
+ train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler,
92
+ num_workers=num_workers, pin_memory=True, drop_last=False)
93
+ val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
94
+ num_workers=num_workers, pin_memory=True)
95
+ print(f"train: {len(train_ds)} (классы={class_counts.tolist()}) val: {len(val_ds)}")
96
+ return train_loader, val_loader