| | |
| | import contextlib |
| | from itertools import repeat |
| | from multiprocessing.pool import ThreadPool |
| | from pathlib import Path |
| |
|
| | import cv2 |
| | import numpy as np |
| | import torch |
| | import torchvision |
| | from PIL import Image |
| |
|
| | from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable |
| | from ultralytics.utils.ops import resample_segments |
| | from .augment import Compose, Format, Instances, LetterBox, classify_augmentations, classify_transforms, v8_transforms |
| | from .base import BaseDataset |
| | from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label |
| |
|
| | |
| | DATASET_CACHE_VERSION = "1.0.3" |
| |
|
| |
|
| | class YOLODataset(BaseDataset): |
| | """ |
| | Dataset class for loading object detection and/or segmentation labels in YOLO format. |
| | |
| | Args: |
| | data (dict, optional): A dataset YAML dictionary. Defaults to None. |
| | task (str): An explicit arg to point current task, Defaults to 'detect'. |
| | |
| | Returns: |
| | (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model. |
| | """ |
| |
|
| | def __init__(self, *args, data=None, task="detect", **kwargs): |
| | """Initializes the YOLODataset with optional configurations for segments and keypoints.""" |
| | self.use_segments = task == "segment" |
| | self.use_keypoints = task == "pose" |
| | self.use_obb = task == "obb" |
| | self.data = data |
| | assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints." |
| | super().__init__(*args, **kwargs) |
| |
|
| | def cache_labels(self, path=Path("./labels.cache")): |
| | """ |
| | Cache dataset labels, check images and read shapes. |
| | |
| | Args: |
| | path (Path): Path where to save the cache file. Default is Path('./labels.cache'). |
| | |
| | Returns: |
| | (dict): labels. |
| | """ |
| | x = {"labels": []} |
| | nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] |
| | desc = f"{self.prefix}Scanning {path.parent / path.stem}..." |
| | total = len(self.im_files) |
| | nkpt, ndim = self.data.get("kpt_shape", (0, 0)) |
| | if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)): |
| | raise ValueError( |
| | "'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of " |
| | "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'" |
| | ) |
| | with ThreadPool(NUM_THREADS) as pool: |
| | results = pool.imap( |
| | func=verify_image_label, |
| | iterable=zip( |
| | self.im_files, |
| | self.label_files, |
| | repeat(self.prefix), |
| | repeat(self.use_keypoints), |
| | repeat(len(self.data["names"])), |
| | repeat(nkpt), |
| | repeat(ndim), |
| | ), |
| | ) |
| | pbar = TQDM(results, desc=desc, total=total) |
| | for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar: |
| | nm += nm_f |
| | nf += nf_f |
| | ne += ne_f |
| | nc += nc_f |
| | if im_file: |
| | x["labels"].append( |
| | dict( |
| | im_file=im_file, |
| | shape=shape, |
| | cls=lb[:, 0:1], |
| | bboxes=lb[:, 1:], |
| | segments=segments, |
| | keypoints=keypoint, |
| | normalized=True, |
| | bbox_format="xywh", |
| | ) |
| | ) |
| | if msg: |
| | msgs.append(msg) |
| | pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt" |
| | pbar.close() |
| |
|
| | if msgs: |
| | LOGGER.info("\n".join(msgs)) |
| | if nf == 0: |
| | LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}") |
| | x["hash"] = get_hash(self.label_files + self.im_files) |
| | x["results"] = nf, nm, ne, nc, len(self.im_files) |
| | x["msgs"] = msgs |
| | save_dataset_cache_file(self.prefix, path, x) |
| | return x |
| |
|
| | def get_labels(self): |
| | """Returns dictionary of labels for YOLO training.""" |
| | self.label_files = img2label_paths(self.im_files) |
| | cache_path = Path(self.label_files[0]).parent.with_suffix(".cache") |
| | try: |
| | cache, exists = load_dataset_cache_file(cache_path), True |
| | assert cache["version"] == DATASET_CACHE_VERSION |
| | assert cache["hash"] == get_hash(self.label_files + self.im_files) |
| | except (FileNotFoundError, AssertionError, AttributeError): |
| | cache, exists = self.cache_labels(cache_path), False |
| |
|
| | |
| | nf, nm, ne, nc, n = cache.pop("results") |
| | if exists and LOCAL_RANK in (-1, 0): |
| | d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt" |
| | TQDM(None, desc=self.prefix + d, total=n, initial=n) |
| | if cache["msgs"]: |
| | LOGGER.info("\n".join(cache["msgs"])) |
| |
|
| | |
| | [cache.pop(k) for k in ("hash", "version", "msgs")] |
| | labels = cache["labels"] |
| | if not labels: |
| | LOGGER.warning(f"WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}") |
| | self.im_files = [lb["im_file"] for lb in labels] |
| |
|
| | |
| | lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels) |
| | len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths)) |
| | if len_segments and len_boxes != len_segments: |
| | LOGGER.warning( |
| | f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, " |
| | f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. " |
| | "To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset." |
| | ) |
| | for lb in labels: |
| | lb["segments"] = [] |
| | if len_cls == 0: |
| | LOGGER.warning(f"WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}") |
| | return labels |
| |
|
| | def build_transforms(self, hyp=None): |
| | """Builds and appends transforms to the list.""" |
| | if self.augment: |
| | hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0 |
| | hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0 |
| | transforms = v8_transforms(self, self.imgsz, hyp) |
| | else: |
| | transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)]) |
| | transforms.append( |
| | Format( |
| | bbox_format="xywh", |
| | normalize=True, |
| | return_mask=self.use_segments, |
| | return_keypoint=self.use_keypoints, |
| | return_obb=self.use_obb, |
| | batch_idx=True, |
| | mask_ratio=hyp.mask_ratio, |
| | mask_overlap=hyp.overlap_mask, |
| | bgr=hyp.bgr if self.augment else 0.0, |
| | ) |
| | ) |
| | return transforms |
| |
|
| | def close_mosaic(self, hyp): |
| | """Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations.""" |
| | hyp.mosaic = 0.0 |
| | hyp.copy_paste = 0.0 |
| | hyp.mixup = 0.0 |
| | self.transforms = self.build_transforms(hyp) |
| |
|
| | def update_labels_info(self, label): |
| | """ |
| | Custom your label format here. |
| | |
| | Note: |
| | cls is not with bboxes now, classification and semantic segmentation need an independent cls label |
| | Can also support classification and semantic segmentation by adding or removing dict keys there. |
| | """ |
| | bboxes = label.pop("bboxes") |
| | segments = label.pop("segments", []) |
| | keypoints = label.pop("keypoints", None) |
| | bbox_format = label.pop("bbox_format") |
| | normalized = label.pop("normalized") |
| |
|
| | |
| | segment_resamples = 100 if self.use_obb else 1000 |
| | if len(segments) > 0: |
| | |
| | |
| | segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0) |
| | else: |
| | segments = np.zeros((0, segment_resamples, 2), dtype=np.float32) |
| | label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized) |
| | return label |
| |
|
| | @staticmethod |
| | def collate_fn(batch): |
| | """Collates data samples into batches.""" |
| | new_batch = {} |
| | keys = batch[0].keys() |
| | values = list(zip(*[list(b.values()) for b in batch])) |
| | for i, k in enumerate(keys): |
| | value = values[i] |
| | if k == "img": |
| | value = torch.stack(value, 0) |
| | if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]: |
| | value = torch.cat(value, 0) |
| | new_batch[k] = value |
| | new_batch["batch_idx"] = list(new_batch["batch_idx"]) |
| | for i in range(len(new_batch["batch_idx"])): |
| | new_batch["batch_idx"][i] += i |
| | new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0) |
| | return new_batch |
| |
|
| |
|
| | |
| | class ClassificationDataset(torchvision.datasets.ImageFolder): |
| | """ |
| | Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image |
| | augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep |
| | learning models, with optional image transformations and caching mechanisms to speed up training. |
| | |
| | This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images |
| | in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process |
| | to ensure data integrity and consistency. |
| | |
| | Attributes: |
| | cache_ram (bool): Indicates if caching in RAM is enabled. |
| | cache_disk (bool): Indicates if caching on disk is enabled. |
| | samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache |
| | file (if caching on disk), and optionally the loaded image array (if caching in RAM). |
| | torch_transforms (callable): PyTorch transforms to be applied to the images. |
| | """ |
| |
|
| | def __init__(self, root, args, augment=False, prefix=""): |
| | """ |
| | Initialize YOLO object with root, image size, augmentations, and cache settings. |
| | |
| | Args: |
| | root (str): Path to the dataset directory where images are stored in a class-specific folder structure. |
| | args (Namespace): Configuration containing dataset-related settings such as image size, augmentation |
| | parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction |
| | of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training), |
| | `auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`. |
| | augment (bool, optional): Whether to apply augmentations to the dataset. Default is False. |
| | prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and |
| | debugging. Default is an empty string. |
| | """ |
| | super().__init__(root=root) |
| | if augment and args.fraction < 1.0: |
| | self.samples = self.samples[: round(len(self.samples) * args.fraction)] |
| | self.prefix = colorstr(f"{prefix}: ") if prefix else "" |
| | self.cache_ram = args.cache is True or args.cache == "ram" |
| | self.cache_disk = args.cache == "disk" |
| | self.samples = self.verify_images() |
| | self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] |
| | scale = (1.0 - args.scale, 1.0) |
| | self.torch_transforms = ( |
| | classify_augmentations( |
| | size=args.imgsz, |
| | scale=scale, |
| | hflip=args.fliplr, |
| | vflip=args.flipud, |
| | erasing=args.erasing, |
| | auto_augment=args.auto_augment, |
| | hsv_h=args.hsv_h, |
| | hsv_s=args.hsv_s, |
| | hsv_v=args.hsv_v, |
| | ) |
| | if augment |
| | else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction) |
| | ) |
| |
|
| | def __getitem__(self, i): |
| | """Returns subset of data and targets corresponding to given indices.""" |
| | f, j, fn, im = self.samples[i] |
| | if self.cache_ram and im is None: |
| | im = self.samples[i][3] = cv2.imread(f) |
| | elif self.cache_disk: |
| | if not fn.exists(): |
| | np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False) |
| | im = np.load(fn) |
| | else: |
| | im = cv2.imread(f) |
| | |
| | im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB)) |
| | sample = self.torch_transforms(im) |
| | return {"img": sample, "cls": j} |
| |
|
| | def __len__(self) -> int: |
| | """Return the total number of samples in the dataset.""" |
| | return len(self.samples) |
| |
|
| | def verify_images(self): |
| | """Verify all images in dataset.""" |
| | desc = f"{self.prefix}Scanning {self.root}..." |
| | path = Path(self.root).with_suffix(".cache") |
| |
|
| | with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError): |
| | cache = load_dataset_cache_file(path) |
| | assert cache["version"] == DATASET_CACHE_VERSION |
| | assert cache["hash"] == get_hash([x[0] for x in self.samples]) |
| | nf, nc, n, samples = cache.pop("results") |
| | if LOCAL_RANK in (-1, 0): |
| | d = f"{desc} {nf} images, {nc} corrupt" |
| | TQDM(None, desc=d, total=n, initial=n) |
| | if cache["msgs"]: |
| | LOGGER.info("\n".join(cache["msgs"])) |
| | return samples |
| |
|
| | |
| | nf, nc, msgs, samples, x = 0, 0, [], [], {} |
| | with ThreadPool(NUM_THREADS) as pool: |
| | results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix))) |
| | pbar = TQDM(results, desc=desc, total=len(self.samples)) |
| | for sample, nf_f, nc_f, msg in pbar: |
| | if nf_f: |
| | samples.append(sample) |
| | if msg: |
| | msgs.append(msg) |
| | nf += nf_f |
| | nc += nc_f |
| | pbar.desc = f"{desc} {nf} images, {nc} corrupt" |
| | pbar.close() |
| | if msgs: |
| | LOGGER.info("\n".join(msgs)) |
| | x["hash"] = get_hash([x[0] for x in self.samples]) |
| | x["results"] = nf, nc, len(samples), samples |
| | x["msgs"] = msgs |
| | save_dataset_cache_file(self.prefix, path, x) |
| | return samples |
| |
|
| |
|
| | def load_dataset_cache_file(path): |
| | """Load an Ultralytics *.cache dictionary from path.""" |
| | import gc |
| |
|
| | gc.disable() |
| | cache = np.load(str(path), allow_pickle=True).item() |
| | gc.enable() |
| | return cache |
| |
|
| |
|
| | def save_dataset_cache_file(prefix, path, x): |
| | """Save an Ultralytics dataset *.cache dictionary x to path.""" |
| | x["version"] = DATASET_CACHE_VERSION |
| | if is_dir_writeable(path.parent): |
| | if path.exists(): |
| | path.unlink() |
| | np.save(str(path), x) |
| | path.with_suffix(".cache.npy").rename(path) |
| | LOGGER.info(f"{prefix}New cache created: {path}") |
| | else: |
| | LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.") |
| |
|
| |
|
| | |
| | class SemanticDataset(BaseDataset): |
| | """ |
| | Semantic Segmentation Dataset. |
| | |
| | This class is responsible for handling datasets used for semantic segmentation tasks. It inherits functionalities |
| | from the BaseDataset class. |
| | |
| | Note: |
| | This class is currently a placeholder and needs to be populated with methods and attributes for supporting |
| | semantic segmentation tasks. |
| | """ |
| |
|
| | def __init__(self): |
| | """Initialize a SemanticDataset object.""" |
| | super().__init__() |
| |
|