paint_defect_detector / src\dataset.py
therealestcoder's picture
Upload src\dataset.py with huggingface_hub
ff1ef32 verified
"""DataLoader'ы с агрессивной аугментацией под малый датасет."""
from __future__ import annotations
from pathlib import Path
import cv2
import numpy as np
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import albumentations as A
from albumentations.pytorch import ToTensorV2
from . import config as C
from .prepare_data import imread_unicode
CLASS_TO_IDX = {"clean": 0, "defect": 1}
def build_transforms(train: bool) -> A.Compose:
if train:
return A.Compose([
A.LongestMaxSize(max_size=C.IMG_SIZE + 32),
A.PadIfNeeded(min_height=C.IMG_SIZE + 32, min_width=C.IMG_SIZE + 32,
border_mode=cv2.BORDER_REFLECT_101),
A.RandomCrop(height=C.IMG_SIZE, width=C.IMG_SIZE),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.RandomRotate90(p=0.5),
A.OneOf([
A.RandomBrightnessContrast(brightness_limit=0.25, contrast_limit=0.25, p=1.0),
A.HueSaturationValue(hue_shift_limit=8, sat_shift_limit=20, val_shift_limit=20, p=1.0),
A.CLAHE(clip_limit=2.0, p=1.0),
], p=0.7),
A.OneOf([
A.GaussianBlur(blur_limit=(3, 5), p=1.0),
A.MotionBlur(blur_limit=5, p=1.0),
A.GaussNoise(var_limit=(5.0, 25.0), p=1.0),
], p=0.4),
# имитируем блики/тени из реального цеха
A.RandomShadow(p=0.2),
A.RandomSunFlare(src_radius=80, num_flare_circles_lower=1,
num_flare_circles_upper=2, p=0.15),
A.CoarseDropout(max_holes=4, max_height=48, max_width=48, p=0.3),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
return A.Compose([
A.LongestMaxSize(max_size=C.IMG_SIZE),
A.PadIfNeeded(min_height=C.IMG_SIZE, min_width=C.IMG_SIZE,
border_mode=cv2.BORDER_REFLECT_101),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
class PatchDataset(Dataset):
"""Каталог: <root>/<class>/*.jpg, метки по имени папки."""
def __init__(self, root: Path, train: bool):
self.samples: list[tuple[Path, int]] = []
for cls, idx in CLASS_TO_IDX.items():
for f in (root / cls).glob("*.jpg"):
self.samples.append((f, idx))
if not self.samples:
raise RuntimeError(f"Нет патчей в {root}. Запустите prepare_data.py")
self.transform = build_transforms(train)
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, i: int):
path, label = self.samples[i]
img = imread_unicode(path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = self.transform(image=img)["image"]
return img, label
def make_loaders(batch_size: int = C.BATCH_SIZE, num_workers: int = C.NUM_WORKERS):
train_ds = PatchDataset(C.DATA_PATCHES / "train", train=True)
val_ds = PatchDataset(C.DATA_PATCHES / "val", train=False)
# балансировка классов через WeightedRandomSampler
labels = np.array([lbl for _, lbl in train_ds.samples])
class_counts = np.bincount(labels, minlength=2).astype(np.float32)
class_weights = 1.0 / np.maximum(class_counts, 1.0)
sample_weights = class_weights[labels]
sampler = WeightedRandomSampler(weights=sample_weights.tolist(),
num_samples=len(sample_weights),
replacement=True)
train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler,
num_workers=num_workers, pin_memory=True, drop_last=False)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
num_workers=num_workers, pin_memory=True)
print(f"train: {len(train_ds)} (классы={class_counts.tolist()}) val: {len(val_ds)}")
return train_loader, val_loader