bchao1's picture
Upload foveated_diffusion Gradio demo
606581d verified
Raw
History Blame Contribute Delete
4.18 kB
"""argparse for inference.py."""
import argparse
def str_to_bool(s):
if isinstance(s, bool):
return s
if isinstance(s, str):
return s.lower() in ("true", "1", "yes")
if isinstance(s, (int, float)):
return bool(s)
raise argparse.ArgumentTypeError(f"Expected true/false, got {s!r}")
_DEFAULT_PROMPT = (
"Documentary-style imagery: a lively little dog stands still on a lush green lawn, "
"filling the entire frame. The dog has brownish-yellow fur, with both ears perked up, "
"and an expression that is focused and cheerful."
)
def build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="FLUX2 Foveated Image Generation")
# Model
p.add_argument("--model_id", type=str, default="black-forest-labs/FLUX.2-klein-base-4B",
help="HuggingFace model ID for FLUX2")
p.add_argument("--lora_checkpoint", default=None, type=str, help="Path to LoRA .safetensors")
p.add_argument("--dit_checkpoint", default=None, type=str, help="Path to full DiT checkpoint")
# Generation
p.add_argument("--height", type=int, default=1024)
p.add_argument("--width", type=int, default=1024)
p.add_argument("--num_inference_steps", type=int, default=50)
p.add_argument("--guidance_scale", type=float, default=4.0)
p.add_argument("--prompt", type=str, default=_DEFAULT_PROMPT)
p.add_argument("--seed", type=int, default=0)
p.add_argument("--decode_mode", type=str, default="direct", choices=["direct", "merge"])
p.add_argument("--prediction_type", type=str, default="clean", choices=["clean", "flow", "refiner"])
p.add_argument("--soft_foveation_blend", type=str_to_bool, default=False,
help="Use a Gaussian-falloff foveation mask boundary in merge decode")
p.add_argument("--lr_downsample_factor", type=int, default=2, choices=[2, 4],
help="Spatial downsampling factor for LR periphery (2 = 4x fewer LR tokens, 4 = 16x)")
# Output
p.add_argument("--output_dir", type=str, default="./outputs/flux2_foveated")
# Experiment selection
p.add_argument("--experiment", type=str, default="ours",
choices=[
"high_res", "naive_mixed_res", "ours",
"circular_traj", "vary_radius",
"runtime", "foveation_trajectory_grid",
"user_study",
])
# Full / distributed eval
p.add_argument("--full_eval", action="store_true", default=False)
p.add_argument("--subset_idx", type=int, default=0)
p.add_argument("--num_subsets", type=int, default=1)
p.add_argument("--full_eval_mask", type=str, default="square",
choices=["square", "checkerboard", "circular"])
p.add_argument("--prompt_dataset_path", type=str, help="Path to CSV with a 'prompt' column")
p.add_argument("--num_prompts", type=int, default=None)
# Foveation trajectory grid
p.add_argument("--num_cols", type=int, default=4)
p.add_argument("--foveation_trajectory_type", type=str, default="circular",
choices=["radius", "circular", "random_circular", "polygons",
"multi_circle", "grid", "spiral"])
p.add_argument("--grid_rows", type=int, default=3)
p.add_argument("--grid_cols", type=int, default=3)
p.add_argument("--outline_width_frac", type=float, default=0.005)
p.add_argument("--outline_color", type=str, default="255,0,0")
p.add_argument("--foveation_outline", type=str_to_bool, default=False)
p.add_argument("--prompt_ids", type=int, nargs="+", default=None,
help="Explicit CSV row indices for foveation_trajectory_grid")
# Circular / vary_radius experiment knobs
p.add_argument("--num_frames", type=int, default=100)
p.add_argument("--orbit_radius", type=float, default=0.25)
p.add_argument("--mask_radius", type=float, default=0.30,
help="Foveation radius (circular) or side ratio (square)")
p.add_argument("--mask_shape", type=str, default="circular",
choices=["circular", "square"])
return p