mingyuliutw's picture
Super-squash branch 'main' using huggingface_hub
fdafd05
"""Extract best agentic upsampling images from an output directory."""
from __future__ import annotations
import argparse
import csv
import json
import shutil
from pathlib import Path
from typing import Any
from agentic_upsampling.io_utils import append_jsonl
IMAGE_SUFFIXES = {".jpg", ".jpeg", ".png", ".webp"}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--output-dir", type=Path, required=True, help="Agentic upsampler run output directory.")
parser.add_argument(
"--export-dir",
type=Path,
default=None,
help="Directory for copied best images and manifests. Defaults to OUTPUT_DIR/best_generations.",
)
parser.add_argument("--overwrite", action="store_true", help="Replace existing copied images/manifests.")
return parser.parse_args()
def iter_best_jsons(output_dir: Path) -> list[Path]:
"""Return per-prompt best.json files in deterministic order."""
return sorted(path for path in output_dir.glob("*/best.json") if path.parent.name != "best_generations")
def resolve_image_path(raw_path: str, *, output_dir: Path, best_json_path: Path) -> Path:
"""Resolve image paths written by runs launched with relative or absolute output dirs."""
image_path = Path(raw_path)
candidates = [image_path]
if not image_path.is_absolute():
candidates.extend(
[
output_dir / image_path,
output_dir.parent / image_path,
best_json_path.parent / image_path.name,
]
)
for candidate in candidates:
if candidate.exists():
return candidate
raise FileNotFoundError(f"Best image does not exist: {raw_path}")
def copied_image_name(record: dict[str, Any], image_path: Path) -> str:
"""Build a simple copied image filename."""
prompt_id = str(record["prompt_id"])
suffix = image_path.suffix.lower()
if suffix not in IMAGE_SUFFIXES:
suffix = ".jpg"
return f"{prompt_id}{suffix}"
def extract_record(best_json_path: Path, *, output_dir: Path, images_dir: Path, overwrite: bool) -> dict[str, Any]:
"""Copy one best image and return its export manifest record."""
best_data = json.loads(best_json_path.read_text(encoding="utf-8"))
if not isinstance(best_data, dict):
raise ValueError(f"{best_json_path} must contain a JSON object.")
best = best_data.get("best")
if not isinstance(best, dict):
raise ValueError(f"{best_json_path} is missing best candidate metadata.")
raw_image_path = str(best.get("image_path") or "")
if not raw_image_path:
raise ValueError(f"{best_json_path} best candidate is missing image_path.")
image_path = resolve_image_path(raw_image_path, output_dir=output_dir, best_json_path=best_json_path)
record = {
"prompt_id": str(best_data["prompt_id"]),
"prompt": str(best_data.get("prompt") or ""),
"best_score": best_data.get("best_score"),
"best_iteration": best_data.get("best_iteration"),
"selected_sample_index": best.get("selected_sample_index", best.get("sample_index")),
"threshold_cleared_any": bool(best_data.get("threshold_cleared_any")),
"source_image_path": str(image_path),
"best_json_path": str(best_json_path),
"analysis_path": str(best.get("analysis_path") or ""),
}
dest_path = images_dir / copied_image_name(record, image_path)
if dest_path.exists() and not overwrite:
raise FileExistsError(f"Refusing to overwrite existing image: {dest_path}")
images_dir.mkdir(parents=True, exist_ok=True)
shutil.copy2(image_path, dest_path)
record["copied_image_path"] = str(dest_path)
return record
def write_csv(path: Path, records: list[dict[str, Any]]) -> None:
"""Write a flat CSV summary for quick spreadsheet inspection."""
fieldnames = [
"prompt_id",
"best_score",
"best_iteration",
"selected_sample_index",
"threshold_cleared_any",
"copied_image_path",
"source_image_path",
"best_json_path",
"analysis_path",
"prompt",
]
with path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for record in records:
writer.writerow({key: record.get(key, "") for key in fieldnames})
def extract_best_images(output_dir: Path, export_dir: Path, *, overwrite: bool = False) -> list[dict[str, Any]]:
"""Copy best images from a run and write JSONL/CSV manifests."""
output_dir = output_dir.expanduser()
export_dir = export_dir.expanduser()
if not output_dir.exists():
raise FileNotFoundError(f"Missing output directory: {output_dir}")
best_jsons = iter_best_jsons(output_dir)
if not best_jsons:
raise RuntimeError(f"No per-prompt best.json files found under {output_dir}")
images_dir = export_dir / "images"
manifest_path = export_dir / "best_generations.jsonl"
csv_path = export_dir / "best_generations.csv"
if overwrite:
manifest_path.unlink(missing_ok=True)
csv_path.unlink(missing_ok=True)
elif manifest_path.exists() or csv_path.exists():
raise FileExistsError(f"Export manifests already exist in {export_dir}; pass --overwrite to replace them.")
records: list[dict[str, Any]] = []
for best_json_path in best_jsons:
record = extract_record(best_json_path, output_dir=output_dir, images_dir=images_dir, overwrite=overwrite)
records.append(record)
append_jsonl(manifest_path, record)
write_csv(csv_path, records)
return records
def main() -> int:
args = parse_args()
export_dir = args.export_dir or (args.output_dir / "best_generations")
records = extract_best_images(args.output_dir, export_dir, overwrite=args.overwrite)
print(f"Exported {len(records)} best images to {export_dir}", flush=True)
print(f"Images: {export_dir / 'images'}", flush=True)
print(f"JSONL: {export_dir / 'best_generations.jsonl'}", flush=True)
print(f"CSV: {export_dir / 'best_generations.csv'}", flush=True)
return 0
if __name__ == "__main__":
raise SystemExit(main())