from __future__ import annotations import argparse import os import shutil import subprocess import sys from pathlib import Path from typing import Any import pandas as pd from sklearn.model_selection import train_test_split from .config import load_config, save_config_snapshot from .paths import ensure_dir from .utils import get_logger LOGGER = get_logger(__name__) CANONICAL_LABELS = ("Not Damaged", "Damaged") LABEL_TO_ID = {"Not Damaged": 0, "Damaged": 1} ID_TO_LABEL = {0: "Not Damaged", 1: "Damaged"} SPLIT_ALIASES = { "train": {"train", "training", "trn"}, "val": {"val", "valid", "validation", "dev"}, "test": {"test", "testing", "tst"}, } IGNORED_PARTS = {"__macosx", ".ipynb_checkpoints"} def normalize_name(value: str) -> str: return "".join(ch for ch in value.lower() if ch.isalnum()) def detect_label_from_name(name: str) -> str | None: normalized = normalize_name(name) not_damaged_markers = { "notdamaged", "undamaged", "nodamage", "nondamaged", "intact", "normal", "healthy", "good", "clean", "fresh", } damaged_markers = {"damaged", "damage", "cracked", "crack", "broken", "defect", "defective"} if any(marker in normalized for marker in not_damaged_markers): return "Not Damaged" if any(marker in normalized for marker in damaged_markers): return "Damaged" return None def detect_split_from_name(name: str) -> str | None: normalized = normalize_name(name) for split, aliases in SPLIT_ALIASES.items(): if normalized in {normalize_name(alias) for alias in aliases}: return split return None def is_hidden_or_system(path: Path) -> bool: return any(part.startswith(".") or part.lower() in IGNORED_PARTS for part in path.parts) def iter_image_files(root: Path, extensions: set[str]) -> list[Path]: images: list[Path] = [] for path in root.rglob("*"): if path.is_file() and path.suffix.lower() in extensions and not is_hidden_or_system(path): images.append(path.resolve()) return sorted(images) def label_for_path(path: Path, root: Path) -> str | None: try: parts = path.relative_to(root).parts[:-1] except ValueError: parts = path.parts[:-1] for part in reversed(parts): label = detect_label_from_name(part) if label: return label return detect_label_from_name(path.stem) def split_for_path(path: Path, root: Path) -> str | None: try: parts = path.relative_to(root).parts[:-1] except ValueError: parts = path.parts[:-1] for part in parts: split = detect_split_from_name(part) if split: return split return None def build_labeled_dataframe(root: str | Path, config: dict[str, Any]) -> pd.DataFrame: root = Path(root).expanduser().resolve() if not root.exists(): raise FileNotFoundError(f"Dataset path does not exist: {root}") extensions = {ext.lower() for ext in config["data"]["image_extensions"]} rows: list[dict[str, str]] = [] for path in iter_image_files(root, extensions): label = label_for_path(path, root) if label is None: continue rows.append( { "filepath": str(path), "label": label, "split": split_for_path(path, root) or "", } ) if not rows: raise ValueError( "No labeled images were detected. Expected folders or filenames resembling " "'Damaged', 'Not Damaged', 'cracked', 'undamaged', 'normal', or similar." ) df = pd.DataFrame(rows).drop_duplicates(subset=["filepath"]).reset_index(drop=True) labels = set(df["label"]) missing = set(CANONICAL_LABELS) - labels if missing: raise ValueError(f"Detected labels {sorted(labels)}, but missing classes: {sorted(missing)}") return df def create_stratified_splits(df: pd.DataFrame, config: dict[str, Any]) -> pd.DataFrame: seed = int(config["seed"]) train_size = float(config["data"]["train_size"]) val_size = float(config["data"]["val_size"]) test_size = float(config["data"]["test_size"]) total = train_size + val_size + test_size if abs(total - 1.0) > 1e-6: train_size, val_size, test_size = train_size / total, val_size / total, test_size / total if df["label"].value_counts().min() < 3: raise ValueError("Each class needs at least 3 images for a 70/15/15 stratified split.") train_df, temp_df = train_test_split( df.drop(columns=["split"], errors="ignore"), train_size=train_size, stratify=df["label"], random_state=seed, ) relative_test = test_size / (val_size + test_size) val_df, test_df = train_test_split( temp_df, test_size=relative_test, stratify=temp_df["label"], random_state=seed, ) train_df = train_df.assign(split="train") val_df = val_df.assign(split="val") test_df = test_df.assign(split="test") return pd.concat([train_df, val_df, test_df], ignore_index=True).sort_values( ["split", "label", "filepath"] ) def complete_or_create_splits(df: pd.DataFrame, config: dict[str, Any]) -> pd.DataFrame: known = df["split"].replace("", pd.NA).dropna() if known.empty: LOGGER.info("No existing train/val/test split folders detected; creating stratified splits.") return create_stratified_splits(df, config) df = df[df["split"].isin(["train", "val", "test"])].copy() if df.empty: return create_stratified_splits(df, config) present = set(df["split"].unique()) if {"train", "val", "test"}.issubset(present): LOGGER.info("Existing train/val/test split folders detected.") return df.sort_values(["split", "label", "filepath"]).reset_index(drop=True) if "train" in present and "val" not in present: LOGGER.info("Existing split lacks validation data; carving validation from train only.") train_mask = df["split"] == "train" train_part = df[train_mask].drop(columns=["split"]) if train_part["label"].value_counts().min() >= 2: new_train, new_val = train_test_split( train_part, test_size=float(config["data"]["val_size"]), stratify=train_part["label"], random_state=int(config["seed"]), ) rest = df[~train_mask] df = pd.concat( [new_train.assign(split="train"), new_val.assign(split="val"), rest], ignore_index=True, ) missing = {"train", "val", "test"} - set(df["split"].unique()) if missing: LOGGER.warning("Missing split(s) %s; evaluation will use the available splits.", sorted(missing)) return df.sort_values(["split", "label", "filepath"]).reset_index(drop=True) def add_label_ids(df: pd.DataFrame) -> pd.DataFrame: out = df.copy() out["label_id"] = out["label"].map(LABEL_TO_ID).astype(int) return out def discover_dataset(config: dict[str, Any], data_dir: str | Path | None = None) -> pd.DataFrame: root = Path(data_dir or config["paths"]["data_dir"]).expanduser().resolve() df = build_labeled_dataframe(root, config) df = complete_or_create_splits(df, config) df = add_label_ids(df) return df[["filepath", "label", "label_id", "split"]].reset_index(drop=True) def class_distribution(df: pd.DataFrame) -> pd.DataFrame: return ( df.groupby(["split", "label"], observed=False) .size() .reset_index(name="count") .sort_values(["split", "label"]) ) def print_class_distribution(df: pd.DataFrame) -> None: dist = class_distribution(df) LOGGER.info("Class distribution:\n%s", dist.to_string(index=False)) for split, split_df in df.groupby("split"): counts = split_df["label"].value_counts() if len(counts) == 2: ratio = counts.max() / max(counts.min(), 1) LOGGER.info("%s imbalance ratio: %.2f", split, ratio) def save_split_metadata(df: pd.DataFrame, config: dict[str, Any]) -> Path: output_dir = ensure_dir(config["paths"]["output_dir"]) split_csv = Path(config["paths"]["split_csv"]) split_csv.parent.mkdir(parents=True, exist_ok=True) df.to_csv(split_csv, index=False) class_distribution(df).to_csv(output_dir / "class_distribution.csv", index=False) save_config_snapshot(config, output_dir) LOGGER.info("Saved split metadata: %s", split_csv) return split_csv def kaggle_credentials_available() -> bool: if Path.home().joinpath(".kaggle", "kaggle.json").exists(): return True return bool({"KAGGLE_USERNAME", "KAGGLE_KEY"}.issubset(set(os.environ))) def download_kaggle_dataset(config: dict[str, Any]) -> Path: dataset = config["kaggle"]["dataset"] download_dir = ensure_dir(config["kaggle"]["download_dir"]) if not kaggle_credentials_available(): raise RuntimeError( "Kaggle credentials were not found. Configure ~/.kaggle/kaggle.json or " "KAGGLE_USERNAME/KAGGLE_KEY, then retry." ) command = shutil.which("kaggle") if command: cmd = [command, "datasets", "download", "-d", dataset, "-p", str(download_dir), "--unzip"] else: cmd = [sys.executable, "-m", "kaggle", "datasets", "download", "-d", dataset, "-p", str(download_dir), "--unzip"] LOGGER.info("Downloading Kaggle dataset %s to %s", dataset, download_dir) subprocess.run(cmd, check=True) return download_dir def prepare_data(config: dict[str, Any], data_dir: str | Path | None = None, download: bool = False) -> pd.DataFrame: if download or config.get("kaggle", {}).get("enabled", False): data_dir = download_kaggle_dataset(config) config["paths"]["data_dir"] = str(data_dir) df = discover_dataset(config, data_dir) print_class_distribution(df) save_split_metadata(df, config) try: from .reporting import plot_class_distribution plot_class_distribution(df, Path(config["paths"]["output_dir"]) / "plots" / "class_distribution.png") except Exception as exc: LOGGER.warning("Could not save class distribution plot: %s", exc) return df def main() -> None: parser = argparse.ArgumentParser(description="Discover and split egg damage image dataset.") parser.add_argument("--config", default="configs/default.yaml") parser.add_argument("--data-dir", default=None) parser.add_argument("--download-kaggle", action="store_true") args = parser.parse_args() config = load_config(args.config) if args.data_dir: config["paths"]["data_dir"] = str(Path(args.data_dir).expanduser().resolve()) prepare_data(config, args.data_dir, args.download_kaggle) if __name__ == "__main__": main()