Spaces:
Runtime error
Runtime error
| """Foveated flow-matching SFT loss. | |
| Ported from the fork's `diffsynth/diffusion/loss.py`. | |
| """ | |
| import torch | |
| from diffsynth.diffusion.base_pipeline import BasePipeline | |
| def _create_random_foveation_mask(h, w, device, dtype=torch.float32, | |
| center_range=(-0.3, 0.3), r_range=(0.2, 0.5)): | |
| """Random circular foveation mask, shape [h, w], 1 = HR, 0 = LR.""" | |
| cx = (torch.rand(1, device=device).item() * (center_range[1] - center_range[0]) | |
| + center_range[0] + 0.5) * w | |
| cy = (torch.rand(1, device=device).item() * (center_range[1] - center_range[0]) | |
| + center_range[0] + 0.5) * h | |
| r = torch.rand(1, device=device).item() * (r_range[1] - r_range[0]) + r_range[0] | |
| diagonal = (h ** 2 + w ** 2) ** 0.5 | |
| radius_px = r * (diagonal / 2.0) | |
| y = torch.arange(h, device=device, dtype=dtype) | |
| x = torch.arange(w, device=device, dtype=dtype) | |
| yy, xx = torch.meshgrid(y, x, indexing="ij") | |
| dist_sq = (xx - cx) ** 2 + (yy - cy) ** 2 | |
| return (dist_sq <= radius_px ** 2).to(device) | |
| def _create_fixed_foveation_mask(h, w, device, dtype=torch.float32, | |
| center=(0.0, 0.0), r=0.5): | |
| """Fixed circular foveation mask with same geometry as the random variant.""" | |
| cx = (center[0] + 0.5) * w | |
| cy = (center[1] + 0.5) * h | |
| diagonal = (h ** 2 + w ** 2) ** 0.5 | |
| radius_px = r * (diagonal / 2.0) | |
| y = torch.arange(h, device=device, dtype=dtype) | |
| x = torch.arange(w, device=device, dtype=dtype) | |
| yy, xx = torch.meshgrid(y, x, indexing="ij") | |
| dist_sq = (xx - cx) ** 2 + (yy - cy) ** 2 | |
| return (dist_sq <= radius_px ** 2).to(device) | |
| def _resolve_foveation_mask(pipe: BasePipeline, inputs: dict, h: int, w: int): | |
| """Resolve a token-grid foveation mask based on `foveated_training_mode`. | |
| Modes: | |
| - `fixed`: centered, r=0.5 | |
| - `random`: sampled each step | |
| - `saliency` / `bbox`: use `inputs["foveation_mask"]` (e.g. precomputed from a | |
| saliency map or bounding boxes) | |
| """ | |
| mode = inputs.get("foveated_training_mode", "random") | |
| if mode == "fixed": | |
| return _create_fixed_foveation_mask(h, w, pipe.device, pipe.torch_dtype, | |
| center=(0.0, 0.0), r=0.5) | |
| if mode == "random": | |
| return _create_random_foveation_mask(h, w, pipe.device, pipe.torch_dtype) | |
| if mode in ("saliency", "bbox"): | |
| mask = inputs.get("foveation_mask") | |
| if mask is None: | |
| raise ValueError( | |
| f"foveated_training_mode='{mode}' but inputs['foveation_mask'] is None." | |
| ) | |
| mask = mask.to(device=pipe.device, dtype=pipe.torch_dtype) | |
| while mask.dim() > 2: | |
| mask = mask.squeeze(0) | |
| if mask.shape[0] != h or mask.shape[1] != w: | |
| mask = torch.nn.functional.interpolate( | |
| mask.unsqueeze(0).unsqueeze(0), size=(h, w), mode="nearest", | |
| ).squeeze(0).squeeze(0) | |
| return mask | |
| return _create_random_foveation_mask(h, w, pipe.device, pipe.torch_dtype) | |
| def FoveatedFlowMatchSFTLoss(pipe: BasePipeline, **inputs): | |
| """Foveated Flow Matching SFT objective. | |
| For each step, samples a random foveation mask, constructs a mixed-resolution | |
| training target (HR tokens inside the foveal region, LR tokens averaged into | |
| one token per `lr_factor x lr_factor` block in the periphery), and computes | |
| MSE against the foveated noise prediction returned by | |
| `pipe.foveated_training_forward(...)`. | |
| """ | |
| max_b = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps)) | |
| min_b = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps)) | |
| timestep_id = torch.randint(min_b, max_b, (1,)) | |
| timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, | |
| device=pipe.device) | |
| noise = torch.randn_like(inputs["input_latents"]) | |
| inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep) | |
| training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep) | |
| noise_downsampled = torch.randn_like(inputs["input_latents_downsampled"]) | |
| inputs["latents_downsampled"] = pipe.scheduler.add_noise( | |
| inputs["input_latents_downsampled"], noise_downsampled, timestep, | |
| ) | |
| training_target_downsampled = pipe.scheduler.training_target( | |
| inputs["input_latents_downsampled"], noise_downsampled, timestep, | |
| ) | |
| has_foveated_forward = hasattr(pipe, "foveated_training_forward") and callable( | |
| getattr(pipe, "foveated_training_forward", None) | |
| ) | |
| if has_foveated_forward: | |
| height = inputs["height"] | |
| width = inputs["width"] | |
| batch_size, _, channels = inputs["latents"].shape | |
| h, w = height // 16, width // 16 | |
| foveation_mask = _resolve_foveation_mask(pipe, inputs, h, w) | |
| lr_factor = inputs.get("lr_downsample_factor", 2) | |
| n_per_block = lr_factor * lr_factor | |
| h_d, w_d = h // lr_factor, w // lr_factor | |
| # Build the mixed-resolution training target. | |
| training_target_merged = training_target.view(batch_size, h, w, channels) | |
| training_target_downsampled = training_target_downsampled.view( | |
| batch_size, h_d, w_d, channels, | |
| ) | |
| training_target_merged = training_target_merged.view( | |
| batch_size, h_d, lr_factor, w_d, lr_factor, channels, | |
| ) | |
| training_target_merged = training_target_merged.permute(0, 1, 3, 2, 4, 5).reshape( | |
| batch_size, h_d, w_d, n_per_block, channels, | |
| ) | |
| mask_blocks = foveation_mask.view(h_d, lr_factor, w_d, lr_factor) | |
| mask_blocks = mask_blocks.permute(0, 2, 1, 3).reshape(h_d, w_d, n_per_block) | |
| is_high_res_block = mask_blocks.sum(dim=-1) > 0 | |
| is_low_res_block = ~is_high_res_block | |
| training_target_merged[:, is_low_res_block, 0, :] = ( | |
| training_target_downsampled[:, is_low_res_block, :].to(training_target_merged.dtype) | |
| ) | |
| valid_tokens = torch.ones( | |
| h_d, w_d, n_per_block, device=training_target_merged.device, dtype=torch.bool, | |
| ) | |
| valid_tokens[is_low_res_block, 1:] = False | |
| training_target_merged = training_target_merged.view(batch_size, -1, channels) | |
| valid_tokens = valid_tokens.view(-1) | |
| training_target = training_target_merged[:, valid_tokens, :] | |
| inputs["foveation_mask"] = foveation_mask | |
| prediction_type = inputs.get("prediction_type", "clean") | |
| noise_pred = pipe.foveated_training_forward( | |
| inputs, timestep, timestep_id, prediction_type, | |
| lr_downsample_factor=lr_factor, | |
| ) | |
| else: | |
| # Fall back to standard flow matching if the pipeline doesn't support foveation. | |
| models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} | |
| noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep) | |
| loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) | |
| loss = loss * pipe.scheduler.training_weight(timestep) | |
| return loss | |