bchao1's picture
Upload foveated_diffusion Gradio demo
606581d verified
Raw
History Blame Contribute Delete
7.15 kB
"""Foveated flow-matching SFT loss.
Ported from the fork's `diffsynth/diffusion/loss.py`.
"""
import torch
from diffsynth.diffusion.base_pipeline import BasePipeline
def _create_random_foveation_mask(h, w, device, dtype=torch.float32,
center_range=(-0.3, 0.3), r_range=(0.2, 0.5)):
"""Random circular foveation mask, shape [h, w], 1 = HR, 0 = LR."""
cx = (torch.rand(1, device=device).item() * (center_range[1] - center_range[0])
+ center_range[0] + 0.5) * w
cy = (torch.rand(1, device=device).item() * (center_range[1] - center_range[0])
+ center_range[0] + 0.5) * h
r = torch.rand(1, device=device).item() * (r_range[1] - r_range[0]) + r_range[0]
diagonal = (h ** 2 + w ** 2) ** 0.5
radius_px = r * (diagonal / 2.0)
y = torch.arange(h, device=device, dtype=dtype)
x = torch.arange(w, device=device, dtype=dtype)
yy, xx = torch.meshgrid(y, x, indexing="ij")
dist_sq = (xx - cx) ** 2 + (yy - cy) ** 2
return (dist_sq <= radius_px ** 2).to(device)
def _create_fixed_foveation_mask(h, w, device, dtype=torch.float32,
center=(0.0, 0.0), r=0.5):
"""Fixed circular foveation mask with same geometry as the random variant."""
cx = (center[0] + 0.5) * w
cy = (center[1] + 0.5) * h
diagonal = (h ** 2 + w ** 2) ** 0.5
radius_px = r * (diagonal / 2.0)
y = torch.arange(h, device=device, dtype=dtype)
x = torch.arange(w, device=device, dtype=dtype)
yy, xx = torch.meshgrid(y, x, indexing="ij")
dist_sq = (xx - cx) ** 2 + (yy - cy) ** 2
return (dist_sq <= radius_px ** 2).to(device)
def _resolve_foveation_mask(pipe: BasePipeline, inputs: dict, h: int, w: int):
"""Resolve a token-grid foveation mask based on `foveated_training_mode`.
Modes:
- `fixed`: centered, r=0.5
- `random`: sampled each step
- `saliency` / `bbox`: use `inputs["foveation_mask"]` (e.g. precomputed from a
saliency map or bounding boxes)
"""
mode = inputs.get("foveated_training_mode", "random")
if mode == "fixed":
return _create_fixed_foveation_mask(h, w, pipe.device, pipe.torch_dtype,
center=(0.0, 0.0), r=0.5)
if mode == "random":
return _create_random_foveation_mask(h, w, pipe.device, pipe.torch_dtype)
if mode in ("saliency", "bbox"):
mask = inputs.get("foveation_mask")
if mask is None:
raise ValueError(
f"foveated_training_mode='{mode}' but inputs['foveation_mask'] is None."
)
mask = mask.to(device=pipe.device, dtype=pipe.torch_dtype)
while mask.dim() > 2:
mask = mask.squeeze(0)
if mask.shape[0] != h or mask.shape[1] != w:
mask = torch.nn.functional.interpolate(
mask.unsqueeze(0).unsqueeze(0), size=(h, w), mode="nearest",
).squeeze(0).squeeze(0)
return mask
return _create_random_foveation_mask(h, w, pipe.device, pipe.torch_dtype)
def FoveatedFlowMatchSFTLoss(pipe: BasePipeline, **inputs):
"""Foveated Flow Matching SFT objective.
For each step, samples a random foveation mask, constructs a mixed-resolution
training target (HR tokens inside the foveal region, LR tokens averaged into
one token per `lr_factor x lr_factor` block in the periphery), and computes
MSE against the foveated noise prediction returned by
`pipe.foveated_training_forward(...)`.
"""
max_b = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
min_b = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
timestep_id = torch.randint(min_b, max_b, (1,))
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype,
device=pipe.device)
noise = torch.randn_like(inputs["input_latents"])
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
noise_downsampled = torch.randn_like(inputs["input_latents_downsampled"])
inputs["latents_downsampled"] = pipe.scheduler.add_noise(
inputs["input_latents_downsampled"], noise_downsampled, timestep,
)
training_target_downsampled = pipe.scheduler.training_target(
inputs["input_latents_downsampled"], noise_downsampled, timestep,
)
has_foveated_forward = hasattr(pipe, "foveated_training_forward") and callable(
getattr(pipe, "foveated_training_forward", None)
)
if has_foveated_forward:
height = inputs["height"]
width = inputs["width"]
batch_size, _, channels = inputs["latents"].shape
h, w = height // 16, width // 16
foveation_mask = _resolve_foveation_mask(pipe, inputs, h, w)
lr_factor = inputs.get("lr_downsample_factor", 2)
n_per_block = lr_factor * lr_factor
h_d, w_d = h // lr_factor, w // lr_factor
# Build the mixed-resolution training target.
training_target_merged = training_target.view(batch_size, h, w, channels)
training_target_downsampled = training_target_downsampled.view(
batch_size, h_d, w_d, channels,
)
training_target_merged = training_target_merged.view(
batch_size, h_d, lr_factor, w_d, lr_factor, channels,
)
training_target_merged = training_target_merged.permute(0, 1, 3, 2, 4, 5).reshape(
batch_size, h_d, w_d, n_per_block, channels,
)
mask_blocks = foveation_mask.view(h_d, lr_factor, w_d, lr_factor)
mask_blocks = mask_blocks.permute(0, 2, 1, 3).reshape(h_d, w_d, n_per_block)
is_high_res_block = mask_blocks.sum(dim=-1) > 0
is_low_res_block = ~is_high_res_block
training_target_merged[:, is_low_res_block, 0, :] = (
training_target_downsampled[:, is_low_res_block, :].to(training_target_merged.dtype)
)
valid_tokens = torch.ones(
h_d, w_d, n_per_block, device=training_target_merged.device, dtype=torch.bool,
)
valid_tokens[is_low_res_block, 1:] = False
training_target_merged = training_target_merged.view(batch_size, -1, channels)
valid_tokens = valid_tokens.view(-1)
training_target = training_target_merged[:, valid_tokens, :]
inputs["foveation_mask"] = foveation_mask
prediction_type = inputs.get("prediction_type", "clean")
noise_pred = pipe.foveated_training_forward(
inputs, timestep, timestep_id, prediction_type,
lr_downsample_factor=lr_factor,
)
else:
# Fall back to standard flow matching if the pipeline doesn't support foveation.
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * pipe.scheduler.training_weight(timestep)
return loss