"""Agentic text-to-image prompt upsampling orchestration.""" from __future__ import annotations import json import traceback from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass from pathlib import Path from typing import Any, Protocol from agentic_upsampling.clients import GenerationOutput from agentic_upsampling.constants import DEFAULT_JPEG_QUALITY, DEFAULT_MAX_ITERATIONS, DEFAULT_SAMPLES_PER_ITERATION from agentic_upsampling.data import PromptItem, prompt_dir_name from agentic_upsampling.io_utils import read_json, write_json_atomic from agentic_upsampling.rubric import candidate_sort_key class RewriterLike(Protocol): def initial_prompt(self, item: PromptItem) -> dict[str, Any]: """Create an initial prompt.""" def rewrite_prompt_pair( self, item: PromptItem, previous_prompt: dict[str, Any], previous_negative_prompt: str, previous_analysis: dict[str, Any], history: list[dict[str, Any]], ) -> tuple[dict[str, Any], str]: """Jointly rewrite a positive prompt and negative prompt.""" class GeneratorLike(Protocol): def generate( self, *, prompt_json: dict[str, Any], prompt_id: str, output_dir: Path, seed: int | None = None, negative_prompt: str = "", jpeg_quality: int = DEFAULT_JPEG_QUALITY, ) -> GenerationOutput: """Generate one image.""" class JudgeLike(Protocol): def score_image( self, *, item: PromptItem, image_path: Path, ) -> dict[str, Any]: """Score one image.""" @dataclass(frozen=True, slots=True) class RunnerConfig: """Runtime settings for the agentic loop.""" output_dir: Path max_iterations: int = DEFAULT_MAX_ITERATIONS samples_per_iteration: int = DEFAULT_SAMPLES_PER_ITERATION overwrite: bool = False seed_base: int | None = None jpeg_quality: int = DEFAULT_JPEG_QUALITY initial_negative_prompt: str = "" early_stop: bool = True verbose: bool = True def __post_init__(self) -> None: if self.max_iterations < 1: raise ValueError("max_iterations must be >= 1.") if self.samples_per_iteration < 1: raise ValueError("samples_per_iteration must be >= 1.") @dataclass(frozen=True, slots=True) class IterationPrompt: """Positive and negative prompts prepared for one iteration.""" prompt_json: dict[str, Any] negative_prompt: str class AgenticUpsamplerRunner: """Run the iterative prompt rewrite, generate, and judge loop.""" rewriter: RewriterLike generator: GeneratorLike judge: JudgeLike config: RunnerConfig def __init__( self, *, rewriter: RewriterLike, generator: GeneratorLike, judge: JudgeLike, config: RunnerConfig, ) -> None: self.rewriter = rewriter self.generator = generator self.judge = judge self.config = config def run_item(self, item: PromptItem) -> dict[str, Any]: """Run all iterations for one prompt item and persist the best candidate.""" item_dir = self.config.output_dir / prompt_dir_name(item) item_dir.mkdir(parents=True, exist_ok=True) (item_dir / "failure.json").unlink(missing_ok=True) (item_dir / "incomplete.json").unlink(missing_ok=True) self._log(f"[prompt {item.prompt_id}] start") candidates: list[dict[str, Any]] = [] previous_prompt: dict[str, Any] | None = None previous_analysis: dict[str, Any] | None = None previous_negative_prompt = self.config.initial_negative_prompt.strip() incomplete_error: dict[str, Any] | None = None for iteration in range(self.config.max_iterations): iteration_dir = item_dir / f"iter_{iteration:02d}" candidate = None if self.config.overwrite else self._load_iteration(iteration_dir, iteration) if candidate is None: try: candidate = self._run_iteration( item, iteration_dir, iteration, previous_prompt, previous_analysis, previous_negative_prompt, candidates, ) except Exception as exc: if not candidates: raise incomplete_error = { "iteration": iteration, "error": repr(exc), "traceback": traceback.format_exc(), } write_json_atomic(item_dir / "incomplete.json", incomplete_error) self._log(f"[prompt {item.prompt_id}] incomplete at iter={iteration}: {exc!r}") break candidates.append(candidate) previous_prompt = candidate["prompt_json"] previous_analysis = candidate["analysis"] previous_negative_prompt = str(candidate.get("negative_prompt") or "") if self.config.early_stop and bool(candidate["analysis"].get("threshold_cleared")): self._log(f"[prompt {item.prompt_id}] early stop at iter={iteration}") break return self.finalize_item(item, candidates, incomplete_error=incomplete_error) def run_item_safely(self, item: PromptItem) -> dict[str, Any]: """Run one item and convert failures into structured records.""" try: return self.run_item(item) except Exception as exc: self._log(f"[prompt {item.prompt_id}] failed: {exc!r}") failure = { "prompt_id": item.prompt_id, "prompt": item.prompt, "error": repr(exc), "traceback": traceback.format_exc(), } failure_path = self.config.output_dir / prompt_dir_name(item) / "failure.json" write_json_atomic(failure_path, failure) return {"prompt_id": item.prompt_id, "error": repr(exc), "failure_path": str(failure_path)} def _run_iteration( self, item: PromptItem, iteration_dir: Path, iteration: int, previous_prompt: dict[str, Any] | None, previous_analysis: dict[str, Any] | None, previous_negative_prompt: str, candidates: list[dict[str, Any]], ) -> dict[str, Any]: prepared = self.prepare_iteration_prompt( item, iteration_dir, iteration, previous_prompt, previous_analysis, previous_negative_prompt, candidates, ) sample_candidates, sample_errors = self._run_iteration_samples( item, iteration_dir, iteration, prepared.prompt_json, prepared.negative_prompt, ) return self.finalize_iteration(item, iteration_dir, iteration, sample_candidates, sample_errors) def _run_iteration_samples( self, item: PromptItem, iteration_dir: Path, iteration: int, prompt_json: dict[str, Any], negative_prompt: str, ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """Generate seed samples concurrently, then judge successful images in sample order.""" generation_outputs: dict[int, GenerationOutput] = {} sample_errors: list[dict[str, Any]] = [] with ThreadPoolExecutor(max_workers=self.config.samples_per_iteration) as executor: future_to_sample_index = { executor.submit( self.run_generation_sample, item, iteration_dir, sample_index, prompt_json, negative_prompt, ): sample_index for sample_index in range(self.config.samples_per_iteration) } for future in as_completed(future_to_sample_index): sample_index = future_to_sample_index[future] try: generation_outputs[sample_index] = future.result() except Exception as exc: sample_errors.append(self._record_sample_error(item, iteration_dir, iteration, sample_index, exc)) sample_candidates: list[dict[str, Any]] = [] for sample_index in range(self.config.samples_per_iteration): generation = generation_outputs.get(sample_index) if generation is None: continue try: sample_candidates.append( self.judge_iteration_sample( item, iteration_dir, iteration, sample_index, prompt_json, negative_prompt, generation, ) ) except Exception as exc: sample_errors.append(self._record_sample_error(item, iteration_dir, iteration, sample_index, exc)) return sample_candidates, sample_errors def _record_sample_error( self, item: PromptItem, iteration_dir: Path, iteration: int, sample_index: int, exc: Exception, ) -> dict[str, Any]: """Persist one per-sample failure record.""" error = {"sample_index": sample_index, "error": repr(exc), "traceback": traceback.format_exc()} write_json_atomic(self._sample_dir(iteration_dir, sample_index) / "failure.json", error) self._log(f"[prompt {item.prompt_id}] iter={iteration} sample={sample_index} failed: {exc!r}") return error def prepare_iteration_prompt( self, item: PromptItem, iteration_dir: Path, iteration: int, previous_prompt: dict[str, Any] | None, previous_analysis: dict[str, Any] | None, previous_negative_prompt: str, candidates: list[dict[str, Any]], ) -> IterationPrompt: """Prepare and persist the positive/negative prompt pair for one iteration.""" iteration_dir.mkdir(parents=True, exist_ok=True) self._log(f"[prompt {item.prompt_id}] iter={iteration} start") if iteration == 0 or previous_prompt is None or previous_analysis is None: prompt_json = self.rewriter.initial_prompt(item) negative_prompt = self.config.initial_negative_prompt.strip() else: prompt_json, negative_prompt = self.rewriter.rewrite_prompt_pair( item, previous_prompt, previous_negative_prompt, previous_analysis, candidates, ) negative_prompt = negative_prompt.strip() write_json_atomic(iteration_dir / "prompt.json", prompt_json) write_json_atomic(iteration_dir / "negative_prompt.json", {"negative_prompt": negative_prompt}) return IterationPrompt(prompt_json=prompt_json, negative_prompt=negative_prompt) def _run_iteration_sample( self, item: PromptItem, iteration_dir: Path, iteration: int, sample_index: int, prompt_json: dict[str, Any], negative_prompt: str, ) -> dict[str, Any]: generation = self.run_generation_sample(item, iteration_dir, sample_index, prompt_json, negative_prompt) return self.judge_iteration_sample( item, iteration_dir, iteration, sample_index, prompt_json, negative_prompt, generation, ) def run_generation_sample( self, item: PromptItem, iteration_dir: Path, sample_index: int, prompt_json: dict[str, Any], negative_prompt: str, ) -> GenerationOutput: """Generate one sample image for an iteration.""" sample_dir = self._sample_dir(iteration_dir, sample_index) sample_dir.mkdir(parents=True, exist_ok=True) self._log(f"[prompt {item.prompt_id}] sample={sample_index} generate") return self.generator.generate( prompt_json=prompt_json, prompt_id=item.prompt_id, output_dir=sample_dir, seed=self._sample_seed(sample_index), negative_prompt=negative_prompt, jpeg_quality=self.config.jpeg_quality, ) def judge_iteration_sample( self, item: PromptItem, iteration_dir: Path, iteration: int, sample_index: int, prompt_json: dict[str, Any], negative_prompt: str, generation: GenerationOutput, ) -> dict[str, Any]: """Judge one generated sample and persist its candidate metadata.""" sample_dir = self._sample_dir(iteration_dir, sample_index) analysis = self.judge.score_image(item=item, image_path=generation.image_path) self._log(f"[prompt {item.prompt_id}] iter={iteration} sample={sample_index} score={analysis.get('overall_score')}") analysis_path = sample_dir / "analysis.json" write_json_atomic(analysis_path, analysis) candidate = { "prompt_id": item.prompt_id, "iteration": iteration, "sample_index": sample_index, "prompt_path": str(iteration_dir / "prompt.json"), "image_path": str(generation.image_path), "analysis_path": str(analysis_path), "generation_meta_path": str(generation.meta_path), "negative_prompt_path": str(iteration_dir / "negative_prompt.json"), "negative_prompt": negative_prompt, "prompt_json": prompt_json, "analysis": analysis, } write_json_atomic(sample_dir / "meta.json", candidate) return candidate def finalize_iteration( self, item: PromptItem, iteration_dir: Path, iteration: int, sample_candidates: list[dict[str, Any]], sample_errors: list[dict[str, Any]], ) -> dict[str, Any]: """Select and persist the best sample candidate for one iteration.""" if not sample_candidates: raise RuntimeError(f"All {self.config.samples_per_iteration} samples failed for iteration {iteration}.") write_json_atomic(iteration_dir / "samples.json", sample_candidates) candidate = dict(max(sample_candidates, key=candidate_sort_key)) candidate["samples"] = sample_candidates candidate["sample_count"] = len(sample_candidates) candidate["selected_sample_index"] = candidate["sample_index"] if sample_errors: candidate["sample_errors"] = sample_errors write_json_atomic(iteration_dir / "sample_failures.json", sample_errors) write_json_atomic(iteration_dir / "meta.json", candidate) self._log( f"[prompt {item.prompt_id}] iter={iteration} best_sample={candidate['selected_sample_index']} " f"score={candidate['analysis'].get('overall_score')} samples={len(sample_candidates)}" ) return candidate def finalize_item( self, item: PromptItem, candidates: list[dict[str, Any]], *, incomplete_error: dict[str, Any] | None = None, ) -> dict[str, Any]: """Persist and return the best candidate summary for a completed or incomplete item.""" if not candidates: raise RuntimeError(f"No candidates produced for prompt {item.prompt_id}.") item_dir = self.config.output_dir / prompt_dir_name(item) best = max(candidates, key=candidate_sort_key) summary = { "prompt_id": item.prompt_id, "prompt": item.prompt, "best_iteration": best["iteration"], "best_score": best["analysis"].get("overall_score"), "threshold_cleared_any": any(bool(candidate["analysis"].get("threshold_cleared")) for candidate in candidates), "best": best, "iterations": candidates, } if incomplete_error is not None: summary["incomplete_error"] = incomplete_error write_json_atomic(item_dir / "best.json", summary) self._log(f"[prompt {item.prompt_id}] done best_iter={summary['best_iteration']} best_score={summary['best_score']}") return summary def _log(self, message: str) -> None: if self.config.verbose: print(message, flush=True) def _sample_seed(self, sample_index: int) -> int | None: if self.config.seed_base is None: return None return self.config.seed_base + sample_index def _sample_dir(self, iteration_dir: Path, sample_index: int) -> Path: if self.config.samples_per_iteration == 1: return iteration_dir return iteration_dir / f"sample_{sample_index:02d}" @staticmethod def _load_iteration(iteration_dir: Path, iteration: int) -> dict[str, Any] | None: meta_path = iteration_dir / "meta.json" prompt_path = iteration_dir / "prompt.json" if not (meta_path.exists() and prompt_path.exists()): return None meta = read_json(meta_path) analysis_path = Path(str(meta.get("analysis_path") or iteration_dir / "analysis.json")) image_path = Path(str(meta.get("image_path") or iteration_dir / "image.jpg")) if not (analysis_path.exists() and image_path.exists()): return None meta["iteration"] = iteration meta["prompt_json"] = read_json(prompt_path) meta["analysis"] = read_json(analysis_path) negative_prompt_path = iteration_dir / "negative_prompt.json" if "negative_prompt" not in meta and negative_prompt_path.exists(): negative_prompt_data = read_json(negative_prompt_path) meta["negative_prompt"] = str(negative_prompt_data.get("negative_prompt") or "") meta["negative_prompt_path"] = str(negative_prompt_path) meta.setdefault("negative_prompt", "") samples_path = iteration_dir / "samples.json" if samples_path.exists(): samples = json.loads(samples_path.read_text(encoding="utf-8")) if isinstance(samples, list): meta["samples"] = samples meta["sample_count"] = len(samples) return meta def write_run_manifest(output_dir: Path, results: list[dict[str, Any]]) -> None: """Write compact run-level manifest files.""" manifest_path = output_dir / "manifest.jsonl" failures_path = output_dir / "failures.jsonl" manifest_path.unlink(missing_ok=True) failures_path.unlink(missing_ok=True) for result in results: target = failures_path if result.get("error") else manifest_path with target.open("a", encoding="utf-8") as f: f.write(json.dumps(result, ensure_ascii=True, separators=(",", ":")) + "\n")