bchao1's picture
Upload foveated_diffusion Gradio demo
606581d verified
Raw
History Blame Contribute Delete
16 kB
"""WandB-aware ModelLogger with foveated validation visualizations.
Subclasses upstream `diffsynth.diffusion.ModelLogger` and adds:
- WandB initialization, loss / metrics / image logging
- Per-validation-step foveated image generation at four mask positions
- Optional target / noise-prediction / L2-error visualization at three timesteps
Token-AE specific code from the fork is removed in this release.
"""
import os
from typing import Any, Callable, List, Optional
import numpy as np
import torch
from accelerate import Accelerator
from PIL import Image
from diffsynth.diffusion import ModelLogger
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
def _create_foveation_mask(h: int, w: int, center: tuple, r: float, device):
"""Circular foveation mask in latent-pixel coords."""
cx = (center[0] + 0.5) * w
cy = (center[1] + 0.5) * h
diagonal = (h ** 2 + w ** 2) ** 0.5
radius_px = r * (diagonal / 2.0)
y = torch.arange(h, device=device, dtype=torch.float32)
x = torch.arange(w, device=device, dtype=torch.float32)
yy, xx = torch.meshgrid(y, x, indexing="ij")
return (xx - cx) ** 2 + (yy - cy) ** 2 <= radius_px ** 2
def _unpatchify_for_vis(latents: torch.Tensor, latent_height: int, latent_width: int):
"""Convert [B, H*W, C] packed FLUX2 latents to [B, 3, H', W'] for visualization."""
B, _, C = latents.shape
latents = latents.reshape(B, latent_height, latent_width, C).permute(0, 3, 1, 2)
if C % 4 != 0:
return latents[:, : min(3, C)].float().cpu()
c_base = C // 4
latents = latents.reshape(B, c_base, 2, 2, latent_height, latent_width)
latents = latents.permute(0, 1, 4, 2, 5, 3).reshape(B, c_base, latent_height * 2, latent_width * 2)
return latents[:, : min(3, c_base)].float().cpu()
def _tensor_to_pil_vis(x: torch.Tensor, vmin: float, vmax: float):
"""Convert tensor [B, 3, H, W] to PIL with a fixed [vmin, vmax] scale."""
x = x.clamp(vmin, vmax)
x = (x - vmin) / (vmax - vmin + 1e-8)
x = (x * 255).clamp(0, 255).byte().permute(0, 2, 3, 1).numpy()
return [Image.fromarray(x[i]) for i in range(x.shape[0])]
class WandbModelLogger(ModelLogger):
"""ModelLogger extended with WandB logging and foveated validation viz."""
def __init__(
self,
output_path: str,
project_name: str = "diffsynth-training",
run_name: Optional[str] = None,
config: Optional[dict] = None,
remove_prefix_in_ckpt: Optional[str] = None,
state_dict_converter: Callable = lambda x: x,
validation_prompts: Optional[List[str]] = None,
validation_steps: int = 500,
log_image_steps: int = 500,
num_validation_images: int = 4,
validation_kwargs: Optional[dict] = None,
):
super().__init__(output_path, remove_prefix_in_ckpt, state_dict_converter)
if not WANDB_AVAILABLE:
raise ImportError("wandb is not installed. Install it with: pip install wandb")
self.project_name = project_name
self.run_name = run_name
self.config = config or {}
self.validation_prompts = validation_prompts or []
self.validation_steps = validation_steps
self.log_image_steps = log_image_steps
self.num_validation_images = num_validation_images
self.validation_kwargs = validation_kwargs or {}
self._wandb_initialized = False
# -----------------------------------------------------------------
# WandB plumbing
# -----------------------------------------------------------------
def init_wandb(self, accelerator: Accelerator):
if self._wandb_initialized:
return
if accelerator.is_main_process:
wandb.init(
project=self.project_name,
name=self.run_name,
config=self.config,
resume="allow",
)
self._wandb_initialized = True
def log_loss(self, loss: float, step: Optional[int] = None):
if not self._wandb_initialized:
return
wandb.log({"train/loss": loss}, step=step if step is not None else self.num_steps)
def log_metrics(self, metrics: dict, step: Optional[int] = None):
if not self._wandb_initialized:
return
wandb.log(metrics, step=step if step is not None else self.num_steps)
def log_images(self, images: List[Any], captions: Optional[List[str]] = None,
key: str = "validation"):
if not self._wandb_initialized:
return
if captions is None:
captions = [f"Image {i}" for i in range(len(images))]
wandb_images = [wandb.Image(img, caption=cap) for img, cap in zip(images, captions)]
wandb.log({key: wandb_images}, step=self.num_steps)
# -----------------------------------------------------------------
# Foveation-aware validation
# -----------------------------------------------------------------
def _is_foveated_pipeline(self, pipe):
return (
hasattr(pipe, "is_foveated_pipeline")
and pipe.is_foveated_pipeline
)
def _log_foveation_masks_once(self, pipe, height: int, width: int):
"""Visualize four foveation masks (left/right/top/bottom) at fixed radius."""
h, w = height // 16, width // 16
device = getattr(pipe, "device", "cuda")
positions = [("left", (-0.3, 0)), ("right", (0.3, 0)),
("top", (0, -0.3)), ("bottom", (0, 0.3))]
mask_images = []
for name, center in positions:
mask = _create_foveation_mask(h, w, center, r=0.25, device=device)
mask_np = mask.float().cpu().numpy() * 255
mask_pil = Image.fromarray(mask_np.astype(np.uint8))
mask_images.append(wandb.Image(mask_pil, caption=f"foveation_mask_{name}"))
if mask_images:
wandb.log({"validation/foveation_masks": mask_images}, step=self.num_steps)
def _visualize_target_noise_pred(self, pipe, height: int, width: int):
"""At 25%/50%/75% timesteps, visualize training_target, noise_pred, and L2 error."""
latent_height, latent_width = height // 16, width // 16
device = getattr(pipe, "device", "cuda")
dtype = getattr(pipe, "torch_dtype", torch.bfloat16)
if not hasattr(pipe, "foveated_training_forward"):
return
try:
prompt = self.validation_prompts[0] if self.validation_prompts else ""
seed = self.validation_kwargs.get("seed", 42)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
pipe.scheduler.set_timesteps(1000, training=True)
ts = pipe.scheduler.timesteps
timestep_ids = [int(len(ts) * 0.25), int(len(ts) * 0.50), int(len(ts) * 0.75)]
labels = ["t0", "t_mid", "t_last"]
all_target_vis, all_pred_vis, all_l2_error = [], [], []
inputs_shared_base = {
"height": height, "width": width, "prompt": prompt, "cfg_scale": 1.0,
"embedded_guidance": 1.0, "input_image": None, "rand_device": str(device),
"seed": seed,
**{k: v for k, v in self.validation_kwargs.items()
if k not in ("height", "width")},
}
inputs_posi, inputs_nega = {"prompt": prompt}, {"negative_prompt": ""}
for unit in pipe.units:
inputs_shared_base, inputs_posi, inputs_nega = pipe.unit_runner(
unit, pipe, inputs_shared_base, inputs_posi, inputs_nega,
)
input_latents = pipe.fixed_clean_latent
noise = torch.randn_like(input_latents)
foveation_mask = _create_foveation_mask(latent_height, latent_width, (0.0, 0.0),
r=0.25, device=device)
for tid, label in zip(timestep_ids, labels):
timestep_id = torch.tensor([tid], device=ts.device)
timestep = ts[timestep_id].to(dtype=dtype, device=device)
latents = pipe.scheduler.add_noise(input_latents, noise, timestep)
training_target = pipe.scheduler.training_target(input_latents, noise, timestep)
inputs_shared = dict(inputs_shared_base)
inputs_shared["latents"] = latents
inputs_shared["input_latents"] = input_latents
inputs_shared["foveation_mask"] = foveation_mask
inputs = dict(inputs_shared, **inputs_posi)
prediction_type = self.validation_kwargs.get("prediction_type", "clean")
lr_downsample_factor = self.validation_kwargs.get("lr_downsample_factor", 2)
noise_pred = pipe.foveated_training_forward(
inputs, timestep, timestep_id, prediction_type,
lr_downsample_factor=lr_downsample_factor,
)
target_vis = _unpatchify_for_vis(training_target, latent_height, latent_width)
pred_vis = _unpatchify_for_vis(noise_pred, latent_height, latent_width)
l2_error = (target_vis - pred_vis).pow(2).sum(dim=1, keepdim=True).sqrt()
all_target_vis.append((label, target_vis))
all_pred_vis.append((label, pred_vis))
all_l2_error.append((label, l2_error))
target_images, pred_images = [], []
for (label, t), (_, p) in zip(all_target_vis, all_pred_vis):
vmin = float(min(t.min().item(), p.min().item()))
vmax = float(max(t.max().item(), p.max().item()))
target_images.append(wandb.Image(_tensor_to_pil_vis(t, vmin, vmax)[0],
caption=f"target_{label}"))
pred_images.append(wandb.Image(_tensor_to_pil_vis(p, vmin, vmax)[0],
caption=f"noise_pred_{label}"))
l2_images = []
for label, e in all_l2_error:
e_3ch = e.expand(-1, 3, -1, -1)
l2_max = float(e_3ch.max().item())
l2_images.append(wandb.Image(_tensor_to_pil_vis(e_3ch, 0.0, l2_max)[0],
caption=f"l2_error_{label}"))
wandb.log({
"validation/target": target_images,
"validation/noise_pred": pred_images,
"validation/l2_error": l2_images,
}, step=self.num_steps)
except Exception as e:
print(f"Warning: Target/noise viz failed: {e}")
def generate_validation_images(self, accelerator: Accelerator, model: torch.nn.Module):
"""Generate validation images (foveated 4-position grid if pipe is foveated)."""
if not accelerator.is_main_process or not self.validation_prompts:
return
unwrapped_model = accelerator.unwrap_model(model)
if not hasattr(unwrapped_model, "pipe"):
return
pipe = unwrapped_model.pipe
pipe.eval()
all_images, all_captions = [], []
use_foveation = self._is_foveated_pipeline(pipe)
height = self.validation_kwargs.get("height", 1024)
width = self.validation_kwargs.get("width", 1024)
with torch.no_grad():
if use_foveation and hasattr(pipe, "foveated_training_forward"):
self._visualize_target_noise_pred(pipe, height, width)
if use_foveation:
self._log_foveation_masks_once(pipe, height, width)
positions = [("left", (-0.3, 0)), ("right", (0.3, 0)),
("top", (0, -0.3)), ("bottom", (0, 0.3))]
h, w = height // 16, width // 16
device = getattr(pipe, "device", "cuda")
for prompt in self.validation_prompts:
for pos_name, center in positions:
try:
mask = _create_foveation_mask(h, w, center, r=0.25, device=device)
kwargs = {"prompt": prompt, "foveation_mask": mask,
**self.validation_kwargs}
images = pipe(**kwargs)
img = images[0] if isinstance(images, (list, tuple)) else images
cap = f"{prompt[:40]}... [{pos_name}]" if len(prompt) > 40 \
else f"{prompt} [{pos_name}]"
all_images.append(img)
all_captions.append(cap)
except Exception as e:
print(f"Warning: Foveated validation failed {pos_name}: {e}")
else:
for prompt in self.validation_prompts:
try:
images = pipe(prompt=prompt, **self.validation_kwargs)
if isinstance(images, (list, tuple)):
for i, img in enumerate(images[: self.num_validation_images]):
cap = f"{prompt[:50]}... [{i}]" if len(prompt) > 50 \
else f"{prompt} [{i}]"
all_images.append(img)
all_captions.append(cap)
else:
cap = prompt[:50] + "..." if len(prompt) > 50 else prompt
all_images.append(images)
all_captions.append(cap)
except Exception as e:
print(f"Warning: Validation failed: {e}")
pipe.train()
if all_images:
self.log_images(all_images, all_captions, key="validation/generated")
# -----------------------------------------------------------------
# ModelLogger hooks
# -----------------------------------------------------------------
def on_training_start(self, accelerator: Accelerator, model: torch.nn.Module):
self.init_wandb(accelerator)
if accelerator.is_main_process:
print("Generating baseline validation images (step 0)...")
self.generate_validation_images(accelerator, model)
def on_step_end(
self,
accelerator: Accelerator,
model: torch.nn.Module,
save_steps: Optional[int] = None,
loss: Optional[float] = None,
**kwargs,
):
self.init_wandb(accelerator)
self.num_steps += 1
if loss is not None and accelerator.is_main_process:
self.log_loss(loss)
if kwargs and accelerator.is_main_process:
metrics = {f"train/{k}": v for k, v in kwargs.items()
if isinstance(v, (int, float))}
if metrics:
self.log_metrics(metrics)
if save_steps is not None and self.num_steps % save_steps == 0:
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
if self.validation_steps > 0 and self.num_steps % self.validation_steps == 0:
self.generate_validation_images(accelerator, model)
def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id: int):
self.init_wandb(accelerator)
if accelerator.is_main_process:
wandb.log({"epoch": epoch_id}, step=self.num_steps)
super().on_epoch_end(accelerator, model, epoch_id)
self.generate_validation_images(accelerator, model)
def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module,
save_steps: Optional[int] = None):
super().on_training_end(accelerator, model, save_steps)
self.generate_validation_images(accelerator, model)
if accelerator.is_main_process and self._wandb_initialized:
wandb.finish()