Prometheus-prototype / src /data /dataset.py
Tadiwa-M
Deploy: auto-derive dedup radius (drop hardcoded 3m bypass)
58aefd4
Raw
History Blame Contribute Delete
5.61 kB
"""
Data-loading utilities for the WAID dataset.
The WAID dataset ships pre-split into train/valid/test directories under
both images/ and labels/. This module validates that structure and generates
the ``dataset.yaml`` file required by Ultralytics for training.
Usage:
from src.config import load_config
from src.data.dataset import validate_dataset, generate_dataset_yaml
cfg = load_config()
stats = validate_dataset(cfg)
yaml_path = generate_dataset_yaml(cfg)
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
import yaml
from src.config import Config
logger = logging.getLogger(__name__)
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
def validate_dataset(cfg: Config) -> dict[str, Any]:
"""Verify WAID dataset integrity and report per-split statistics.
Checks that every split directory exists, counts images and labels,
and flags orphans (images without labels or vice-versa).
Args:
cfg: Pipeline configuration.
Returns:
Dict with per-split counts and overall totals.
Raises:
FileNotFoundError: If the dataset root or expected directories are missing.
"""
root = Path(str(cfg.paths.dataset_root))
if not root.exists():
raise FileNotFoundError(
f"Dataset root not found: {root}\n"
"Clone the WAID repo or update paths.dataset_root in config."
)
split_dirs = cfg.dataset.raw.get("split_dirs", {})
stats: dict[str, Any] = {"splits": {}, "total_images": 0, "total_labels": 0}
for split_name, dir_name in split_dirs.items():
img_dir = root / "images" / dir_name
lbl_dir = root / "labels" / dir_name
if not img_dir.exists():
logger.warning("Missing image directory: %s", img_dir)
continue
if not lbl_dir.exists():
logger.warning("Missing label directory: %s", lbl_dir)
continue
img_stems = {
p.stem for p in img_dir.iterdir()
if p.is_file() and p.suffix.lower() in _IMAGE_EXTS
}
lbl_stems = {
p.stem for p in lbl_dir.iterdir()
if p.is_file() and p.suffix == ".txt"
}
split_stats = {
"images": len(img_stems),
"labels": len(lbl_stems),
"images_without_labels": len(img_stems - lbl_stems),
"labels_without_images": len(lbl_stems - img_stems),
}
stats["splits"][split_name] = split_stats
stats["total_images"] += split_stats["images"]
stats["total_labels"] += split_stats["labels"]
logger.info(
" %-6s — images: %4d labels: %4d orphan_img: %d orphan_lbl: %d",
split_name,
split_stats["images"],
split_stats["labels"],
split_stats["images_without_labels"],
split_stats["labels_without_images"],
)
logger.info(
"Dataset totals — images: %d, labels: %d",
stats["total_images"],
stats["total_labels"],
)
return stats
def generate_dataset_yaml(
cfg: Config,
output_path: str | Path = "data/waid.yaml",
) -> Path:
"""Generate the ``dataset.yaml`` file required by Ultralytics for training.
Points directly at the pre-split WAID directories.
Args:
cfg: Pipeline configuration.
output_path: Where to write the YAML file.
Returns:
Absolute path to the generated YAML file.
"""
root = Path(str(cfg.paths.dataset_root)).resolve()
split_dirs = cfg.dataset.raw.get("split_dirs", {})
ds_yaml: dict[str, Any] = {
"path": str(root),
"train": f"images/{split_dirs.get('train', 'train')}",
"val": f"images/{split_dirs.get('val', 'valid')}",
"test": f"images/{split_dirs.get('test', 'test')}",
"nc": int(cfg.dataset.num_classes),
"names": list(cfg.dataset.class_names),
}
out = Path(output_path)
out.parent.mkdir(parents=True, exist_ok=True)
with open(out, "w", encoding="utf-8") as fh:
yaml.dump(ds_yaml, fh, default_flow_style=False, sort_keys=False)
logger.info("Dataset YAML written to %s", out.resolve())
return out.resolve()
def get_class_distribution(cfg: Config, split: str = "train") -> dict[str, int]:
"""Count per-class instances across all label files in a split.
Reads every .txt label file and tallies class IDs.
Args:
cfg: Pipeline configuration.
split: Which split to scan (``"train"``, ``"val"``, ``"test"``).
Returns:
Dict mapping class name → instance count.
"""
root = Path(str(cfg.paths.dataset_root))
split_dirs = cfg.dataset.raw.get("split_dirs", {})
dir_name = split_dirs.get(split, split)
lbl_dir = root / "labels" / dir_name
class_names = list(cfg.dataset.class_names)
counts: dict[str, int] = {name: 0 for name in class_names}
if not lbl_dir.exists():
logger.warning("Label directory not found: %s", lbl_dir)
return counts
for lbl_file in lbl_dir.iterdir():
if not lbl_file.suffix == ".txt":
continue
with open(lbl_file, "r", encoding="utf-8") as fh:
for line in fh:
parts = line.strip().split()
if not parts:
continue
cls_id = int(parts[0])
if 0 <= cls_id < len(class_names):
counts[class_names[cls_id]] += 1
logger.info("Class distribution (%s): %s", split, counts)
return counts