Spaces:
Runtime error
Runtime error
| """WandB-aware ModelLogger with foveated validation visualizations. | |
| Subclasses upstream `diffsynth.diffusion.ModelLogger` and adds: | |
| - WandB initialization, loss / metrics / image logging | |
| - Per-validation-step foveated image generation at four mask positions | |
| - Optional target / noise-prediction / L2-error visualization at three timesteps | |
| Token-AE specific code from the fork is removed in this release. | |
| """ | |
| import os | |
| from typing import Any, Callable, List, Optional | |
| import numpy as np | |
| import torch | |
| from accelerate import Accelerator | |
| from PIL import Image | |
| from diffsynth.diffusion import ModelLogger | |
| try: | |
| import wandb | |
| WANDB_AVAILABLE = True | |
| except ImportError: | |
| WANDB_AVAILABLE = False | |
| def _create_foveation_mask(h: int, w: int, center: tuple, r: float, device): | |
| """Circular foveation mask in latent-pixel coords.""" | |
| cx = (center[0] + 0.5) * w | |
| cy = (center[1] + 0.5) * h | |
| diagonal = (h ** 2 + w ** 2) ** 0.5 | |
| radius_px = r * (diagonal / 2.0) | |
| y = torch.arange(h, device=device, dtype=torch.float32) | |
| x = torch.arange(w, device=device, dtype=torch.float32) | |
| yy, xx = torch.meshgrid(y, x, indexing="ij") | |
| return (xx - cx) ** 2 + (yy - cy) ** 2 <= radius_px ** 2 | |
| def _unpatchify_for_vis(latents: torch.Tensor, latent_height: int, latent_width: int): | |
| """Convert [B, H*W, C] packed FLUX2 latents to [B, 3, H', W'] for visualization.""" | |
| B, _, C = latents.shape | |
| latents = latents.reshape(B, latent_height, latent_width, C).permute(0, 3, 1, 2) | |
| if C % 4 != 0: | |
| return latents[:, : min(3, C)].float().cpu() | |
| c_base = C // 4 | |
| latents = latents.reshape(B, c_base, 2, 2, latent_height, latent_width) | |
| latents = latents.permute(0, 1, 4, 2, 5, 3).reshape(B, c_base, latent_height * 2, latent_width * 2) | |
| return latents[:, : min(3, c_base)].float().cpu() | |
| def _tensor_to_pil_vis(x: torch.Tensor, vmin: float, vmax: float): | |
| """Convert tensor [B, 3, H, W] to PIL with a fixed [vmin, vmax] scale.""" | |
| x = x.clamp(vmin, vmax) | |
| x = (x - vmin) / (vmax - vmin + 1e-8) | |
| x = (x * 255).clamp(0, 255).byte().permute(0, 2, 3, 1).numpy() | |
| return [Image.fromarray(x[i]) for i in range(x.shape[0])] | |
| class WandbModelLogger(ModelLogger): | |
| """ModelLogger extended with WandB logging and foveated validation viz.""" | |
| def __init__( | |
| self, | |
| output_path: str, | |
| project_name: str = "diffsynth-training", | |
| run_name: Optional[str] = None, | |
| config: Optional[dict] = None, | |
| remove_prefix_in_ckpt: Optional[str] = None, | |
| state_dict_converter: Callable = lambda x: x, | |
| validation_prompts: Optional[List[str]] = None, | |
| validation_steps: int = 500, | |
| log_image_steps: int = 500, | |
| num_validation_images: int = 4, | |
| validation_kwargs: Optional[dict] = None, | |
| ): | |
| super().__init__(output_path, remove_prefix_in_ckpt, state_dict_converter) | |
| if not WANDB_AVAILABLE: | |
| raise ImportError("wandb is not installed. Install it with: pip install wandb") | |
| self.project_name = project_name | |
| self.run_name = run_name | |
| self.config = config or {} | |
| self.validation_prompts = validation_prompts or [] | |
| self.validation_steps = validation_steps | |
| self.log_image_steps = log_image_steps | |
| self.num_validation_images = num_validation_images | |
| self.validation_kwargs = validation_kwargs or {} | |
| self._wandb_initialized = False | |
| # ----------------------------------------------------------------- | |
| # WandB plumbing | |
| # ----------------------------------------------------------------- | |
| def init_wandb(self, accelerator: Accelerator): | |
| if self._wandb_initialized: | |
| return | |
| if accelerator.is_main_process: | |
| wandb.init( | |
| project=self.project_name, | |
| name=self.run_name, | |
| config=self.config, | |
| resume="allow", | |
| ) | |
| self._wandb_initialized = True | |
| def log_loss(self, loss: float, step: Optional[int] = None): | |
| if not self._wandb_initialized: | |
| return | |
| wandb.log({"train/loss": loss}, step=step if step is not None else self.num_steps) | |
| def log_metrics(self, metrics: dict, step: Optional[int] = None): | |
| if not self._wandb_initialized: | |
| return | |
| wandb.log(metrics, step=step if step is not None else self.num_steps) | |
| def log_images(self, images: List[Any], captions: Optional[List[str]] = None, | |
| key: str = "validation"): | |
| if not self._wandb_initialized: | |
| return | |
| if captions is None: | |
| captions = [f"Image {i}" for i in range(len(images))] | |
| wandb_images = [wandb.Image(img, caption=cap) for img, cap in zip(images, captions)] | |
| wandb.log({key: wandb_images}, step=self.num_steps) | |
| # ----------------------------------------------------------------- | |
| # Foveation-aware validation | |
| # ----------------------------------------------------------------- | |
| def _is_foveated_pipeline(self, pipe): | |
| return ( | |
| hasattr(pipe, "is_foveated_pipeline") | |
| and pipe.is_foveated_pipeline | |
| ) | |
| def _log_foveation_masks_once(self, pipe, height: int, width: int): | |
| """Visualize four foveation masks (left/right/top/bottom) at fixed radius.""" | |
| h, w = height // 16, width // 16 | |
| device = getattr(pipe, "device", "cuda") | |
| positions = [("left", (-0.3, 0)), ("right", (0.3, 0)), | |
| ("top", (0, -0.3)), ("bottom", (0, 0.3))] | |
| mask_images = [] | |
| for name, center in positions: | |
| mask = _create_foveation_mask(h, w, center, r=0.25, device=device) | |
| mask_np = mask.float().cpu().numpy() * 255 | |
| mask_pil = Image.fromarray(mask_np.astype(np.uint8)) | |
| mask_images.append(wandb.Image(mask_pil, caption=f"foveation_mask_{name}")) | |
| if mask_images: | |
| wandb.log({"validation/foveation_masks": mask_images}, step=self.num_steps) | |
| def _visualize_target_noise_pred(self, pipe, height: int, width: int): | |
| """At 25%/50%/75% timesteps, visualize training_target, noise_pred, and L2 error.""" | |
| latent_height, latent_width = height // 16, width // 16 | |
| device = getattr(pipe, "device", "cuda") | |
| dtype = getattr(pipe, "torch_dtype", torch.bfloat16) | |
| if not hasattr(pipe, "foveated_training_forward"): | |
| return | |
| try: | |
| prompt = self.validation_prompts[0] if self.validation_prompts else "" | |
| seed = self.validation_kwargs.get("seed", 42) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| np.random.seed(seed) | |
| pipe.scheduler.set_timesteps(1000, training=True) | |
| ts = pipe.scheduler.timesteps | |
| timestep_ids = [int(len(ts) * 0.25), int(len(ts) * 0.50), int(len(ts) * 0.75)] | |
| labels = ["t0", "t_mid", "t_last"] | |
| all_target_vis, all_pred_vis, all_l2_error = [], [], [] | |
| inputs_shared_base = { | |
| "height": height, "width": width, "prompt": prompt, "cfg_scale": 1.0, | |
| "embedded_guidance": 1.0, "input_image": None, "rand_device": str(device), | |
| "seed": seed, | |
| **{k: v for k, v in self.validation_kwargs.items() | |
| if k not in ("height", "width")}, | |
| } | |
| inputs_posi, inputs_nega = {"prompt": prompt}, {"negative_prompt": ""} | |
| for unit in pipe.units: | |
| inputs_shared_base, inputs_posi, inputs_nega = pipe.unit_runner( | |
| unit, pipe, inputs_shared_base, inputs_posi, inputs_nega, | |
| ) | |
| input_latents = pipe.fixed_clean_latent | |
| noise = torch.randn_like(input_latents) | |
| foveation_mask = _create_foveation_mask(latent_height, latent_width, (0.0, 0.0), | |
| r=0.25, device=device) | |
| for tid, label in zip(timestep_ids, labels): | |
| timestep_id = torch.tensor([tid], device=ts.device) | |
| timestep = ts[timestep_id].to(dtype=dtype, device=device) | |
| latents = pipe.scheduler.add_noise(input_latents, noise, timestep) | |
| training_target = pipe.scheduler.training_target(input_latents, noise, timestep) | |
| inputs_shared = dict(inputs_shared_base) | |
| inputs_shared["latents"] = latents | |
| inputs_shared["input_latents"] = input_latents | |
| inputs_shared["foveation_mask"] = foveation_mask | |
| inputs = dict(inputs_shared, **inputs_posi) | |
| prediction_type = self.validation_kwargs.get("prediction_type", "clean") | |
| lr_downsample_factor = self.validation_kwargs.get("lr_downsample_factor", 2) | |
| noise_pred = pipe.foveated_training_forward( | |
| inputs, timestep, timestep_id, prediction_type, | |
| lr_downsample_factor=lr_downsample_factor, | |
| ) | |
| target_vis = _unpatchify_for_vis(training_target, latent_height, latent_width) | |
| pred_vis = _unpatchify_for_vis(noise_pred, latent_height, latent_width) | |
| l2_error = (target_vis - pred_vis).pow(2).sum(dim=1, keepdim=True).sqrt() | |
| all_target_vis.append((label, target_vis)) | |
| all_pred_vis.append((label, pred_vis)) | |
| all_l2_error.append((label, l2_error)) | |
| target_images, pred_images = [], [] | |
| for (label, t), (_, p) in zip(all_target_vis, all_pred_vis): | |
| vmin = float(min(t.min().item(), p.min().item())) | |
| vmax = float(max(t.max().item(), p.max().item())) | |
| target_images.append(wandb.Image(_tensor_to_pil_vis(t, vmin, vmax)[0], | |
| caption=f"target_{label}")) | |
| pred_images.append(wandb.Image(_tensor_to_pil_vis(p, vmin, vmax)[0], | |
| caption=f"noise_pred_{label}")) | |
| l2_images = [] | |
| for label, e in all_l2_error: | |
| e_3ch = e.expand(-1, 3, -1, -1) | |
| l2_max = float(e_3ch.max().item()) | |
| l2_images.append(wandb.Image(_tensor_to_pil_vis(e_3ch, 0.0, l2_max)[0], | |
| caption=f"l2_error_{label}")) | |
| wandb.log({ | |
| "validation/target": target_images, | |
| "validation/noise_pred": pred_images, | |
| "validation/l2_error": l2_images, | |
| }, step=self.num_steps) | |
| except Exception as e: | |
| print(f"Warning: Target/noise viz failed: {e}") | |
| def generate_validation_images(self, accelerator: Accelerator, model: torch.nn.Module): | |
| """Generate validation images (foveated 4-position grid if pipe is foveated).""" | |
| if not accelerator.is_main_process or not self.validation_prompts: | |
| return | |
| unwrapped_model = accelerator.unwrap_model(model) | |
| if not hasattr(unwrapped_model, "pipe"): | |
| return | |
| pipe = unwrapped_model.pipe | |
| pipe.eval() | |
| all_images, all_captions = [], [] | |
| use_foveation = self._is_foveated_pipeline(pipe) | |
| height = self.validation_kwargs.get("height", 1024) | |
| width = self.validation_kwargs.get("width", 1024) | |
| with torch.no_grad(): | |
| if use_foveation and hasattr(pipe, "foveated_training_forward"): | |
| self._visualize_target_noise_pred(pipe, height, width) | |
| if use_foveation: | |
| self._log_foveation_masks_once(pipe, height, width) | |
| positions = [("left", (-0.3, 0)), ("right", (0.3, 0)), | |
| ("top", (0, -0.3)), ("bottom", (0, 0.3))] | |
| h, w = height // 16, width // 16 | |
| device = getattr(pipe, "device", "cuda") | |
| for prompt in self.validation_prompts: | |
| for pos_name, center in positions: | |
| try: | |
| mask = _create_foveation_mask(h, w, center, r=0.25, device=device) | |
| kwargs = {"prompt": prompt, "foveation_mask": mask, | |
| **self.validation_kwargs} | |
| images = pipe(**kwargs) | |
| img = images[0] if isinstance(images, (list, tuple)) else images | |
| cap = f"{prompt[:40]}... [{pos_name}]" if len(prompt) > 40 \ | |
| else f"{prompt} [{pos_name}]" | |
| all_images.append(img) | |
| all_captions.append(cap) | |
| except Exception as e: | |
| print(f"Warning: Foveated validation failed {pos_name}: {e}") | |
| else: | |
| for prompt in self.validation_prompts: | |
| try: | |
| images = pipe(prompt=prompt, **self.validation_kwargs) | |
| if isinstance(images, (list, tuple)): | |
| for i, img in enumerate(images[: self.num_validation_images]): | |
| cap = f"{prompt[:50]}... [{i}]" if len(prompt) > 50 \ | |
| else f"{prompt} [{i}]" | |
| all_images.append(img) | |
| all_captions.append(cap) | |
| else: | |
| cap = prompt[:50] + "..." if len(prompt) > 50 else prompt | |
| all_images.append(images) | |
| all_captions.append(cap) | |
| except Exception as e: | |
| print(f"Warning: Validation failed: {e}") | |
| pipe.train() | |
| if all_images: | |
| self.log_images(all_images, all_captions, key="validation/generated") | |
| # ----------------------------------------------------------------- | |
| # ModelLogger hooks | |
| # ----------------------------------------------------------------- | |
| def on_training_start(self, accelerator: Accelerator, model: torch.nn.Module): | |
| self.init_wandb(accelerator) | |
| if accelerator.is_main_process: | |
| print("Generating baseline validation images (step 0)...") | |
| self.generate_validation_images(accelerator, model) | |
| def on_step_end( | |
| self, | |
| accelerator: Accelerator, | |
| model: torch.nn.Module, | |
| save_steps: Optional[int] = None, | |
| loss: Optional[float] = None, | |
| **kwargs, | |
| ): | |
| self.init_wandb(accelerator) | |
| self.num_steps += 1 | |
| if loss is not None and accelerator.is_main_process: | |
| self.log_loss(loss) | |
| if kwargs and accelerator.is_main_process: | |
| metrics = {f"train/{k}": v for k, v in kwargs.items() | |
| if isinstance(v, (int, float))} | |
| if metrics: | |
| self.log_metrics(metrics) | |
| if save_steps is not None and self.num_steps % save_steps == 0: | |
| self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors") | |
| if self.validation_steps > 0 and self.num_steps % self.validation_steps == 0: | |
| self.generate_validation_images(accelerator, model) | |
| def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id: int): | |
| self.init_wandb(accelerator) | |
| if accelerator.is_main_process: | |
| wandb.log({"epoch": epoch_id}, step=self.num_steps) | |
| super().on_epoch_end(accelerator, model, epoch_id) | |
| self.generate_validation_images(accelerator, model) | |
| def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, | |
| save_steps: Optional[int] = None): | |
| super().on_training_end(accelerator, model, save_steps) | |
| self.generate_validation_images(accelerator, model) | |
| if accelerator.is_main_process and self._wandb_initialized: | |
| wandb.finish() | |