Spaces:
Sleeping
Sleeping
| """src/training/manifests.py — Build and validate manifest CSV files.""" | |
| from __future__ import annotations | |
| import csv | |
| import random | |
| from pathlib import Path | |
| from typing import List, Tuple | |
| IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".webp"} | |
| def build_manifest( | |
| image_dirs: List[Tuple[Path, int, int]], # (dir, label, generator_idx) | |
| output_path: Path, | |
| seed: int = 42, | |
| train_ratio: float = 0.80, | |
| val_ratio: float = 0.10, | |
| ) -> dict: | |
| """ | |
| Walk image directories, build split manifests. | |
| Returns dict with train/val/test paths. | |
| """ | |
| rng = random.Random(seed) | |
| records = [] | |
| for img_dir, label, generator in image_dirs: | |
| for p in sorted(Path(img_dir).rglob("*")): | |
| if p.suffix.lower() in IMAGE_EXTS: | |
| records.append({ | |
| "filepath": str(p), | |
| "label": label, | |
| "generator": generator, | |
| }) | |
| rng.shuffle(records) | |
| n = len(records) | |
| n_train = int(n * train_ratio) | |
| n_val = int(n * val_ratio) | |
| splits = { | |
| "train": records[:n_train], | |
| "val": records[n_train:n_train + n_val], | |
| "test": records[n_train + n_val:], | |
| } | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| manifest_paths = {} | |
| for split, rows in splits.items(): | |
| out = output_path.parent / f"{output_path.stem}_{split}.csv" | |
| with open(out, "w", newline="") as f: | |
| writer = csv.DictWriter(f, fieldnames=["filepath", "label", "generator"]) | |
| writer.writeheader() | |
| writer.writerows(rows) | |
| manifest_paths[split] = out | |
| return manifest_paths | |
| def validate_manifest(manifest_path: Path) -> dict: | |
| """Check a manifest CSV is well-formed and all files exist.""" | |
| missing = [] | |
| counts = {"total": 0, "real": 0, "fake": 0} | |
| with open(manifest_path) as f: | |
| for row in csv.DictReader(f): | |
| counts["total"] += 1 | |
| if int(row["label"]) == 0: | |
| counts["real"] += 1 | |
| else: | |
| counts["fake"] += 1 | |
| if not Path(row["filepath"]).exists(): | |
| missing.append(row["filepath"]) | |
| return {"counts": counts, "missing": missing, "ok": len(missing) == 0} | |