mingyuliutw's picture
Super-squash branch 'main' using huggingface_hub
fdafd05
raw
history blame
19.1 kB
"""Agentic text-to-image prompt upsampling orchestration."""
from __future__ import annotations
import json
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Protocol
from agentic_upsampling.clients import GenerationOutput
from agentic_upsampling.constants import DEFAULT_JPEG_QUALITY, DEFAULT_MAX_ITERATIONS, DEFAULT_SAMPLES_PER_ITERATION
from agentic_upsampling.data import PromptItem, prompt_dir_name
from agentic_upsampling.io_utils import read_json, write_json_atomic
from agentic_upsampling.rubric import candidate_sort_key
class RewriterLike(Protocol):
def initial_prompt(self, item: PromptItem) -> dict[str, Any]:
"""Create an initial prompt."""
def rewrite_prompt_pair(
self,
item: PromptItem,
previous_prompt: dict[str, Any],
previous_negative_prompt: str,
previous_analysis: dict[str, Any],
history: list[dict[str, Any]],
) -> tuple[dict[str, Any], str]:
"""Jointly rewrite a positive prompt and negative prompt."""
class GeneratorLike(Protocol):
def generate(
self,
*,
prompt_json: dict[str, Any],
prompt_id: str,
output_dir: Path,
seed: int | None = None,
negative_prompt: str = "",
jpeg_quality: int = DEFAULT_JPEG_QUALITY,
) -> GenerationOutput:
"""Generate one image."""
class JudgeLike(Protocol):
def score_image(
self,
*,
item: PromptItem,
image_path: Path,
) -> dict[str, Any]:
"""Score one image."""
@dataclass(frozen=True, slots=True)
class RunnerConfig:
"""Runtime settings for the agentic loop."""
output_dir: Path
max_iterations: int = DEFAULT_MAX_ITERATIONS
samples_per_iteration: int = DEFAULT_SAMPLES_PER_ITERATION
overwrite: bool = False
seed_base: int | None = None
jpeg_quality: int = DEFAULT_JPEG_QUALITY
initial_negative_prompt: str = ""
early_stop: bool = True
verbose: bool = True
def __post_init__(self) -> None:
if self.max_iterations < 1:
raise ValueError("max_iterations must be >= 1.")
if self.samples_per_iteration < 1:
raise ValueError("samples_per_iteration must be >= 1.")
@dataclass(frozen=True, slots=True)
class IterationPrompt:
"""Positive and negative prompts prepared for one iteration."""
prompt_json: dict[str, Any]
negative_prompt: str
class AgenticUpsamplerRunner:
"""Run the iterative prompt rewrite, generate, and judge loop."""
rewriter: RewriterLike
generator: GeneratorLike
judge: JudgeLike
config: RunnerConfig
def __init__(
self,
*,
rewriter: RewriterLike,
generator: GeneratorLike,
judge: JudgeLike,
config: RunnerConfig,
) -> None:
self.rewriter = rewriter
self.generator = generator
self.judge = judge
self.config = config
def run_item(self, item: PromptItem) -> dict[str, Any]:
"""Run all iterations for one prompt item and persist the best candidate."""
item_dir = self.config.output_dir / prompt_dir_name(item)
item_dir.mkdir(parents=True, exist_ok=True)
(item_dir / "failure.json").unlink(missing_ok=True)
(item_dir / "incomplete.json").unlink(missing_ok=True)
self._log(f"[prompt {item.prompt_id}] start")
candidates: list[dict[str, Any]] = []
previous_prompt: dict[str, Any] | None = None
previous_analysis: dict[str, Any] | None = None
previous_negative_prompt = self.config.initial_negative_prompt.strip()
incomplete_error: dict[str, Any] | None = None
for iteration in range(self.config.max_iterations):
iteration_dir = item_dir / f"iter_{iteration:02d}"
candidate = None if self.config.overwrite else self._load_iteration(iteration_dir, iteration)
if candidate is None:
try:
candidate = self._run_iteration(
item,
iteration_dir,
iteration,
previous_prompt,
previous_analysis,
previous_negative_prompt,
candidates,
)
except Exception as exc:
if not candidates:
raise
incomplete_error = {
"iteration": iteration,
"error": repr(exc),
"traceback": traceback.format_exc(),
}
write_json_atomic(item_dir / "incomplete.json", incomplete_error)
self._log(f"[prompt {item.prompt_id}] incomplete at iter={iteration}: {exc!r}")
break
candidates.append(candidate)
previous_prompt = candidate["prompt_json"]
previous_analysis = candidate["analysis"]
previous_negative_prompt = str(candidate.get("negative_prompt") or "")
if self.config.early_stop and bool(candidate["analysis"].get("threshold_cleared")):
self._log(f"[prompt {item.prompt_id}] early stop at iter={iteration}")
break
return self.finalize_item(item, candidates, incomplete_error=incomplete_error)
def run_item_safely(self, item: PromptItem) -> dict[str, Any]:
"""Run one item and convert failures into structured records."""
try:
return self.run_item(item)
except Exception as exc:
self._log(f"[prompt {item.prompt_id}] failed: {exc!r}")
failure = {
"prompt_id": item.prompt_id,
"prompt": item.prompt,
"error": repr(exc),
"traceback": traceback.format_exc(),
}
failure_path = self.config.output_dir / prompt_dir_name(item) / "failure.json"
write_json_atomic(failure_path, failure)
return {"prompt_id": item.prompt_id, "error": repr(exc), "failure_path": str(failure_path)}
def _run_iteration(
self,
item: PromptItem,
iteration_dir: Path,
iteration: int,
previous_prompt: dict[str, Any] | None,
previous_analysis: dict[str, Any] | None,
previous_negative_prompt: str,
candidates: list[dict[str, Any]],
) -> dict[str, Any]:
prepared = self.prepare_iteration_prompt(
item,
iteration_dir,
iteration,
previous_prompt,
previous_analysis,
previous_negative_prompt,
candidates,
)
sample_candidates, sample_errors = self._run_iteration_samples(
item,
iteration_dir,
iteration,
prepared.prompt_json,
prepared.negative_prompt,
)
return self.finalize_iteration(item, iteration_dir, iteration, sample_candidates, sample_errors)
def _run_iteration_samples(
self,
item: PromptItem,
iteration_dir: Path,
iteration: int,
prompt_json: dict[str, Any],
negative_prompt: str,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""Generate seed samples concurrently, then judge successful images in sample order."""
generation_outputs: dict[int, GenerationOutput] = {}
sample_errors: list[dict[str, Any]] = []
with ThreadPoolExecutor(max_workers=self.config.samples_per_iteration) as executor:
future_to_sample_index = {
executor.submit(
self.run_generation_sample,
item,
iteration_dir,
sample_index,
prompt_json,
negative_prompt,
): sample_index
for sample_index in range(self.config.samples_per_iteration)
}
for future in as_completed(future_to_sample_index):
sample_index = future_to_sample_index[future]
try:
generation_outputs[sample_index] = future.result()
except Exception as exc:
sample_errors.append(self._record_sample_error(item, iteration_dir, iteration, sample_index, exc))
sample_candidates: list[dict[str, Any]] = []
for sample_index in range(self.config.samples_per_iteration):
generation = generation_outputs.get(sample_index)
if generation is None:
continue
try:
sample_candidates.append(
self.judge_iteration_sample(
item,
iteration_dir,
iteration,
sample_index,
prompt_json,
negative_prompt,
generation,
)
)
except Exception as exc:
sample_errors.append(self._record_sample_error(item, iteration_dir, iteration, sample_index, exc))
return sample_candidates, sample_errors
def _record_sample_error(
self,
item: PromptItem,
iteration_dir: Path,
iteration: int,
sample_index: int,
exc: Exception,
) -> dict[str, Any]:
"""Persist one per-sample failure record."""
error = {"sample_index": sample_index, "error": repr(exc), "traceback": traceback.format_exc()}
write_json_atomic(self._sample_dir(iteration_dir, sample_index) / "failure.json", error)
self._log(f"[prompt {item.prompt_id}] iter={iteration} sample={sample_index} failed: {exc!r}")
return error
def prepare_iteration_prompt(
self,
item: PromptItem,
iteration_dir: Path,
iteration: int,
previous_prompt: dict[str, Any] | None,
previous_analysis: dict[str, Any] | None,
previous_negative_prompt: str,
candidates: list[dict[str, Any]],
) -> IterationPrompt:
"""Prepare and persist the positive/negative prompt pair for one iteration."""
iteration_dir.mkdir(parents=True, exist_ok=True)
self._log(f"[prompt {item.prompt_id}] iter={iteration} start")
if iteration == 0 or previous_prompt is None or previous_analysis is None:
prompt_json = self.rewriter.initial_prompt(item)
negative_prompt = self.config.initial_negative_prompt.strip()
else:
prompt_json, negative_prompt = self.rewriter.rewrite_prompt_pair(
item,
previous_prompt,
previous_negative_prompt,
previous_analysis,
candidates,
)
negative_prompt = negative_prompt.strip()
write_json_atomic(iteration_dir / "prompt.json", prompt_json)
write_json_atomic(iteration_dir / "negative_prompt.json", {"negative_prompt": negative_prompt})
return IterationPrompt(prompt_json=prompt_json, negative_prompt=negative_prompt)
def _run_iteration_sample(
self,
item: PromptItem,
iteration_dir: Path,
iteration: int,
sample_index: int,
prompt_json: dict[str, Any],
negative_prompt: str,
) -> dict[str, Any]:
generation = self.run_generation_sample(item, iteration_dir, sample_index, prompt_json, negative_prompt)
return self.judge_iteration_sample(
item,
iteration_dir,
iteration,
sample_index,
prompt_json,
negative_prompt,
generation,
)
def run_generation_sample(
self,
item: PromptItem,
iteration_dir: Path,
sample_index: int,
prompt_json: dict[str, Any],
negative_prompt: str,
) -> GenerationOutput:
"""Generate one sample image for an iteration."""
sample_dir = self._sample_dir(iteration_dir, sample_index)
sample_dir.mkdir(parents=True, exist_ok=True)
self._log(f"[prompt {item.prompt_id}] sample={sample_index} generate")
return self.generator.generate(
prompt_json=prompt_json,
prompt_id=item.prompt_id,
output_dir=sample_dir,
seed=self._sample_seed(sample_index),
negative_prompt=negative_prompt,
jpeg_quality=self.config.jpeg_quality,
)
def judge_iteration_sample(
self,
item: PromptItem,
iteration_dir: Path,
iteration: int,
sample_index: int,
prompt_json: dict[str, Any],
negative_prompt: str,
generation: GenerationOutput,
) -> dict[str, Any]:
"""Judge one generated sample and persist its candidate metadata."""
sample_dir = self._sample_dir(iteration_dir, sample_index)
analysis = self.judge.score_image(item=item, image_path=generation.image_path)
self._log(f"[prompt {item.prompt_id}] iter={iteration} sample={sample_index} score={analysis.get('overall_score')}")
analysis_path = sample_dir / "analysis.json"
write_json_atomic(analysis_path, analysis)
candidate = {
"prompt_id": item.prompt_id,
"iteration": iteration,
"sample_index": sample_index,
"prompt_path": str(iteration_dir / "prompt.json"),
"image_path": str(generation.image_path),
"analysis_path": str(analysis_path),
"generation_meta_path": str(generation.meta_path),
"negative_prompt_path": str(iteration_dir / "negative_prompt.json"),
"negative_prompt": negative_prompt,
"prompt_json": prompt_json,
"analysis": analysis,
}
write_json_atomic(sample_dir / "meta.json", candidate)
return candidate
def finalize_iteration(
self,
item: PromptItem,
iteration_dir: Path,
iteration: int,
sample_candidates: list[dict[str, Any]],
sample_errors: list[dict[str, Any]],
) -> dict[str, Any]:
"""Select and persist the best sample candidate for one iteration."""
if not sample_candidates:
raise RuntimeError(f"All {self.config.samples_per_iteration} samples failed for iteration {iteration}.")
write_json_atomic(iteration_dir / "samples.json", sample_candidates)
candidate = dict(max(sample_candidates, key=candidate_sort_key))
candidate["samples"] = sample_candidates
candidate["sample_count"] = len(sample_candidates)
candidate["selected_sample_index"] = candidate["sample_index"]
if sample_errors:
candidate["sample_errors"] = sample_errors
write_json_atomic(iteration_dir / "sample_failures.json", sample_errors)
write_json_atomic(iteration_dir / "meta.json", candidate)
self._log(
f"[prompt {item.prompt_id}] iter={iteration} best_sample={candidate['selected_sample_index']} "
f"score={candidate['analysis'].get('overall_score')} samples={len(sample_candidates)}"
)
return candidate
def finalize_item(
self,
item: PromptItem,
candidates: list[dict[str, Any]],
*,
incomplete_error: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Persist and return the best candidate summary for a completed or incomplete item."""
if not candidates:
raise RuntimeError(f"No candidates produced for prompt {item.prompt_id}.")
item_dir = self.config.output_dir / prompt_dir_name(item)
best = max(candidates, key=candidate_sort_key)
summary = {
"prompt_id": item.prompt_id,
"prompt": item.prompt,
"best_iteration": best["iteration"],
"best_score": best["analysis"].get("overall_score"),
"threshold_cleared_any": any(bool(candidate["analysis"].get("threshold_cleared")) for candidate in candidates),
"best": best,
"iterations": candidates,
}
if incomplete_error is not None:
summary["incomplete_error"] = incomplete_error
write_json_atomic(item_dir / "best.json", summary)
self._log(f"[prompt {item.prompt_id}] done best_iter={summary['best_iteration']} best_score={summary['best_score']}")
return summary
def _log(self, message: str) -> None:
if self.config.verbose:
print(message, flush=True)
def _sample_seed(self, sample_index: int) -> int | None:
if self.config.seed_base is None:
return None
return self.config.seed_base + sample_index
def _sample_dir(self, iteration_dir: Path, sample_index: int) -> Path:
if self.config.samples_per_iteration == 1:
return iteration_dir
return iteration_dir / f"sample_{sample_index:02d}"
@staticmethod
def _load_iteration(iteration_dir: Path, iteration: int) -> dict[str, Any] | None:
meta_path = iteration_dir / "meta.json"
prompt_path = iteration_dir / "prompt.json"
if not (meta_path.exists() and prompt_path.exists()):
return None
meta = read_json(meta_path)
analysis_path = Path(str(meta.get("analysis_path") or iteration_dir / "analysis.json"))
image_path = Path(str(meta.get("image_path") or iteration_dir / "image.jpg"))
if not (analysis_path.exists() and image_path.exists()):
return None
meta["iteration"] = iteration
meta["prompt_json"] = read_json(prompt_path)
meta["analysis"] = read_json(analysis_path)
negative_prompt_path = iteration_dir / "negative_prompt.json"
if "negative_prompt" not in meta and negative_prompt_path.exists():
negative_prompt_data = read_json(negative_prompt_path)
meta["negative_prompt"] = str(negative_prompt_data.get("negative_prompt") or "")
meta["negative_prompt_path"] = str(negative_prompt_path)
meta.setdefault("negative_prompt", "")
samples_path = iteration_dir / "samples.json"
if samples_path.exists():
samples = json.loads(samples_path.read_text(encoding="utf-8"))
if isinstance(samples, list):
meta["samples"] = samples
meta["sample_count"] = len(samples)
return meta
def write_run_manifest(output_dir: Path, results: list[dict[str, Any]]) -> None:
"""Write compact run-level manifest files."""
manifest_path = output_dir / "manifest.jsonl"
failures_path = output_dir / "failures.jsonl"
manifest_path.unlink(missing_ok=True)
failures_path.unlink(missing_ok=True)
for result in results:
target = failures_path if result.get("error") else manifest_path
with target.open("a", encoding="utf-8") as f:
f.write(json.dumps(result, ensure_ascii=True, separators=(",", ":")) + "\n")