"""Experiment runners: maps `--experiment` value to a callable. All runners take (pipe, args, output_dir) and write images / metadata under `output_dir`. See README for the per-experiment output layout. """ import math import os import random import time import numpy as np import pandas as pd import torch import torch.nn.functional as F from torchvision.utils import save_image from diffsynth.core import load_state_dict from ..masks import ( create_foveation_mask, create_foveation_mask_full_res, gaussian_blur_mask_2d, generate_foveation_trajectory_masks, ) from .visualize import ( create_tokenization_mask_vis, draw_foveation_outline, load_prompt_dataset, ) # --------------------------------------------------------------------------- # Common pipeline call # --------------------------------------------------------------------------- def _pipe_call( pipe, args, prompt: str, foveation_mask, full_res_foveation_mask=None, decode_mode: str = None, extra_kwargs: dict = None, ): """Single call to the foveated FLUX2 pipeline. Returns a PIL.Image.""" kwargs = dict( prompt=prompt, height=args.height, width=args.width, seed=args.seed, rand_device="cuda", num_inference_steps=args.num_inference_steps, cfg_scale=args.guidance_scale, foveation_mask=foveation_mask, decode_mode=decode_mode if decode_mode is not None else args.decode_mode, prediction_type=args.prediction_type, soft_foveation_blend=getattr(args, "soft_foveation_blend", False), lr_downsample_factor=getattr(args, "lr_downsample_factor", 2), ) if full_res_foveation_mask is not None: kwargs["full_res_foveation_mask"] = full_res_foveation_mask if extra_kwargs: kwargs.update(extra_kwargs) return pipe(**kwargs) # --------------------------------------------------------------------------- # Single-prompt generation: high_res / naive_mixed_res / ours # --------------------------------------------------------------------------- def _run_high_res_one(pipe, args, output_dir, prompt, prompt_idx, foveation_mask=None, full_res_foveation_mask=None): print(f"[high_res] prompt {prompt_idx:010d}: {prompt}") image = _pipe_call(pipe, args, prompt, foveation_mask=None, decode_mode="direct") image.save(os.path.join(output_dir, f"img_{prompt_idx:010d}.png")) def _run_naive_mixed_res_one(pipe, args, output_dir, prompt, prompt_idx, foveation_mask, full_res_foveation_mask=None): print(f"[naive_mixed_res] prompt {prompt_idx:010d}: {prompt}") image = _pipe_call(pipe, args, prompt, foveation_mask, full_res_foveation_mask) image.save(os.path.join(output_dir, f"img_{prompt_idx:010d}.png")) def _run_ours_one(pipe, args, output_dir, prompt, prompt_idx, foveation_mask, full_res_foveation_mask=None): print(f"[ours] prompt {prompt_idx:010d}: {prompt}") image = _pipe_call(pipe, args, prompt, foveation_mask, full_res_foveation_mask) image.save(os.path.join(output_dir, f"img_{prompt_idx:010d}.png")) SINGLE_PROMPT_FUNCS = { "high_res": _run_high_res_one, "naive_mixed_res": _run_naive_mixed_res_one, "ours": _run_ours_one, } def _run_full_eval(single_prompt_func, pipe, args, output_dir, prompts): """Iterate over prompts; for naive/ours apply a default centered foveation mask.""" default_mask = None full_res_mask = None if single_prompt_func is not _run_high_res_one: mask_shape = getattr(args, "full_eval_mask", "square") r = getattr(args, "mask_radius", 0.5) lr_factor = getattr(args, "lr_downsample_factor", 2) default_mask = create_foveation_mask( args.height, args.width, (0, 0), r, mask_shape, pipe.device, lr_factor=lr_factor, ) full_res_mask = create_foveation_mask_full_res( args.height, args.width, (0, 0), r, mask_shape, pipe.device, ) prompt_list, image_names = [], [] subset_idx = getattr(args, "subset_idx", 0) num_subsets = max(1, getattr(args, "num_subsets", 1)) for global_idx, prompt in enumerate(prompts): if args.num_prompts is not None and global_idx >= args.num_prompts: break if (global_idx % num_subsets) != subset_idx: continue single_prompt_func( pipe, args, output_dir, prompt, global_idx, foveation_mask=default_mask, full_res_foveation_mask=full_res_mask, ) prompt_list.append(prompt) image_names.append(f"img_{global_idx:010d}.png") pd.DataFrame({"image": image_names, "prompt": prompt_list}).to_csv( os.path.join(output_dir, f"metadata_{subset_idx:05d}.csv"), index=False, ) def _resolve_prompts(args): if args.full_eval: assert args.prompt_dataset_path is not None, "--prompt_dataset_path required for --full_eval" prompts = load_prompt_dataset(args.prompt_dataset_path) print(f"Loaded {len(prompts)} prompts from {args.prompt_dataset_path}") return prompts return [args.prompt] def run_single_prompt_experiment(pipe, args, output_dir): """Dispatcher for high_res / naive_mixed_res / ours (each uses _run_full_eval).""" prompts = _resolve_prompts(args) _run_full_eval(SINGLE_PROMPT_FUNCS[args.experiment], pipe, args, output_dir, prompts) # --------------------------------------------------------------------------- # Circular trajectory (single prompt, mask center orbits) # --------------------------------------------------------------------------- def run_circular_traj(pipe, args, output_dir): print(f"[circular_traj] {args.num_frames} frames, orbit={args.orbit_radius}, r={args.mask_radius}") lr_factor = getattr(args, "lr_downsample_factor", 2) for i in range(args.num_frames): angle = 2 * math.pi * (i / args.num_frames) center = (args.orbit_radius * math.cos(angle), args.orbit_radius * math.sin(angle)) foveation_mask = create_foveation_mask( args.height, args.width, center, args.mask_radius, args.mask_shape, pipe.device, lr_factor, ) full_res_mask = create_foveation_mask_full_res( args.height, args.width, center, args.mask_radius, args.mask_shape, pipe.device, ) _save_mask_image(full_res_mask, pipe, args, os.path.join(output_dir, f"foveation_mask_{i:03d}.png")) torch.cuda.empty_cache() t0 = time.time() image = _pipe_call(pipe, args, args.prompt, foveation_mask, full_res_mask) print(f" step {i:02d}/{args.num_frames} angle={math.degrees(angle):.1f}deg time={time.time() - t0:.2f}s") image.save(os.path.join(output_dir, f"img_{i:03d}.png")) # --------------------------------------------------------------------------- # Vary radius (single prompt, sweep mask radius from 0.1 -> 1.0) # --------------------------------------------------------------------------- def run_vary_radius(pipe, args, output_dir): print("[vary_radius]") lr_factor = getattr(args, "lr_downsample_factor", 2) for r_idx, r in enumerate(np.linspace(0.1, 1.0, args.num_frames)): tag = f"{r_idx:02d}" foveation_mask = create_foveation_mask( args.height, args.width, (0, 0), r, args.mask_shape, pipe.device, lr_factor, ) full_res_mask = create_foveation_mask_full_res( args.height, args.width, (0, 0), r, args.mask_shape, pipe.device, ) _save_mask_image(full_res_mask, pipe, args, os.path.join(output_dir, f"foveation_mask_{tag}.png")) torch.cuda.empty_cache() t0 = time.time() image = _pipe_call(pipe, args, args.prompt, foveation_mask, full_res_mask) print(f" r={r:.2f} time={time.time() - t0:.2f}s") image.save(os.path.join(output_dir, f"img_{tag}.png")) # --------------------------------------------------------------------------- # Runtime benchmark # --------------------------------------------------------------------------- def run_runtime_experiments(pipe, args, output_dir): print("[runtime] benchmarking foveation radius vs runtime") lr_factor = getattr(args, "lr_downsample_factor", 2) args.mask_shape = "square" r_list = np.arange(0, 64 // lr_factor + 1)[::-1] r_list = np.concatenate([[r_list[0]], r_list]) runtime_list, token_ratio_list = [], [] for r_idx, r in enumerate(r_list): # Build a square HR region with side `r` in low-res latent units, then upsample to token grid. foveation_mask = torch.zeros( args.height // 16 // lr_factor, args.width // 16 // lr_factor, device=pipe.device, dtype=torch.float32, ) foveation_mask[:r, :r] = 1.0 foveation_mask = F.interpolate( foveation_mask.unsqueeze(0).unsqueeze(0), size=(args.height // 16, args.width // 16), mode="nearest", ).squeeze(0).squeeze(0) total_orig_tokens = foveation_mask.shape[-2] * foveation_mask.shape[-1] num_hr = foveation_mask.sum() num_lr = (total_orig_tokens - num_hr) // (lr_factor ** 2) token_ratio = float(((num_hr + num_lr) / total_orig_tokens).item()) token_ratio_list.append(token_ratio) print(f" r={r} HR={int(num_hr)} LR={int(num_lr)} ratio={token_ratio:.3f}") num_repeats = 3 runtimes, saved_image = [], None for rep in range(num_repeats): torch.cuda.empty_cache() image = _pipe_call(pipe, args, args.prompt, foveation_mask) runtimes.append(pipe.vit_timing) if rep == 0: saved_image = image avg = float(np.mean(runtimes)) print(f" -> avg vit_time={avg:.2f}s") runtime_list.append(avg) if saved_image is not None: saved_image.save(os.path.join(output_dir, f"img_{r_idx:02d}.png")) np.savez( os.path.join(output_dir, "runtime_and_radius_list.npz"), runtime_list=np.array(runtime_list), r_list=np.array(r_list), token_ratio_list=np.array(token_ratio_list), ) # --------------------------------------------------------------------------- # Foveation trajectory grid (N prompts x M masks) # --------------------------------------------------------------------------- def run_foveation_trajectory_grid(pipe, args, output_dir): """N prompts (rows) x M trajectory masks (cols). Same trajectory for every prompt except `random_circular`, which samples a fresh mask per prompt. """ import imageio if args.prompt_dataset_path is not None: all_prompts = load_prompt_dataset(args.prompt_dataset_path) else: all_prompts = [args.prompt] args.num_prompts = 1 prompt_ids = getattr(args, "prompt_ids", None) if prompt_ids is not None: if not prompt_ids: raise ValueError("--prompt_ids must contain at least one index") max_idx = len(all_prompts) - 1 bad = [i for i in prompt_ids if i < 0 or i > max_idx] if bad: raise ValueError(f"--prompt_ids contains invalid indices {bad}; valid range [0, {max_idx}]") prompts = [all_prompts[i] for i in prompt_ids] else: assert args.num_prompts is not None and args.num_prompts > 0, "--num_prompts required when --prompt_ids not given" prompts = all_prompts[: args.num_prompts] print(f"[foveation_trajectory_grid] {len(prompts)} prompts") outline_width_frac = getattr(args, "outline_width_frac", 0.01) outline_color = getattr(args, "outline_color", (255, 0, 0)) if isinstance(outline_color, str): outline_color = tuple(int(x) for x in outline_color.split(",")) height, width, device = args.height, args.width, pipe.device lr_factor = getattr(args, "lr_downsample_factor", 2) traj_type = getattr(args, "foveation_trajectory_type", "circular") per_prompt_random_mask = (traj_type == "random_circular") if not per_prompt_random_mask: foveation_masks, full_res_foveation_masks = generate_foveation_trajectory_masks( height, width, args, device, lr_factor=lr_factor, ) num_cols = len(foveation_masks) print(f" trajectory={traj_type}, {num_cols} masks (shared across prompts)") tok_dir = os.path.join(output_dir, "tokenization_masks") os.makedirs(tok_dir, exist_ok=True) for col, m in enumerate(foveation_masks): tok_vis = create_tokenization_mask_vis(m, height, width, lr_factor=lr_factor) imageio.imwrite(os.path.join(tok_dir, f"tokenization_mask_{col:04d}.png"), tok_vis) else: rng = np.random.default_rng(getattr(args, "seed", 0)) print(f" trajectory={traj_type}, 1 fresh mask per prompt") for prompt_idx, prompt in enumerate(prompts): folder_id = prompt_ids[prompt_idx] if prompt_ids is not None else prompt_idx prompt_dir = os.path.join(output_dir, f"{folder_id:03d}") os.makedirs(prompt_dir, exist_ok=True) with open(os.path.join(prompt_dir, "prompt.txt"), "w") as f: f.write(prompt) if per_prompt_random_mask: orbit_radius = getattr(args, "orbit_radius", 0.25) mask_radius = getattr(args, "mask_radius", 0.30) angle = float(rng.uniform(0, 2 * math.pi)) center = (orbit_radius * math.cos(angle), orbit_radius * math.sin(angle)) shape = getattr(args, "mask_shape", "circular") foveation_masks = [create_foveation_mask(height, width, center, mask_radius, shape, device, lr_factor)] full_res_foveation_masks = [create_foveation_mask_full_res(height, width, center, mask_radius, shape, device)] for col, foveation_mask in enumerate(foveation_masks): foveation_mask_upsampled = F.interpolate( foveation_mask.unsqueeze(0).unsqueeze(0), size=(height, width), mode="nearest", ).squeeze(0).squeeze(0) torch.cuda.empty_cache() image = _pipe_call( pipe, args, prompt, foveation_mask, full_res_foveation_mask=foveation_mask_upsampled, ) image_np = np.array(image) if image_np.ndim == 2: image_np = np.stack([image_np] * 3, axis=-1) if getattr(args, "foveation_outline", True): draw_foveation_outline( image_np, full_res_foveation_masks[col], height, width, outline_width_frac=outline_width_frac, color=outline_color, ) imageio.imwrite(os.path.join(prompt_dir, f"img_{col:04d}.png"), image_np) _save_mask_image( full_res_foveation_masks[col], pipe, args, os.path.join(prompt_dir, f"mask_{col:04d}.png"), ) print(f" saved {folder_id:03d}/img_{col:04d}.png") # --------------------------------------------------------------------------- # User study (triplets: high_res, naive, ours per prompt) # --------------------------------------------------------------------------- def run_user_study(pipe, args, output_dir): """For each sampled prompt, produce (high_res, naive, ours) with the same seed and a random foveation mask. LoRA / DiT checkpoint is loaded only for the 'ours' pass. """ import imageio assert args.prompt_dataset_path, "--prompt_dataset_path required for user_study" num_prompts = getattr(args, "num_prompts", None) or 5 all_prompts = load_prompt_dataset(args.prompt_dataset_path) k = min(num_prompts, len(all_prompts)) prompts = all_prompts if num_prompts == len(all_prompts) else random.sample(all_prompts, k) print(f"[user_study] {len(prompts)} prompts seed={args.seed}") height, width, device = args.height, args.width, pipe.device lr_factor = getattr(args, "lr_downsample_factor", 2) outline_width_frac = getattr(args, "outline_width_frac", 0.01) outline_color = getattr(args, "outline_color", (255, 0, 0)) if isinstance(outline_color, str): outline_color = tuple(int(x) for x in outline_color.split(",")) pipe.clear_lora(verbose=0) items = [] for idx, prompt in enumerate(prompts): subdir = os.path.join(output_dir, f"{idx:05d}") os.makedirs(subdir, exist_ok=True) with open(os.path.join(subdir, "prompt.txt"), "w") as f: f.write(prompt) center = (random.uniform(-0.3, 0.3), random.uniform(-0.3, 0.3)) r, shape = 0.33, "circular" foveation_mask = create_foveation_mask(height, width, center, r, shape, device, lr_factor) full_res_mask = create_foveation_mask_full_res(height, width, center, r, shape, device) items.append((subdir, prompt, foveation_mask, full_res_mask, center)) # 1) high-res baseline _run_high_res_one(pipe, args, subdir, prompt, 0) os.rename(os.path.join(subdir, "img_0000000000.png"), os.path.join(subdir, "img_high_res.png")) # 2) naive mixed-resolution (base DiT) _run_naive_mixed_res_one(pipe, args, subdir, prompt, 0, foveation_mask, full_res_mask) os.rename(os.path.join(subdir, "img_0000000000.png"), os.path.join(subdir, "img_naive.png")) if getattr(args, "foveation_outline", True): _overlay_outline(os.path.join(subdir, "img_naive.png"), full_res_mask, height, width, outline_width_frac, outline_color) _save_mask_image(full_res_mask, pipe, args, os.path.join(subdir, "mask.png")) _save_fixation_dot(center, height, width, os.path.join(subdir, "fixation_point.png")) print(f" saved {idx:05d}/ (high_res, naive, mask, fixation)") # Swap to LoRA / DiT checkpoint for the "ours" pass if args.lora_checkpoint is not None: pipe.load_lora(pipe.dit, args.lora_checkpoint) print(f"Loaded LoRA checkpoint from {args.lora_checkpoint}") if args.dit_checkpoint is not None: state_dict = load_state_dict(args.dit_checkpoint, torch_dtype=torch.bfloat16) pipe.dit.load_state_dict(state_dict) print(f"Loaded DiT checkpoint from {args.dit_checkpoint}") for idx, (subdir, prompt, foveation_mask, full_res_mask, _) in enumerate(items): _run_ours_one(pipe, args, subdir, prompt, 0, foveation_mask, full_res_mask) os.rename(os.path.join(subdir, "img_0000000000.png"), os.path.join(subdir, "img_ours.png")) if getattr(args, "foveation_outline", True): _overlay_outline(os.path.join(subdir, "img_ours.png"), full_res_mask, height, width, outline_width_frac, outline_color) print(f" saved {idx:05d}/img_ours.png") if args.lora_checkpoint is not None: pipe.clear_lora(verbose=0) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _save_mask_image(full_res_mask, pipe, args, path: str): """Save a full-resolution mask, optionally Gaussian-blurred to match the soft blend.""" mask = full_res_mask.unsqueeze(0).unsqueeze(0).float().to(pipe.device) if getattr(args, "soft_foveation_blend", False): upscale = args.height / (args.height // 16) mask = gaussian_blur_mask_2d(mask, upscale, device=pipe.device, dtype=mask.dtype) if path.endswith(".png") and ("foveation_mask" in path or "mask_" in path or path.endswith("mask.png")): # Use uint8 PNG via numpy for user-study / trajectory layouts import imageio arr = np.clip(mask.squeeze().cpu().numpy(), 0.0, 1.0) imageio.imwrite(path, (arr * 255).astype(np.uint8)) else: save_image(mask, path) def _overlay_outline(image_path, full_res_mask, height, width, outline_width_frac, outline_color): import imageio img = imageio.imread(image_path) if img.ndim == 2: img = np.stack([img] * 3, axis=-1) draw_foveation_outline(img, full_res_mask, height, width, outline_width_frac, outline_color) imageio.imwrite(image_path, img) def _save_fixation_dot(center, height, width, path: str): import imageio cx_px = (center[0] + 0.5) * width cy_px = (center[1] + 0.5) * height dot_radius = max(2, int(0.01 * min(width, height))) y = np.arange(height, dtype=np.float64)[:, None] x = np.arange(width, dtype=np.float64)[None, :] in_circle = (x - cx_px) ** 2 + (y - cy_px) ** 2 <= dot_radius ** 2 img = np.zeros((height, width, 3), dtype=np.uint8) img[in_circle] = [255, 0, 0] imageio.imwrite(path, img) EXPERIMENTS = { "high_res": run_single_prompt_experiment, "naive_mixed_res": run_single_prompt_experiment, "ours": run_single_prompt_experiment, "circular_traj": run_circular_traj, "vary_radius": run_vary_radius, "runtime": run_runtime_experiments, "foveation_trajectory_grid": run_foveation_trajectory_grid, "user_study": run_user_study, }