""" Data-loading utilities for the WAID dataset. The WAID dataset ships pre-split into train/valid/test directories under both images/ and labels/. This module validates that structure and generates the ``dataset.yaml`` file required by Ultralytics for training. Usage: from src.config import load_config from src.data.dataset import validate_dataset, generate_dataset_yaml cfg = load_config() stats = validate_dataset(cfg) yaml_path = generate_dataset_yaml(cfg) """ from __future__ import annotations import logging from pathlib import Path from typing import Any import yaml from src.config import Config logger = logging.getLogger(__name__) _IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} def validate_dataset(cfg: Config) -> dict[str, Any]: """Verify WAID dataset integrity and report per-split statistics. Checks that every split directory exists, counts images and labels, and flags orphans (images without labels or vice-versa). Args: cfg: Pipeline configuration. Returns: Dict with per-split counts and overall totals. Raises: FileNotFoundError: If the dataset root or expected directories are missing. """ root = Path(str(cfg.paths.dataset_root)) if not root.exists(): raise FileNotFoundError( f"Dataset root not found: {root}\n" "Clone the WAID repo or update paths.dataset_root in config." ) split_dirs = cfg.dataset.raw.get("split_dirs", {}) stats: dict[str, Any] = {"splits": {}, "total_images": 0, "total_labels": 0} for split_name, dir_name in split_dirs.items(): img_dir = root / "images" / dir_name lbl_dir = root / "labels" / dir_name if not img_dir.exists(): logger.warning("Missing image directory: %s", img_dir) continue if not lbl_dir.exists(): logger.warning("Missing label directory: %s", lbl_dir) continue img_stems = { p.stem for p in img_dir.iterdir() if p.is_file() and p.suffix.lower() in _IMAGE_EXTS } lbl_stems = { p.stem for p in lbl_dir.iterdir() if p.is_file() and p.suffix == ".txt" } split_stats = { "images": len(img_stems), "labels": len(lbl_stems), "images_without_labels": len(img_stems - lbl_stems), "labels_without_images": len(lbl_stems - img_stems), } stats["splits"][split_name] = split_stats stats["total_images"] += split_stats["images"] stats["total_labels"] += split_stats["labels"] logger.info( " %-6s — images: %4d labels: %4d orphan_img: %d orphan_lbl: %d", split_name, split_stats["images"], split_stats["labels"], split_stats["images_without_labels"], split_stats["labels_without_images"], ) logger.info( "Dataset totals — images: %d, labels: %d", stats["total_images"], stats["total_labels"], ) return stats def generate_dataset_yaml( cfg: Config, output_path: str | Path = "data/waid.yaml", ) -> Path: """Generate the ``dataset.yaml`` file required by Ultralytics for training. Points directly at the pre-split WAID directories. Args: cfg: Pipeline configuration. output_path: Where to write the YAML file. Returns: Absolute path to the generated YAML file. """ root = Path(str(cfg.paths.dataset_root)).resolve() split_dirs = cfg.dataset.raw.get("split_dirs", {}) ds_yaml: dict[str, Any] = { "path": str(root), "train": f"images/{split_dirs.get('train', 'train')}", "val": f"images/{split_dirs.get('val', 'valid')}", "test": f"images/{split_dirs.get('test', 'test')}", "nc": int(cfg.dataset.num_classes), "names": list(cfg.dataset.class_names), } out = Path(output_path) out.parent.mkdir(parents=True, exist_ok=True) with open(out, "w", encoding="utf-8") as fh: yaml.dump(ds_yaml, fh, default_flow_style=False, sort_keys=False) logger.info("Dataset YAML written to %s", out.resolve()) return out.resolve() def get_class_distribution(cfg: Config, split: str = "train") -> dict[str, int]: """Count per-class instances across all label files in a split. Reads every .txt label file and tallies class IDs. Args: cfg: Pipeline configuration. split: Which split to scan (``"train"``, ``"val"``, ``"test"``). Returns: Dict mapping class name → instance count. """ root = Path(str(cfg.paths.dataset_root)) split_dirs = cfg.dataset.raw.get("split_dirs", {}) dir_name = split_dirs.get(split, split) lbl_dir = root / "labels" / dir_name class_names = list(cfg.dataset.class_names) counts: dict[str, int] = {name: 0 for name in class_names} if not lbl_dir.exists(): logger.warning("Label directory not found: %s", lbl_dir) return counts for lbl_file in lbl_dir.iterdir(): if not lbl_file.suffix == ".txt": continue with open(lbl_file, "r", encoding="utf-8") as fh: for line in fh: parts = line.strip().split() if not parts: continue cls_id = int(parts[0]) if 0 <= cls_id < len(class_names): counts[class_names[cls_id]] += 1 logger.info("Class distribution (%s): %s", split, counts) return counts