File size: 10,910 Bytes
5aba42e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
from __future__ import annotations

import argparse
import os
import shutil
import subprocess
import sys
from pathlib import Path
from typing import Any

import pandas as pd
from sklearn.model_selection import train_test_split

from .config import load_config, save_config_snapshot
from .paths import ensure_dir
from .utils import get_logger


LOGGER = get_logger(__name__)

CANONICAL_LABELS = ("Not Damaged", "Damaged")
LABEL_TO_ID = {"Not Damaged": 0, "Damaged": 1}
ID_TO_LABEL = {0: "Not Damaged", 1: "Damaged"}
SPLIT_ALIASES = {
    "train": {"train", "training", "trn"},
    "val": {"val", "valid", "validation", "dev"},
    "test": {"test", "testing", "tst"},
}
IGNORED_PARTS = {"__macosx", ".ipynb_checkpoints"}


def normalize_name(value: str) -> str:
    return "".join(ch for ch in value.lower() if ch.isalnum())


def detect_label_from_name(name: str) -> str | None:
    normalized = normalize_name(name)
    not_damaged_markers = {
        "notdamaged",
        "undamaged",
        "nodamage",
        "nondamaged",
        "intact",
        "normal",
        "healthy",
        "good",
        "clean",
        "fresh",
    }
    damaged_markers = {"damaged", "damage", "cracked", "crack", "broken", "defect", "defective"}
    if any(marker in normalized for marker in not_damaged_markers):
        return "Not Damaged"
    if any(marker in normalized for marker in damaged_markers):
        return "Damaged"
    return None


def detect_split_from_name(name: str) -> str | None:
    normalized = normalize_name(name)
    for split, aliases in SPLIT_ALIASES.items():
        if normalized in {normalize_name(alias) for alias in aliases}:
            return split
    return None


def is_hidden_or_system(path: Path) -> bool:
    return any(part.startswith(".") or part.lower() in IGNORED_PARTS for part in path.parts)


def iter_image_files(root: Path, extensions: set[str]) -> list[Path]:
    images: list[Path] = []
    for path in root.rglob("*"):
        if path.is_file() and path.suffix.lower() in extensions and not is_hidden_or_system(path):
            images.append(path.resolve())
    return sorted(images)


def label_for_path(path: Path, root: Path) -> str | None:
    try:
        parts = path.relative_to(root).parts[:-1]
    except ValueError:
        parts = path.parts[:-1]
    for part in reversed(parts):
        label = detect_label_from_name(part)
        if label:
            return label
    return detect_label_from_name(path.stem)


def split_for_path(path: Path, root: Path) -> str | None:
    try:
        parts = path.relative_to(root).parts[:-1]
    except ValueError:
        parts = path.parts[:-1]
    for part in parts:
        split = detect_split_from_name(part)
        if split:
            return split
    return None


def build_labeled_dataframe(root: str | Path, config: dict[str, Any]) -> pd.DataFrame:
    root = Path(root).expanduser().resolve()
    if not root.exists():
        raise FileNotFoundError(f"Dataset path does not exist: {root}")
    extensions = {ext.lower() for ext in config["data"]["image_extensions"]}
    rows: list[dict[str, str]] = []
    for path in iter_image_files(root, extensions):
        label = label_for_path(path, root)
        if label is None:
            continue
        rows.append(
            {
                "filepath": str(path),
                "label": label,
                "split": split_for_path(path, root) or "",
            }
        )
    if not rows:
        raise ValueError(
            "No labeled images were detected. Expected folders or filenames resembling "
            "'Damaged', 'Not Damaged', 'cracked', 'undamaged', 'normal', or similar."
        )
    df = pd.DataFrame(rows).drop_duplicates(subset=["filepath"]).reset_index(drop=True)
    labels = set(df["label"])
    missing = set(CANONICAL_LABELS) - labels
    if missing:
        raise ValueError(f"Detected labels {sorted(labels)}, but missing classes: {sorted(missing)}")
    return df


def create_stratified_splits(df: pd.DataFrame, config: dict[str, Any]) -> pd.DataFrame:
    seed = int(config["seed"])
    train_size = float(config["data"]["train_size"])
    val_size = float(config["data"]["val_size"])
    test_size = float(config["data"]["test_size"])
    total = train_size + val_size + test_size
    if abs(total - 1.0) > 1e-6:
        train_size, val_size, test_size = train_size / total, val_size / total, test_size / total

    if df["label"].value_counts().min() < 3:
        raise ValueError("Each class needs at least 3 images for a 70/15/15 stratified split.")

    train_df, temp_df = train_test_split(
        df.drop(columns=["split"], errors="ignore"),
        train_size=train_size,
        stratify=df["label"],
        random_state=seed,
    )
    relative_test = test_size / (val_size + test_size)
    val_df, test_df = train_test_split(
        temp_df,
        test_size=relative_test,
        stratify=temp_df["label"],
        random_state=seed,
    )
    train_df = train_df.assign(split="train")
    val_df = val_df.assign(split="val")
    test_df = test_df.assign(split="test")
    return pd.concat([train_df, val_df, test_df], ignore_index=True).sort_values(
        ["split", "label", "filepath"]
    )


def complete_or_create_splits(df: pd.DataFrame, config: dict[str, Any]) -> pd.DataFrame:
    known = df["split"].replace("", pd.NA).dropna()
    if known.empty:
        LOGGER.info("No existing train/val/test split folders detected; creating stratified splits.")
        return create_stratified_splits(df, config)

    df = df[df["split"].isin(["train", "val", "test"])].copy()
    if df.empty:
        return create_stratified_splits(df, config)
    present = set(df["split"].unique())
    if {"train", "val", "test"}.issubset(present):
        LOGGER.info("Existing train/val/test split folders detected.")
        return df.sort_values(["split", "label", "filepath"]).reset_index(drop=True)
    if "train" in present and "val" not in present:
        LOGGER.info("Existing split lacks validation data; carving validation from train only.")
        train_mask = df["split"] == "train"
        train_part = df[train_mask].drop(columns=["split"])
        if train_part["label"].value_counts().min() >= 2:
            new_train, new_val = train_test_split(
                train_part,
                test_size=float(config["data"]["val_size"]),
                stratify=train_part["label"],
                random_state=int(config["seed"]),
            )
            rest = df[~train_mask]
            df = pd.concat(
                [new_train.assign(split="train"), new_val.assign(split="val"), rest],
                ignore_index=True,
            )
    missing = {"train", "val", "test"} - set(df["split"].unique())
    if missing:
        LOGGER.warning("Missing split(s) %s; evaluation will use the available splits.", sorted(missing))
    return df.sort_values(["split", "label", "filepath"]).reset_index(drop=True)


def add_label_ids(df: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()
    out["label_id"] = out["label"].map(LABEL_TO_ID).astype(int)
    return out


def discover_dataset(config: dict[str, Any], data_dir: str | Path | None = None) -> pd.DataFrame:
    root = Path(data_dir or config["paths"]["data_dir"]).expanduser().resolve()
    df = build_labeled_dataframe(root, config)
    df = complete_or_create_splits(df, config)
    df = add_label_ids(df)
    return df[["filepath", "label", "label_id", "split"]].reset_index(drop=True)


def class_distribution(df: pd.DataFrame) -> pd.DataFrame:
    return (
        df.groupby(["split", "label"], observed=False)
        .size()
        .reset_index(name="count")
        .sort_values(["split", "label"])
    )


def print_class_distribution(df: pd.DataFrame) -> None:
    dist = class_distribution(df)
    LOGGER.info("Class distribution:\n%s", dist.to_string(index=False))
    for split, split_df in df.groupby("split"):
        counts = split_df["label"].value_counts()
        if len(counts) == 2:
            ratio = counts.max() / max(counts.min(), 1)
            LOGGER.info("%s imbalance ratio: %.2f", split, ratio)


def save_split_metadata(df: pd.DataFrame, config: dict[str, Any]) -> Path:
    output_dir = ensure_dir(config["paths"]["output_dir"])
    split_csv = Path(config["paths"]["split_csv"])
    split_csv.parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(split_csv, index=False)
    class_distribution(df).to_csv(output_dir / "class_distribution.csv", index=False)
    save_config_snapshot(config, output_dir)
    LOGGER.info("Saved split metadata: %s", split_csv)
    return split_csv


def kaggle_credentials_available() -> bool:
    if Path.home().joinpath(".kaggle", "kaggle.json").exists():
        return True
    return bool({"KAGGLE_USERNAME", "KAGGLE_KEY"}.issubset(set(os.environ)))


def download_kaggle_dataset(config: dict[str, Any]) -> Path:
    dataset = config["kaggle"]["dataset"]
    download_dir = ensure_dir(config["kaggle"]["download_dir"])
    if not kaggle_credentials_available():
        raise RuntimeError(
            "Kaggle credentials were not found. Configure ~/.kaggle/kaggle.json or "
            "KAGGLE_USERNAME/KAGGLE_KEY, then retry."
        )
    command = shutil.which("kaggle")
    if command:
        cmd = [command, "datasets", "download", "-d", dataset, "-p", str(download_dir), "--unzip"]
    else:
        cmd = [sys.executable, "-m", "kaggle", "datasets", "download", "-d", dataset, "-p", str(download_dir), "--unzip"]
    LOGGER.info("Downloading Kaggle dataset %s to %s", dataset, download_dir)
    subprocess.run(cmd, check=True)
    return download_dir


def prepare_data(config: dict[str, Any], data_dir: str | Path | None = None, download: bool = False) -> pd.DataFrame:
    if download or config.get("kaggle", {}).get("enabled", False):
        data_dir = download_kaggle_dataset(config)
        config["paths"]["data_dir"] = str(data_dir)
    df = discover_dataset(config, data_dir)
    print_class_distribution(df)
    save_split_metadata(df, config)
    try:
        from .reporting import plot_class_distribution

        plot_class_distribution(df, Path(config["paths"]["output_dir"]) / "plots" / "class_distribution.png")
    except Exception as exc:
        LOGGER.warning("Could not save class distribution plot: %s", exc)
    return df


def main() -> None:
    parser = argparse.ArgumentParser(description="Discover and split egg damage image dataset.")
    parser.add_argument("--config", default="configs/default.yaml")
    parser.add_argument("--data-dir", default=None)
    parser.add_argument("--download-kaggle", action="store_true")
    args = parser.parse_args()
    config = load_config(args.config)
    if args.data_dir:
        config["paths"]["data_dir"] = str(Path(args.data_dir).expanduser().resolve())
    prepare_data(config, args.data_dir, args.download_kaggle)


if __name__ == "__main__":
    main()