File size: 6,284 Bytes
fdafd05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""Extract best agentic upsampling images from an output directory."""

from __future__ import annotations

import argparse
import csv
import json
import shutil
from pathlib import Path
from typing import Any

from agentic_upsampling.io_utils import append_jsonl

IMAGE_SUFFIXES = {".jpg", ".jpeg", ".png", ".webp"}


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--output-dir", type=Path, required=True, help="Agentic upsampler run output directory.")
    parser.add_argument(
        "--export-dir",
        type=Path,
        default=None,
        help="Directory for copied best images and manifests. Defaults to OUTPUT_DIR/best_generations.",
    )
    parser.add_argument("--overwrite", action="store_true", help="Replace existing copied images/manifests.")
    return parser.parse_args()


def iter_best_jsons(output_dir: Path) -> list[Path]:
    """Return per-prompt best.json files in deterministic order."""
    return sorted(path for path in output_dir.glob("*/best.json") if path.parent.name != "best_generations")


def resolve_image_path(raw_path: str, *, output_dir: Path, best_json_path: Path) -> Path:
    """Resolve image paths written by runs launched with relative or absolute output dirs."""
    image_path = Path(raw_path)
    candidates = [image_path]
    if not image_path.is_absolute():
        candidates.extend(
            [
                output_dir / image_path,
                output_dir.parent / image_path,
                best_json_path.parent / image_path.name,
            ]
        )
    for candidate in candidates:
        if candidate.exists():
            return candidate
    raise FileNotFoundError(f"Best image does not exist: {raw_path}")


def copied_image_name(record: dict[str, Any], image_path: Path) -> str:
    """Build a simple copied image filename."""
    prompt_id = str(record["prompt_id"])
    suffix = image_path.suffix.lower()
    if suffix not in IMAGE_SUFFIXES:
        suffix = ".jpg"
    return f"{prompt_id}{suffix}"


def extract_record(best_json_path: Path, *, output_dir: Path, images_dir: Path, overwrite: bool) -> dict[str, Any]:
    """Copy one best image and return its export manifest record."""
    best_data = json.loads(best_json_path.read_text(encoding="utf-8"))
    if not isinstance(best_data, dict):
        raise ValueError(f"{best_json_path} must contain a JSON object.")
    best = best_data.get("best")
    if not isinstance(best, dict):
        raise ValueError(f"{best_json_path} is missing best candidate metadata.")
    raw_image_path = str(best.get("image_path") or "")
    if not raw_image_path:
        raise ValueError(f"{best_json_path} best candidate is missing image_path.")
    image_path = resolve_image_path(raw_image_path, output_dir=output_dir, best_json_path=best_json_path)
    record = {
        "prompt_id": str(best_data["prompt_id"]),
        "prompt": str(best_data.get("prompt") or ""),
        "best_score": best_data.get("best_score"),
        "best_iteration": best_data.get("best_iteration"),
        "selected_sample_index": best.get("selected_sample_index", best.get("sample_index")),
        "threshold_cleared_any": bool(best_data.get("threshold_cleared_any")),
        "source_image_path": str(image_path),
        "best_json_path": str(best_json_path),
        "analysis_path": str(best.get("analysis_path") or ""),
    }
    dest_path = images_dir / copied_image_name(record, image_path)
    if dest_path.exists() and not overwrite:
        raise FileExistsError(f"Refusing to overwrite existing image: {dest_path}")
    images_dir.mkdir(parents=True, exist_ok=True)
    shutil.copy2(image_path, dest_path)
    record["copied_image_path"] = str(dest_path)
    return record


def write_csv(path: Path, records: list[dict[str, Any]]) -> None:
    """Write a flat CSV summary for quick spreadsheet inspection."""
    fieldnames = [
        "prompt_id",
        "best_score",
        "best_iteration",
        "selected_sample_index",
        "threshold_cleared_any",
        "copied_image_path",
        "source_image_path",
        "best_json_path",
        "analysis_path",
        "prompt",
    ]
    with path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for record in records:
            writer.writerow({key: record.get(key, "") for key in fieldnames})


def extract_best_images(output_dir: Path, export_dir: Path, *, overwrite: bool = False) -> list[dict[str, Any]]:
    """Copy best images from a run and write JSONL/CSV manifests."""
    output_dir = output_dir.expanduser()
    export_dir = export_dir.expanduser()
    if not output_dir.exists():
        raise FileNotFoundError(f"Missing output directory: {output_dir}")
    best_jsons = iter_best_jsons(output_dir)
    if not best_jsons:
        raise RuntimeError(f"No per-prompt best.json files found under {output_dir}")

    images_dir = export_dir / "images"
    manifest_path = export_dir / "best_generations.jsonl"
    csv_path = export_dir / "best_generations.csv"
    if overwrite:
        manifest_path.unlink(missing_ok=True)
        csv_path.unlink(missing_ok=True)
    elif manifest_path.exists() or csv_path.exists():
        raise FileExistsError(f"Export manifests already exist in {export_dir}; pass --overwrite to replace them.")

    records: list[dict[str, Any]] = []
    for best_json_path in best_jsons:
        record = extract_record(best_json_path, output_dir=output_dir, images_dir=images_dir, overwrite=overwrite)
        records.append(record)
        append_jsonl(manifest_path, record)
    write_csv(csv_path, records)
    return records


def main() -> int:
    args = parse_args()
    export_dir = args.export_dir or (args.output_dir / "best_generations")
    records = extract_best_images(args.output_dir, export_dir, overwrite=args.overwrite)
    print(f"Exported {len(records)} best images to {export_dir}", flush=True)
    print(f"Images: {export_dir / 'images'}", flush=True)
    print(f"JSONL:  {export_dir / 'best_generations.jsonl'}", flush=True)
    print(f"CSV:    {export_dir / 'best_generations.csv'}", flush=True)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())