budijuarto's picture
Upload src/egg_damage/data_discovery.py
5aba42e verified
from __future__ import annotations
import argparse
import os
import shutil
import subprocess
import sys
from pathlib import Path
from typing import Any
import pandas as pd
from sklearn.model_selection import train_test_split
from .config import load_config, save_config_snapshot
from .paths import ensure_dir
from .utils import get_logger
LOGGER = get_logger(__name__)
CANONICAL_LABELS = ("Not Damaged", "Damaged")
LABEL_TO_ID = {"Not Damaged": 0, "Damaged": 1}
ID_TO_LABEL = {0: "Not Damaged", 1: "Damaged"}
SPLIT_ALIASES = {
"train": {"train", "training", "trn"},
"val": {"val", "valid", "validation", "dev"},
"test": {"test", "testing", "tst"},
}
IGNORED_PARTS = {"__macosx", ".ipynb_checkpoints"}
def normalize_name(value: str) -> str:
return "".join(ch for ch in value.lower() if ch.isalnum())
def detect_label_from_name(name: str) -> str | None:
normalized = normalize_name(name)
not_damaged_markers = {
"notdamaged",
"undamaged",
"nodamage",
"nondamaged",
"intact",
"normal",
"healthy",
"good",
"clean",
"fresh",
}
damaged_markers = {"damaged", "damage", "cracked", "crack", "broken", "defect", "defective"}
if any(marker in normalized for marker in not_damaged_markers):
return "Not Damaged"
if any(marker in normalized for marker in damaged_markers):
return "Damaged"
return None
def detect_split_from_name(name: str) -> str | None:
normalized = normalize_name(name)
for split, aliases in SPLIT_ALIASES.items():
if normalized in {normalize_name(alias) for alias in aliases}:
return split
return None
def is_hidden_or_system(path: Path) -> bool:
return any(part.startswith(".") or part.lower() in IGNORED_PARTS for part in path.parts)
def iter_image_files(root: Path, extensions: set[str]) -> list[Path]:
images: list[Path] = []
for path in root.rglob("*"):
if path.is_file() and path.suffix.lower() in extensions and not is_hidden_or_system(path):
images.append(path.resolve())
return sorted(images)
def label_for_path(path: Path, root: Path) -> str | None:
try:
parts = path.relative_to(root).parts[:-1]
except ValueError:
parts = path.parts[:-1]
for part in reversed(parts):
label = detect_label_from_name(part)
if label:
return label
return detect_label_from_name(path.stem)
def split_for_path(path: Path, root: Path) -> str | None:
try:
parts = path.relative_to(root).parts[:-1]
except ValueError:
parts = path.parts[:-1]
for part in parts:
split = detect_split_from_name(part)
if split:
return split
return None
def build_labeled_dataframe(root: str | Path, config: dict[str, Any]) -> pd.DataFrame:
root = Path(root).expanduser().resolve()
if not root.exists():
raise FileNotFoundError(f"Dataset path does not exist: {root}")
extensions = {ext.lower() for ext in config["data"]["image_extensions"]}
rows: list[dict[str, str]] = []
for path in iter_image_files(root, extensions):
label = label_for_path(path, root)
if label is None:
continue
rows.append(
{
"filepath": str(path),
"label": label,
"split": split_for_path(path, root) or "",
}
)
if not rows:
raise ValueError(
"No labeled images were detected. Expected folders or filenames resembling "
"'Damaged', 'Not Damaged', 'cracked', 'undamaged', 'normal', or similar."
)
df = pd.DataFrame(rows).drop_duplicates(subset=["filepath"]).reset_index(drop=True)
labels = set(df["label"])
missing = set(CANONICAL_LABELS) - labels
if missing:
raise ValueError(f"Detected labels {sorted(labels)}, but missing classes: {sorted(missing)}")
return df
def create_stratified_splits(df: pd.DataFrame, config: dict[str, Any]) -> pd.DataFrame:
seed = int(config["seed"])
train_size = float(config["data"]["train_size"])
val_size = float(config["data"]["val_size"])
test_size = float(config["data"]["test_size"])
total = train_size + val_size + test_size
if abs(total - 1.0) > 1e-6:
train_size, val_size, test_size = train_size / total, val_size / total, test_size / total
if df["label"].value_counts().min() < 3:
raise ValueError("Each class needs at least 3 images for a 70/15/15 stratified split.")
train_df, temp_df = train_test_split(
df.drop(columns=["split"], errors="ignore"),
train_size=train_size,
stratify=df["label"],
random_state=seed,
)
relative_test = test_size / (val_size + test_size)
val_df, test_df = train_test_split(
temp_df,
test_size=relative_test,
stratify=temp_df["label"],
random_state=seed,
)
train_df = train_df.assign(split="train")
val_df = val_df.assign(split="val")
test_df = test_df.assign(split="test")
return pd.concat([train_df, val_df, test_df], ignore_index=True).sort_values(
["split", "label", "filepath"]
)
def complete_or_create_splits(df: pd.DataFrame, config: dict[str, Any]) -> pd.DataFrame:
known = df["split"].replace("", pd.NA).dropna()
if known.empty:
LOGGER.info("No existing train/val/test split folders detected; creating stratified splits.")
return create_stratified_splits(df, config)
df = df[df["split"].isin(["train", "val", "test"])].copy()
if df.empty:
return create_stratified_splits(df, config)
present = set(df["split"].unique())
if {"train", "val", "test"}.issubset(present):
LOGGER.info("Existing train/val/test split folders detected.")
return df.sort_values(["split", "label", "filepath"]).reset_index(drop=True)
if "train" in present and "val" not in present:
LOGGER.info("Existing split lacks validation data; carving validation from train only.")
train_mask = df["split"] == "train"
train_part = df[train_mask].drop(columns=["split"])
if train_part["label"].value_counts().min() >= 2:
new_train, new_val = train_test_split(
train_part,
test_size=float(config["data"]["val_size"]),
stratify=train_part["label"],
random_state=int(config["seed"]),
)
rest = df[~train_mask]
df = pd.concat(
[new_train.assign(split="train"), new_val.assign(split="val"), rest],
ignore_index=True,
)
missing = {"train", "val", "test"} - set(df["split"].unique())
if missing:
LOGGER.warning("Missing split(s) %s; evaluation will use the available splits.", sorted(missing))
return df.sort_values(["split", "label", "filepath"]).reset_index(drop=True)
def add_label_ids(df: pd.DataFrame) -> pd.DataFrame:
out = df.copy()
out["label_id"] = out["label"].map(LABEL_TO_ID).astype(int)
return out
def discover_dataset(config: dict[str, Any], data_dir: str | Path | None = None) -> pd.DataFrame:
root = Path(data_dir or config["paths"]["data_dir"]).expanduser().resolve()
df = build_labeled_dataframe(root, config)
df = complete_or_create_splits(df, config)
df = add_label_ids(df)
return df[["filepath", "label", "label_id", "split"]].reset_index(drop=True)
def class_distribution(df: pd.DataFrame) -> pd.DataFrame:
return (
df.groupby(["split", "label"], observed=False)
.size()
.reset_index(name="count")
.sort_values(["split", "label"])
)
def print_class_distribution(df: pd.DataFrame) -> None:
dist = class_distribution(df)
LOGGER.info("Class distribution:\n%s", dist.to_string(index=False))
for split, split_df in df.groupby("split"):
counts = split_df["label"].value_counts()
if len(counts) == 2:
ratio = counts.max() / max(counts.min(), 1)
LOGGER.info("%s imbalance ratio: %.2f", split, ratio)
def save_split_metadata(df: pd.DataFrame, config: dict[str, Any]) -> Path:
output_dir = ensure_dir(config["paths"]["output_dir"])
split_csv = Path(config["paths"]["split_csv"])
split_csv.parent.mkdir(parents=True, exist_ok=True)
df.to_csv(split_csv, index=False)
class_distribution(df).to_csv(output_dir / "class_distribution.csv", index=False)
save_config_snapshot(config, output_dir)
LOGGER.info("Saved split metadata: %s", split_csv)
return split_csv
def kaggle_credentials_available() -> bool:
if Path.home().joinpath(".kaggle", "kaggle.json").exists():
return True
return bool({"KAGGLE_USERNAME", "KAGGLE_KEY"}.issubset(set(os.environ)))
def download_kaggle_dataset(config: dict[str, Any]) -> Path:
dataset = config["kaggle"]["dataset"]
download_dir = ensure_dir(config["kaggle"]["download_dir"])
if not kaggle_credentials_available():
raise RuntimeError(
"Kaggle credentials were not found. Configure ~/.kaggle/kaggle.json or "
"KAGGLE_USERNAME/KAGGLE_KEY, then retry."
)
command = shutil.which("kaggle")
if command:
cmd = [command, "datasets", "download", "-d", dataset, "-p", str(download_dir), "--unzip"]
else:
cmd = [sys.executable, "-m", "kaggle", "datasets", "download", "-d", dataset, "-p", str(download_dir), "--unzip"]
LOGGER.info("Downloading Kaggle dataset %s to %s", dataset, download_dir)
subprocess.run(cmd, check=True)
return download_dir
def prepare_data(config: dict[str, Any], data_dir: str | Path | None = None, download: bool = False) -> pd.DataFrame:
if download or config.get("kaggle", {}).get("enabled", False):
data_dir = download_kaggle_dataset(config)
config["paths"]["data_dir"] = str(data_dir)
df = discover_dataset(config, data_dir)
print_class_distribution(df)
save_split_metadata(df, config)
try:
from .reporting import plot_class_distribution
plot_class_distribution(df, Path(config["paths"]["output_dir"]) / "plots" / "class_distribution.png")
except Exception as exc:
LOGGER.warning("Could not save class distribution plot: %s", exc)
return df
def main() -> None:
parser = argparse.ArgumentParser(description="Discover and split egg damage image dataset.")
parser.add_argument("--config", default="configs/default.yaml")
parser.add_argument("--data-dir", default=None)
parser.add_argument("--download-kaggle", action="store_true")
args = parser.parse_args()
config = load_config(args.config)
if args.data_dir:
config["paths"]["data_dir"] = str(Path(args.data_dir).expanduser().resolve())
prepare_data(config, args.data_dir, args.download_kaggle)
if __name__ == "__main__":
main()