| from __future__ import annotations |
|
|
| import argparse |
| import os |
| import shutil |
| import subprocess |
| import sys |
| from pathlib import Path |
| from typing import Any |
|
|
| import pandas as pd |
| from sklearn.model_selection import train_test_split |
|
|
| from .config import load_config, save_config_snapshot |
| from .paths import ensure_dir |
| from .utils import get_logger |
|
|
|
|
| LOGGER = get_logger(__name__) |
|
|
| CANONICAL_LABELS = ("Not Damaged", "Damaged") |
| LABEL_TO_ID = {"Not Damaged": 0, "Damaged": 1} |
| ID_TO_LABEL = {0: "Not Damaged", 1: "Damaged"} |
| SPLIT_ALIASES = { |
| "train": {"train", "training", "trn"}, |
| "val": {"val", "valid", "validation", "dev"}, |
| "test": {"test", "testing", "tst"}, |
| } |
| IGNORED_PARTS = {"__macosx", ".ipynb_checkpoints"} |
|
|
|
|
| def normalize_name(value: str) -> str: |
| return "".join(ch for ch in value.lower() if ch.isalnum()) |
|
|
|
|
| def detect_label_from_name(name: str) -> str | None: |
| normalized = normalize_name(name) |
| not_damaged_markers = { |
| "notdamaged", |
| "undamaged", |
| "nodamage", |
| "nondamaged", |
| "intact", |
| "normal", |
| "healthy", |
| "good", |
| "clean", |
| "fresh", |
| } |
| damaged_markers = {"damaged", "damage", "cracked", "crack", "broken", "defect", "defective"} |
| if any(marker in normalized for marker in not_damaged_markers): |
| return "Not Damaged" |
| if any(marker in normalized for marker in damaged_markers): |
| return "Damaged" |
| return None |
|
|
|
|
| def detect_split_from_name(name: str) -> str | None: |
| normalized = normalize_name(name) |
| for split, aliases in SPLIT_ALIASES.items(): |
| if normalized in {normalize_name(alias) for alias in aliases}: |
| return split |
| return None |
|
|
|
|
| def is_hidden_or_system(path: Path) -> bool: |
| return any(part.startswith(".") or part.lower() in IGNORED_PARTS for part in path.parts) |
|
|
|
|
| def iter_image_files(root: Path, extensions: set[str]) -> list[Path]: |
| images: list[Path] = [] |
| for path in root.rglob("*"): |
| if path.is_file() and path.suffix.lower() in extensions and not is_hidden_or_system(path): |
| images.append(path.resolve()) |
| return sorted(images) |
|
|
|
|
| def label_for_path(path: Path, root: Path) -> str | None: |
| try: |
| parts = path.relative_to(root).parts[:-1] |
| except ValueError: |
| parts = path.parts[:-1] |
| for part in reversed(parts): |
| label = detect_label_from_name(part) |
| if label: |
| return label |
| return detect_label_from_name(path.stem) |
|
|
|
|
| def split_for_path(path: Path, root: Path) -> str | None: |
| try: |
| parts = path.relative_to(root).parts[:-1] |
| except ValueError: |
| parts = path.parts[:-1] |
| for part in parts: |
| split = detect_split_from_name(part) |
| if split: |
| return split |
| return None |
|
|
|
|
| def build_labeled_dataframe(root: str | Path, config: dict[str, Any]) -> pd.DataFrame: |
| root = Path(root).expanduser().resolve() |
| if not root.exists(): |
| raise FileNotFoundError(f"Dataset path does not exist: {root}") |
| extensions = {ext.lower() for ext in config["data"]["image_extensions"]} |
| rows: list[dict[str, str]] = [] |
| for path in iter_image_files(root, extensions): |
| label = label_for_path(path, root) |
| if label is None: |
| continue |
| rows.append( |
| { |
| "filepath": str(path), |
| "label": label, |
| "split": split_for_path(path, root) or "", |
| } |
| ) |
| if not rows: |
| raise ValueError( |
| "No labeled images were detected. Expected folders or filenames resembling " |
| "'Damaged', 'Not Damaged', 'cracked', 'undamaged', 'normal', or similar." |
| ) |
| df = pd.DataFrame(rows).drop_duplicates(subset=["filepath"]).reset_index(drop=True) |
| labels = set(df["label"]) |
| missing = set(CANONICAL_LABELS) - labels |
| if missing: |
| raise ValueError(f"Detected labels {sorted(labels)}, but missing classes: {sorted(missing)}") |
| return df |
|
|
|
|
| def create_stratified_splits(df: pd.DataFrame, config: dict[str, Any]) -> pd.DataFrame: |
| seed = int(config["seed"]) |
| train_size = float(config["data"]["train_size"]) |
| val_size = float(config["data"]["val_size"]) |
| test_size = float(config["data"]["test_size"]) |
| total = train_size + val_size + test_size |
| if abs(total - 1.0) > 1e-6: |
| train_size, val_size, test_size = train_size / total, val_size / total, test_size / total |
|
|
| if df["label"].value_counts().min() < 3: |
| raise ValueError("Each class needs at least 3 images for a 70/15/15 stratified split.") |
|
|
| train_df, temp_df = train_test_split( |
| df.drop(columns=["split"], errors="ignore"), |
| train_size=train_size, |
| stratify=df["label"], |
| random_state=seed, |
| ) |
| relative_test = test_size / (val_size + test_size) |
| val_df, test_df = train_test_split( |
| temp_df, |
| test_size=relative_test, |
| stratify=temp_df["label"], |
| random_state=seed, |
| ) |
| train_df = train_df.assign(split="train") |
| val_df = val_df.assign(split="val") |
| test_df = test_df.assign(split="test") |
| return pd.concat([train_df, val_df, test_df], ignore_index=True).sort_values( |
| ["split", "label", "filepath"] |
| ) |
|
|
|
|
| def complete_or_create_splits(df: pd.DataFrame, config: dict[str, Any]) -> pd.DataFrame: |
| known = df["split"].replace("", pd.NA).dropna() |
| if known.empty: |
| LOGGER.info("No existing train/val/test split folders detected; creating stratified splits.") |
| return create_stratified_splits(df, config) |
|
|
| df = df[df["split"].isin(["train", "val", "test"])].copy() |
| if df.empty: |
| return create_stratified_splits(df, config) |
| present = set(df["split"].unique()) |
| if {"train", "val", "test"}.issubset(present): |
| LOGGER.info("Existing train/val/test split folders detected.") |
| return df.sort_values(["split", "label", "filepath"]).reset_index(drop=True) |
| if "train" in present and "val" not in present: |
| LOGGER.info("Existing split lacks validation data; carving validation from train only.") |
| train_mask = df["split"] == "train" |
| train_part = df[train_mask].drop(columns=["split"]) |
| if train_part["label"].value_counts().min() >= 2: |
| new_train, new_val = train_test_split( |
| train_part, |
| test_size=float(config["data"]["val_size"]), |
| stratify=train_part["label"], |
| random_state=int(config["seed"]), |
| ) |
| rest = df[~train_mask] |
| df = pd.concat( |
| [new_train.assign(split="train"), new_val.assign(split="val"), rest], |
| ignore_index=True, |
| ) |
| missing = {"train", "val", "test"} - set(df["split"].unique()) |
| if missing: |
| LOGGER.warning("Missing split(s) %s; evaluation will use the available splits.", sorted(missing)) |
| return df.sort_values(["split", "label", "filepath"]).reset_index(drop=True) |
|
|
|
|
| def add_label_ids(df: pd.DataFrame) -> pd.DataFrame: |
| out = df.copy() |
| out["label_id"] = out["label"].map(LABEL_TO_ID).astype(int) |
| return out |
|
|
|
|
| def discover_dataset(config: dict[str, Any], data_dir: str | Path | None = None) -> pd.DataFrame: |
| root = Path(data_dir or config["paths"]["data_dir"]).expanduser().resolve() |
| df = build_labeled_dataframe(root, config) |
| df = complete_or_create_splits(df, config) |
| df = add_label_ids(df) |
| return df[["filepath", "label", "label_id", "split"]].reset_index(drop=True) |
|
|
|
|
| def class_distribution(df: pd.DataFrame) -> pd.DataFrame: |
| return ( |
| df.groupby(["split", "label"], observed=False) |
| .size() |
| .reset_index(name="count") |
| .sort_values(["split", "label"]) |
| ) |
|
|
|
|
| def print_class_distribution(df: pd.DataFrame) -> None: |
| dist = class_distribution(df) |
| LOGGER.info("Class distribution:\n%s", dist.to_string(index=False)) |
| for split, split_df in df.groupby("split"): |
| counts = split_df["label"].value_counts() |
| if len(counts) == 2: |
| ratio = counts.max() / max(counts.min(), 1) |
| LOGGER.info("%s imbalance ratio: %.2f", split, ratio) |
|
|
|
|
| def save_split_metadata(df: pd.DataFrame, config: dict[str, Any]) -> Path: |
| output_dir = ensure_dir(config["paths"]["output_dir"]) |
| split_csv = Path(config["paths"]["split_csv"]) |
| split_csv.parent.mkdir(parents=True, exist_ok=True) |
| df.to_csv(split_csv, index=False) |
| class_distribution(df).to_csv(output_dir / "class_distribution.csv", index=False) |
| save_config_snapshot(config, output_dir) |
| LOGGER.info("Saved split metadata: %s", split_csv) |
| return split_csv |
|
|
|
|
| def kaggle_credentials_available() -> bool: |
| if Path.home().joinpath(".kaggle", "kaggle.json").exists(): |
| return True |
| return bool({"KAGGLE_USERNAME", "KAGGLE_KEY"}.issubset(set(os.environ))) |
|
|
|
|
| def download_kaggle_dataset(config: dict[str, Any]) -> Path: |
| dataset = config["kaggle"]["dataset"] |
| download_dir = ensure_dir(config["kaggle"]["download_dir"]) |
| if not kaggle_credentials_available(): |
| raise RuntimeError( |
| "Kaggle credentials were not found. Configure ~/.kaggle/kaggle.json or " |
| "KAGGLE_USERNAME/KAGGLE_KEY, then retry." |
| ) |
| command = shutil.which("kaggle") |
| if command: |
| cmd = [command, "datasets", "download", "-d", dataset, "-p", str(download_dir), "--unzip"] |
| else: |
| cmd = [sys.executable, "-m", "kaggle", "datasets", "download", "-d", dataset, "-p", str(download_dir), "--unzip"] |
| LOGGER.info("Downloading Kaggle dataset %s to %s", dataset, download_dir) |
| subprocess.run(cmd, check=True) |
| return download_dir |
|
|
|
|
| def prepare_data(config: dict[str, Any], data_dir: str | Path | None = None, download: bool = False) -> pd.DataFrame: |
| if download or config.get("kaggle", {}).get("enabled", False): |
| data_dir = download_kaggle_dataset(config) |
| config["paths"]["data_dir"] = str(data_dir) |
| df = discover_dataset(config, data_dir) |
| print_class_distribution(df) |
| save_split_metadata(df, config) |
| try: |
| from .reporting import plot_class_distribution |
|
|
| plot_class_distribution(df, Path(config["paths"]["output_dir"]) / "plots" / "class_distribution.png") |
| except Exception as exc: |
| LOGGER.warning("Could not save class distribution plot: %s", exc) |
| return df |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Discover and split egg damage image dataset.") |
| parser.add_argument("--config", default="configs/default.yaml") |
| parser.add_argument("--data-dir", default=None) |
| parser.add_argument("--download-kaggle", action="store_true") |
| args = parser.parse_args() |
| config = load_config(args.config) |
| if args.data_dir: |
| config["paths"]["data_dir"] = str(Path(args.data_dir).expanduser().resolve()) |
| prepare_data(config, args.data_dir, args.download_kaggle) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|