foveated-diffusion / src /inference /experiments.py
bchao1's picture
Upload foveated_diffusion Gradio demo
606581d verified
Raw
History Blame Contribute Delete
21 kB
"""Experiment runners: maps `--experiment` value to a callable.
All runners take (pipe, args, output_dir) and write images / metadata under
`output_dir`. See README for the per-experiment output layout.
"""
import math
import os
import random
import time
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torchvision.utils import save_image
from diffsynth.core import load_state_dict
from ..masks import (
create_foveation_mask,
create_foveation_mask_full_res,
gaussian_blur_mask_2d,
generate_foveation_trajectory_masks,
)
from .visualize import (
create_tokenization_mask_vis,
draw_foveation_outline,
load_prompt_dataset,
)
# ---------------------------------------------------------------------------
# Common pipeline call
# ---------------------------------------------------------------------------
def _pipe_call(
pipe,
args,
prompt: str,
foveation_mask,
full_res_foveation_mask=None,
decode_mode: str = None,
extra_kwargs: dict = None,
):
"""Single call to the foveated FLUX2 pipeline. Returns a PIL.Image."""
kwargs = dict(
prompt=prompt,
height=args.height,
width=args.width,
seed=args.seed,
rand_device="cuda",
num_inference_steps=args.num_inference_steps,
cfg_scale=args.guidance_scale,
foveation_mask=foveation_mask,
decode_mode=decode_mode if decode_mode is not None else args.decode_mode,
prediction_type=args.prediction_type,
soft_foveation_blend=getattr(args, "soft_foveation_blend", False),
lr_downsample_factor=getattr(args, "lr_downsample_factor", 2),
)
if full_res_foveation_mask is not None:
kwargs["full_res_foveation_mask"] = full_res_foveation_mask
if extra_kwargs:
kwargs.update(extra_kwargs)
return pipe(**kwargs)
# ---------------------------------------------------------------------------
# Single-prompt generation: high_res / naive_mixed_res / ours
# ---------------------------------------------------------------------------
def _run_high_res_one(pipe, args, output_dir, prompt, prompt_idx, foveation_mask=None, full_res_foveation_mask=None):
print(f"[high_res] prompt {prompt_idx:010d}: {prompt}")
image = _pipe_call(pipe, args, prompt, foveation_mask=None, decode_mode="direct")
image.save(os.path.join(output_dir, f"img_{prompt_idx:010d}.png"))
def _run_naive_mixed_res_one(pipe, args, output_dir, prompt, prompt_idx, foveation_mask, full_res_foveation_mask=None):
print(f"[naive_mixed_res] prompt {prompt_idx:010d}: {prompt}")
image = _pipe_call(pipe, args, prompt, foveation_mask, full_res_foveation_mask)
image.save(os.path.join(output_dir, f"img_{prompt_idx:010d}.png"))
def _run_ours_one(pipe, args, output_dir, prompt, prompt_idx, foveation_mask, full_res_foveation_mask=None):
print(f"[ours] prompt {prompt_idx:010d}: {prompt}")
image = _pipe_call(pipe, args, prompt, foveation_mask, full_res_foveation_mask)
image.save(os.path.join(output_dir, f"img_{prompt_idx:010d}.png"))
SINGLE_PROMPT_FUNCS = {
"high_res": _run_high_res_one,
"naive_mixed_res": _run_naive_mixed_res_one,
"ours": _run_ours_one,
}
def _run_full_eval(single_prompt_func, pipe, args, output_dir, prompts):
"""Iterate over prompts; for naive/ours apply a default centered foveation mask."""
default_mask = None
full_res_mask = None
if single_prompt_func is not _run_high_res_one:
mask_shape = getattr(args, "full_eval_mask", "square")
r = getattr(args, "mask_radius", 0.5)
lr_factor = getattr(args, "lr_downsample_factor", 2)
default_mask = create_foveation_mask(
args.height, args.width, (0, 0), r, mask_shape, pipe.device, lr_factor=lr_factor,
)
full_res_mask = create_foveation_mask_full_res(
args.height, args.width, (0, 0), r, mask_shape, pipe.device,
)
prompt_list, image_names = [], []
subset_idx = getattr(args, "subset_idx", 0)
num_subsets = max(1, getattr(args, "num_subsets", 1))
for global_idx, prompt in enumerate(prompts):
if args.num_prompts is not None and global_idx >= args.num_prompts:
break
if (global_idx % num_subsets) != subset_idx:
continue
single_prompt_func(
pipe, args, output_dir, prompt, global_idx,
foveation_mask=default_mask, full_res_foveation_mask=full_res_mask,
)
prompt_list.append(prompt)
image_names.append(f"img_{global_idx:010d}.png")
pd.DataFrame({"image": image_names, "prompt": prompt_list}).to_csv(
os.path.join(output_dir, f"metadata_{subset_idx:05d}.csv"), index=False,
)
def _resolve_prompts(args):
if args.full_eval:
assert args.prompt_dataset_path is not None, "--prompt_dataset_path required for --full_eval"
prompts = load_prompt_dataset(args.prompt_dataset_path)
print(f"Loaded {len(prompts)} prompts from {args.prompt_dataset_path}")
return prompts
return [args.prompt]
def run_single_prompt_experiment(pipe, args, output_dir):
"""Dispatcher for high_res / naive_mixed_res / ours (each uses _run_full_eval)."""
prompts = _resolve_prompts(args)
_run_full_eval(SINGLE_PROMPT_FUNCS[args.experiment], pipe, args, output_dir, prompts)
# ---------------------------------------------------------------------------
# Circular trajectory (single prompt, mask center orbits)
# ---------------------------------------------------------------------------
def run_circular_traj(pipe, args, output_dir):
print(f"[circular_traj] {args.num_frames} frames, orbit={args.orbit_radius}, r={args.mask_radius}")
lr_factor = getattr(args, "lr_downsample_factor", 2)
for i in range(args.num_frames):
angle = 2 * math.pi * (i / args.num_frames)
center = (args.orbit_radius * math.cos(angle), args.orbit_radius * math.sin(angle))
foveation_mask = create_foveation_mask(
args.height, args.width, center, args.mask_radius, args.mask_shape, pipe.device, lr_factor,
)
full_res_mask = create_foveation_mask_full_res(
args.height, args.width, center, args.mask_radius, args.mask_shape, pipe.device,
)
_save_mask_image(full_res_mask, pipe, args, os.path.join(output_dir, f"foveation_mask_{i:03d}.png"))
torch.cuda.empty_cache()
t0 = time.time()
image = _pipe_call(pipe, args, args.prompt, foveation_mask, full_res_mask)
print(f" step {i:02d}/{args.num_frames} angle={math.degrees(angle):.1f}deg time={time.time() - t0:.2f}s")
image.save(os.path.join(output_dir, f"img_{i:03d}.png"))
# ---------------------------------------------------------------------------
# Vary radius (single prompt, sweep mask radius from 0.1 -> 1.0)
# ---------------------------------------------------------------------------
def run_vary_radius(pipe, args, output_dir):
print("[vary_radius]")
lr_factor = getattr(args, "lr_downsample_factor", 2)
for r_idx, r in enumerate(np.linspace(0.1, 1.0, args.num_frames)):
tag = f"{r_idx:02d}"
foveation_mask = create_foveation_mask(
args.height, args.width, (0, 0), r, args.mask_shape, pipe.device, lr_factor,
)
full_res_mask = create_foveation_mask_full_res(
args.height, args.width, (0, 0), r, args.mask_shape, pipe.device,
)
_save_mask_image(full_res_mask, pipe, args, os.path.join(output_dir, f"foveation_mask_{tag}.png"))
torch.cuda.empty_cache()
t0 = time.time()
image = _pipe_call(pipe, args, args.prompt, foveation_mask, full_res_mask)
print(f" r={r:.2f} time={time.time() - t0:.2f}s")
image.save(os.path.join(output_dir, f"img_{tag}.png"))
# ---------------------------------------------------------------------------
# Runtime benchmark
# ---------------------------------------------------------------------------
def run_runtime_experiments(pipe, args, output_dir):
print("[runtime] benchmarking foveation radius vs runtime")
lr_factor = getattr(args, "lr_downsample_factor", 2)
args.mask_shape = "square"
r_list = np.arange(0, 64 // lr_factor + 1)[::-1]
r_list = np.concatenate([[r_list[0]], r_list])
runtime_list, token_ratio_list = [], []
for r_idx, r in enumerate(r_list):
# Build a square HR region with side `r` in low-res latent units, then upsample to token grid.
foveation_mask = torch.zeros(
args.height // 16 // lr_factor, args.width // 16 // lr_factor,
device=pipe.device, dtype=torch.float32,
)
foveation_mask[:r, :r] = 1.0
foveation_mask = F.interpolate(
foveation_mask.unsqueeze(0).unsqueeze(0),
size=(args.height // 16, args.width // 16),
mode="nearest",
).squeeze(0).squeeze(0)
total_orig_tokens = foveation_mask.shape[-2] * foveation_mask.shape[-1]
num_hr = foveation_mask.sum()
num_lr = (total_orig_tokens - num_hr) // (lr_factor ** 2)
token_ratio = float(((num_hr + num_lr) / total_orig_tokens).item())
token_ratio_list.append(token_ratio)
print(f" r={r} HR={int(num_hr)} LR={int(num_lr)} ratio={token_ratio:.3f}")
num_repeats = 3
runtimes, saved_image = [], None
for rep in range(num_repeats):
torch.cuda.empty_cache()
image = _pipe_call(pipe, args, args.prompt, foveation_mask)
runtimes.append(pipe.vit_timing)
if rep == 0:
saved_image = image
avg = float(np.mean(runtimes))
print(f" -> avg vit_time={avg:.2f}s")
runtime_list.append(avg)
if saved_image is not None:
saved_image.save(os.path.join(output_dir, f"img_{r_idx:02d}.png"))
np.savez(
os.path.join(output_dir, "runtime_and_radius_list.npz"),
runtime_list=np.array(runtime_list),
r_list=np.array(r_list),
token_ratio_list=np.array(token_ratio_list),
)
# ---------------------------------------------------------------------------
# Foveation trajectory grid (N prompts x M masks)
# ---------------------------------------------------------------------------
def run_foveation_trajectory_grid(pipe, args, output_dir):
"""N prompts (rows) x M trajectory masks (cols). Same trajectory for every prompt
except `random_circular`, which samples a fresh mask per prompt.
"""
import imageio
if args.prompt_dataset_path is not None:
all_prompts = load_prompt_dataset(args.prompt_dataset_path)
else:
all_prompts = [args.prompt]
args.num_prompts = 1
prompt_ids = getattr(args, "prompt_ids", None)
if prompt_ids is not None:
if not prompt_ids:
raise ValueError("--prompt_ids must contain at least one index")
max_idx = len(all_prompts) - 1
bad = [i for i in prompt_ids if i < 0 or i > max_idx]
if bad:
raise ValueError(f"--prompt_ids contains invalid indices {bad}; valid range [0, {max_idx}]")
prompts = [all_prompts[i] for i in prompt_ids]
else:
assert args.num_prompts is not None and args.num_prompts > 0, "--num_prompts required when --prompt_ids not given"
prompts = all_prompts[: args.num_prompts]
print(f"[foveation_trajectory_grid] {len(prompts)} prompts")
outline_width_frac = getattr(args, "outline_width_frac", 0.01)
outline_color = getattr(args, "outline_color", (255, 0, 0))
if isinstance(outline_color, str):
outline_color = tuple(int(x) for x in outline_color.split(","))
height, width, device = args.height, args.width, pipe.device
lr_factor = getattr(args, "lr_downsample_factor", 2)
traj_type = getattr(args, "foveation_trajectory_type", "circular")
per_prompt_random_mask = (traj_type == "random_circular")
if not per_prompt_random_mask:
foveation_masks, full_res_foveation_masks = generate_foveation_trajectory_masks(
height, width, args, device, lr_factor=lr_factor,
)
num_cols = len(foveation_masks)
print(f" trajectory={traj_type}, {num_cols} masks (shared across prompts)")
tok_dir = os.path.join(output_dir, "tokenization_masks")
os.makedirs(tok_dir, exist_ok=True)
for col, m in enumerate(foveation_masks):
tok_vis = create_tokenization_mask_vis(m, height, width, lr_factor=lr_factor)
imageio.imwrite(os.path.join(tok_dir, f"tokenization_mask_{col:04d}.png"), tok_vis)
else:
rng = np.random.default_rng(getattr(args, "seed", 0))
print(f" trajectory={traj_type}, 1 fresh mask per prompt")
for prompt_idx, prompt in enumerate(prompts):
folder_id = prompt_ids[prompt_idx] if prompt_ids is not None else prompt_idx
prompt_dir = os.path.join(output_dir, f"{folder_id:03d}")
os.makedirs(prompt_dir, exist_ok=True)
with open(os.path.join(prompt_dir, "prompt.txt"), "w") as f:
f.write(prompt)
if per_prompt_random_mask:
orbit_radius = getattr(args, "orbit_radius", 0.25)
mask_radius = getattr(args, "mask_radius", 0.30)
angle = float(rng.uniform(0, 2 * math.pi))
center = (orbit_radius * math.cos(angle), orbit_radius * math.sin(angle))
shape = getattr(args, "mask_shape", "circular")
foveation_masks = [create_foveation_mask(height, width, center, mask_radius, shape, device, lr_factor)]
full_res_foveation_masks = [create_foveation_mask_full_res(height, width, center, mask_radius, shape, device)]
for col, foveation_mask in enumerate(foveation_masks):
foveation_mask_upsampled = F.interpolate(
foveation_mask.unsqueeze(0).unsqueeze(0),
size=(height, width), mode="nearest",
).squeeze(0).squeeze(0)
torch.cuda.empty_cache()
image = _pipe_call(
pipe, args, prompt,
foveation_mask, full_res_foveation_mask=foveation_mask_upsampled,
)
image_np = np.array(image)
if image_np.ndim == 2:
image_np = np.stack([image_np] * 3, axis=-1)
if getattr(args, "foveation_outline", True):
draw_foveation_outline(
image_np, full_res_foveation_masks[col], height, width,
outline_width_frac=outline_width_frac, color=outline_color,
)
imageio.imwrite(os.path.join(prompt_dir, f"img_{col:04d}.png"), image_np)
_save_mask_image(
full_res_foveation_masks[col], pipe, args,
os.path.join(prompt_dir, f"mask_{col:04d}.png"),
)
print(f" saved {folder_id:03d}/img_{col:04d}.png")
# ---------------------------------------------------------------------------
# User study (triplets: high_res, naive, ours per prompt)
# ---------------------------------------------------------------------------
def run_user_study(pipe, args, output_dir):
"""For each sampled prompt, produce (high_res, naive, ours) with the same seed and
a random foveation mask. LoRA / DiT checkpoint is loaded only for the 'ours' pass.
"""
import imageio
assert args.prompt_dataset_path, "--prompt_dataset_path required for user_study"
num_prompts = getattr(args, "num_prompts", None) or 5
all_prompts = load_prompt_dataset(args.prompt_dataset_path)
k = min(num_prompts, len(all_prompts))
prompts = all_prompts if num_prompts == len(all_prompts) else random.sample(all_prompts, k)
print(f"[user_study] {len(prompts)} prompts seed={args.seed}")
height, width, device = args.height, args.width, pipe.device
lr_factor = getattr(args, "lr_downsample_factor", 2)
outline_width_frac = getattr(args, "outline_width_frac", 0.01)
outline_color = getattr(args, "outline_color", (255, 0, 0))
if isinstance(outline_color, str):
outline_color = tuple(int(x) for x in outline_color.split(","))
pipe.clear_lora(verbose=0)
items = []
for idx, prompt in enumerate(prompts):
subdir = os.path.join(output_dir, f"{idx:05d}")
os.makedirs(subdir, exist_ok=True)
with open(os.path.join(subdir, "prompt.txt"), "w") as f:
f.write(prompt)
center = (random.uniform(-0.3, 0.3), random.uniform(-0.3, 0.3))
r, shape = 0.33, "circular"
foveation_mask = create_foveation_mask(height, width, center, r, shape, device, lr_factor)
full_res_mask = create_foveation_mask_full_res(height, width, center, r, shape, device)
items.append((subdir, prompt, foveation_mask, full_res_mask, center))
# 1) high-res baseline
_run_high_res_one(pipe, args, subdir, prompt, 0)
os.rename(os.path.join(subdir, "img_0000000000.png"), os.path.join(subdir, "img_high_res.png"))
# 2) naive mixed-resolution (base DiT)
_run_naive_mixed_res_one(pipe, args, subdir, prompt, 0, foveation_mask, full_res_mask)
os.rename(os.path.join(subdir, "img_0000000000.png"), os.path.join(subdir, "img_naive.png"))
if getattr(args, "foveation_outline", True):
_overlay_outline(os.path.join(subdir, "img_naive.png"), full_res_mask,
height, width, outline_width_frac, outline_color)
_save_mask_image(full_res_mask, pipe, args, os.path.join(subdir, "mask.png"))
_save_fixation_dot(center, height, width, os.path.join(subdir, "fixation_point.png"))
print(f" saved {idx:05d}/ (high_res, naive, mask, fixation)")
# Swap to LoRA / DiT checkpoint for the "ours" pass
if args.lora_checkpoint is not None:
pipe.load_lora(pipe.dit, args.lora_checkpoint)
print(f"Loaded LoRA checkpoint from {args.lora_checkpoint}")
if args.dit_checkpoint is not None:
state_dict = load_state_dict(args.dit_checkpoint, torch_dtype=torch.bfloat16)
pipe.dit.load_state_dict(state_dict)
print(f"Loaded DiT checkpoint from {args.dit_checkpoint}")
for idx, (subdir, prompt, foveation_mask, full_res_mask, _) in enumerate(items):
_run_ours_one(pipe, args, subdir, prompt, 0, foveation_mask, full_res_mask)
os.rename(os.path.join(subdir, "img_0000000000.png"), os.path.join(subdir, "img_ours.png"))
if getattr(args, "foveation_outline", True):
_overlay_outline(os.path.join(subdir, "img_ours.png"), full_res_mask,
height, width, outline_width_frac, outline_color)
print(f" saved {idx:05d}/img_ours.png")
if args.lora_checkpoint is not None:
pipe.clear_lora(verbose=0)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _save_mask_image(full_res_mask, pipe, args, path: str):
"""Save a full-resolution mask, optionally Gaussian-blurred to match the soft blend."""
mask = full_res_mask.unsqueeze(0).unsqueeze(0).float().to(pipe.device)
if getattr(args, "soft_foveation_blend", False):
upscale = args.height / (args.height // 16)
mask = gaussian_blur_mask_2d(mask, upscale, device=pipe.device, dtype=mask.dtype)
if path.endswith(".png") and ("foveation_mask" in path or "mask_" in path or path.endswith("mask.png")):
# Use uint8 PNG via numpy for user-study / trajectory layouts
import imageio
arr = np.clip(mask.squeeze().cpu().numpy(), 0.0, 1.0)
imageio.imwrite(path, (arr * 255).astype(np.uint8))
else:
save_image(mask, path)
def _overlay_outline(image_path, full_res_mask, height, width, outline_width_frac, outline_color):
import imageio
img = imageio.imread(image_path)
if img.ndim == 2:
img = np.stack([img] * 3, axis=-1)
draw_foveation_outline(img, full_res_mask, height, width, outline_width_frac, outline_color)
imageio.imwrite(image_path, img)
def _save_fixation_dot(center, height, width, path: str):
import imageio
cx_px = (center[0] + 0.5) * width
cy_px = (center[1] + 0.5) * height
dot_radius = max(2, int(0.01 * min(width, height)))
y = np.arange(height, dtype=np.float64)[:, None]
x = np.arange(width, dtype=np.float64)[None, :]
in_circle = (x - cx_px) ** 2 + (y - cy_px) ** 2 <= dot_radius ** 2
img = np.zeros((height, width, 3), dtype=np.uint8)
img[in_circle] = [255, 0, 0]
imageio.imwrite(path, img)
EXPERIMENTS = {
"high_res": run_single_prompt_experiment,
"naive_mixed_res": run_single_prompt_experiment,
"ours": run_single_prompt_experiment,
"circular_traj": run_circular_traj,
"vary_radius": run_vary_radius,
"runtime": run_runtime_experiments,
"foveation_trajectory_grid": run_foveation_trajectory_grid,
"user_study": run_user_study,
}