mingyuliutw's picture
Super-squash branch 'main' using huggingface_hub
fdafd05
raw
history blame
8.4 kB
"""CLI for standalone agentic Cosmos3 text-to-image prompt upsampling."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from agentic_upsampling.clients import (
ImageGenerationClient,
PromptRewriterClient,
VLMQualityJudge,
read_api_token,
read_optional_generation_auth_key,
)
from agentic_upsampling.constants import (
DEFAULT_ASPECT_RATIO,
DEFAULT_CRITIC_ENDPOINT_URL,
DEFAULT_CRITIC_MODEL,
DEFAULT_FLOW_SHIFT,
DEFAULT_GENERATION_AUTH_KEY_ENV,
DEFAULT_GENERATION_EXTRA_ARGS,
DEFAULT_GENERATION_MODEL,
DEFAULT_GEMINI_API_KEY_ENV,
DEFAULT_GUIDANCE,
DEFAULT_IMAGE_SIZE,
DEFAULT_LLM_EXTRA_BODY,
DEFAULT_MAX_ITERATIONS,
DEFAULT_NUM_STEPS,
DEFAULT_OPENAI_API_KEY_ENV,
DEFAULT_RESOLUTION,
DEFAULT_REWRITER_ENDPOINT_URL,
DEFAULT_REWRITER_MODEL,
DEFAULT_SAMPLES_PER_ITERATION,
DEFAULT_UPSAMPLER_ENDPOINT_URL,
DEFAULT_UPSAMPLER_MODEL,
)
from agentic_upsampling.data import load_prompt_items
from agentic_upsampling.extract_best import extract_best_images
from agentic_upsampling.io_utils import write_json_atomic
from agentic_upsampling.runner import AgenticUpsamplerRunner, RunnerConfig, write_run_manifest
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
input_group = parser.add_mutually_exclusive_group(required=True)
input_group.add_argument("--prompt", default=None, help="Single text prompt to run.")
input_group.add_argument("--prompts", type=Path, default=None, help="Path to .txt, .jsonl, or .csv prompts.")
parser.add_argument("--limit", type=int, default=None, help="Optional maximum number of prompts to run.")
parser.add_argument("--output-dir", type=Path, required=True)
parser.add_argument("--overwrite", action="store_true")
parser.add_argument("--max-iterations", type=int, default=DEFAULT_MAX_ITERATIONS)
parser.add_argument("--samples-per-iteration", type=int, default=DEFAULT_SAMPLES_PER_ITERATION)
parser.add_argument("--seed-base", type=int, default=None)
parser.add_argument("--disable-early-stop", action="store_true")
parser.add_argument("--quiet", action="store_true")
parser.add_argument("--extract-best", action="store_true", help="Copy best images after the run finishes.")
parser.add_argument("--generation-endpoint", required=True)
parser.add_argument("--generation-model", default=DEFAULT_GENERATION_MODEL)
parser.add_argument("--size", default=DEFAULT_IMAGE_SIZE, help="vLLM-Omni image size in WIDTHxHEIGHT format.")
parser.add_argument("--generation-auth-key", default="")
parser.add_argument("--generation-auth-key-env", default=DEFAULT_GENERATION_AUTH_KEY_ENV)
parser.add_argument("--resolution", default=DEFAULT_RESOLUTION)
parser.add_argument("--aspect-ratio", default=DEFAULT_ASPECT_RATIO)
parser.add_argument("--num-steps", type=int, default=DEFAULT_NUM_STEPS)
parser.add_argument("--guidance", type=float, default=DEFAULT_GUIDANCE)
parser.add_argument("--flow-shift", type=float, default=DEFAULT_FLOW_SHIFT)
parser.add_argument("--generation-extra-args", type=json.loads, default=DEFAULT_GENERATION_EXTRA_ARGS)
parser.add_argument("--upsampler-endpoint-url", default=DEFAULT_UPSAMPLER_ENDPOINT_URL)
parser.add_argument("--upsampler-model", default=DEFAULT_UPSAMPLER_MODEL)
parser.add_argument("--rewriter-endpoint-url", default=DEFAULT_REWRITER_ENDPOINT_URL)
parser.add_argument("--rewriter-model", default=DEFAULT_REWRITER_MODEL)
parser.add_argument("--openai-api-key-env", default=DEFAULT_OPENAI_API_KEY_ENV)
parser.add_argument("--openai-api-key-file", type=Path, default=None)
parser.add_argument("--llm-extra-body", type=json.loads, default=DEFAULT_LLM_EXTRA_BODY)
parser.add_argument("--initial-negative-prompt", default="")
parser.add_argument("--critic-endpoint-url", default=DEFAULT_CRITIC_ENDPOINT_URL)
parser.add_argument("--critic-model", default=DEFAULT_CRITIC_MODEL)
parser.add_argument("--gemini-api-key-env", default=DEFAULT_GEMINI_API_KEY_ENV)
parser.add_argument("--gemini-api-key-file", type=Path, default=None)
return parser.parse_args()
def main() -> int:
args = parse_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
items = load_prompt_items(prompt=args.prompt, prompts_path=args.prompts, limit=args.limit)
if not items:
raise RuntimeError("No prompts selected.")
if args.samples_per_iteration < 1:
raise ValueError("--samples-per-iteration must be >= 1.")
if not isinstance(args.generation_extra_args, dict):
raise ValueError("--generation-extra-args must decode to a JSON object.")
openai_token = read_api_token(args.openai_api_key_env, args.openai_api_key_file)
gemini_token = read_api_token(args.gemini_api_key_env, args.gemini_api_key_file)
generation_auth_key = read_optional_generation_auth_key(args.generation_auth_key, args.generation_auth_key_env)
write_json_atomic(
args.output_dir / "run_config.json",
{
"selected_prompts": len(items),
"max_iterations": args.max_iterations,
"samples_per_iteration": args.samples_per_iteration,
"early_stop": not args.disable_early_stop,
"generation_endpoint": args.generation_endpoint,
"generation_model": args.generation_model,
"size": args.size,
"resolution": args.resolution,
"aspect_ratio": args.aspect_ratio,
"num_steps": args.num_steps,
"guidance": args.guidance,
"flow_shift": args.flow_shift,
"generation_extra_args": args.generation_extra_args,
"upsampler_endpoint_url": args.upsampler_endpoint_url,
"upsampler_model": args.upsampler_model,
"rewriter_endpoint_url": args.rewriter_endpoint_url,
"rewriter_model": args.rewriter_model,
"llm_extra_body": args.llm_extra_body,
"critic_endpoint_url": args.critic_endpoint_url,
"critic_model": args.critic_model,
"initial_negative_prompt": args.initial_negative_prompt,
},
)
rewriter = PromptRewriterClient(
api_token=openai_token,
upsampler_endpoint_url=args.upsampler_endpoint_url,
upsampler_model=args.upsampler_model,
rewriter_endpoint_url=args.rewriter_endpoint_url,
rewriter_model=args.rewriter_model,
extra_body=args.llm_extra_body,
resolution=args.resolution,
aspect_ratio=args.aspect_ratio,
)
generator = ImageGenerationClient(
endpoint=args.generation_endpoint,
auth_key=generation_auth_key,
model=args.generation_model,
size=args.size,
num_steps=args.num_steps,
guidance=args.guidance,
flow_shift=args.flow_shift,
extra_args=args.generation_extra_args,
)
judge = VLMQualityJudge(
api_token=gemini_token,
endpoint_url=args.critic_endpoint_url,
model=args.critic_model,
)
runner = AgenticUpsamplerRunner(
rewriter=rewriter,
generator=generator,
judge=judge,
config=RunnerConfig(
output_dir=args.output_dir,
max_iterations=args.max_iterations,
samples_per_iteration=args.samples_per_iteration,
overwrite=args.overwrite,
seed_base=args.seed_base,
initial_negative_prompt=args.initial_negative_prompt,
early_stop=not args.disable_early_stop,
verbose=not args.quiet,
),
)
results = [runner.run_item_safely(item) for item in items]
write_run_manifest(args.output_dir, results)
failures = sum(1 for item in results if item.get("error"))
summary = {"selected_prompts": len(items), "completed": len(items) - failures, "failures": failures}
write_json_atomic(args.output_dir / "summary.json", summary)
print(json.dumps(summary, indent=2), flush=True)
if args.extract_best and not failures:
export_dir = args.output_dir / "best_generations"
extract_best_images(args.output_dir, export_dir, overwrite=args.overwrite)
print(f"Exported best images to {export_dir}", flush=True)
return 1 if failures else 0
if __name__ == "__main__":
raise SystemExit(main())