Spaces:
Sleeping
Sleeping
| from typing import Optional | |
| import torch | |
| import torch.nn.functional as F | |
| # ============================================ | |
| # Noise Schedulers (how to compute z from x0 and noise) | |
| # ============================================ | |
| def add_noise(x0, noise, t, config, cond_seq_mask=None): | |
| """Flow-matching interpolation z = t*x0 + (1-t)*noise*scale, preserving cond tokens.""" | |
| t_expanded = t.reshape(-1, 1, 1) | |
| z = t_expanded * x0 + (1 - t_expanded) * noise * config.denoiser_noise_scale | |
| if cond_seq_mask is not None: | |
| z = cond_seq_mask * x0 + (1 - cond_seq_mask) * z | |
| return z | |
| # ============================================ | |
| # Time Schedulers (how to sample t) | |
| # ============================================ | |
| def sample_timesteps( | |
| batch_size: int, | |
| P_mean: float = -0.8, | |
| P_std: float = 0.8, | |
| time_schedule: str = 'logit_normal', | |
| device: Optional[torch.device] = None, | |
| dtype: torch.dtype = torch.float32, | |
| ): | |
| """Sample timesteps using various time schedules. | |
| Args: | |
| batch_size: Number of samples | |
| P_mean: Mean for logit-normal distribution | |
| P_std: Std for logit-normal distribution | |
| time_schedule: 'logit_normal' or 'uniform' | |
| Returns: | |
| Sampled timesteps in [0, 1] | |
| """ | |
| if time_schedule == 'logit_normal': | |
| z = torch.randn((batch_size,), dtype=dtype, device=device) * P_std + P_mean | |
| return torch.sigmoid(z) | |
| if time_schedule == 'uniform': | |
| return torch.rand((batch_size,), dtype=dtype, device=device) | |
| raise ValueError(f"Unknown time_schedule: {time_schedule}") | |
| def get_sampling_steps( | |
| n_steps: int, time_schedule: str = "logit_normal", | |
| P_mean: float = -0.8, P_std: float = 0.8, | |
| device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, | |
| ) -> torch.Tensor: | |
| """Return a length-(n_steps+1) tensor of t values in [0, 1] for a sampling run. | |
| - "uniform": evenly-spaced linspace from 0 to 1 (deterministic). | |
| - "logit_normal": sorted logit-normal samples with 0 / 1 endpoints (random). | |
| """ | |
| if time_schedule == "uniform": | |
| return torch.linspace(0.0, 1.0, n_steps + 1, dtype=dtype, device=device) | |
| if time_schedule == "logit_normal": | |
| steps = sample_timesteps( | |
| batch_size=n_steps - 1, | |
| P_mean=P_mean, P_std=P_std, time_schedule=time_schedule, | |
| device=device, dtype=dtype, | |
| ) | |
| steps = torch.sort(steps).values | |
| endpoints_lo = torch.zeros((1,), dtype=dtype, device=steps.device) | |
| endpoints_hi = torch.ones((1,), dtype=dtype, device=steps.device) | |
| return torch.cat([endpoints_lo, steps, endpoints_hi], dim=0) | |
| raise ValueError(f"Unknown time_schedule: {time_schedule}") | |
| # ============================================ | |
| # CFG Scale Sampling (how to sample cfg scale) | |
| # ============================================ | |
| def sample_cfg_scale(batch_size, cfg_min=0.0, cfg_max=3.0, | |
| dtype=torch.float32, device=None): | |
| """Sample CFG scale from log-uniform distribution in [cfg_min, cfg_max].""" | |
| u = torch.rand((batch_size,), dtype=dtype, device=device) | |
| a = float(1.0 + cfg_min) | |
| b = float(1.0 + cfg_max) | |
| log_ratio = torch.tensor(b / a, dtype=dtype, device=u.device).log() | |
| return a * torch.exp(u * log_ratio) - 1.0 | |
| # ============================================ | |
| # Conditioning helpers (preserve clean tokens during sampling) | |
| # ============================================ | |
| def restore_cond(z_updated, cond_seq, cond_seq_mask): | |
| """Restore clean conditioning tokens in z after a denoising step.""" | |
| mask = cond_seq_mask | |
| target_ndim = max(z_updated.dim(), cond_seq.dim()) | |
| while mask.dim() < target_ndim: | |
| mask = mask.unsqueeze(-1) | |
| return torch.where(mask > 0, cond_seq, z_updated) | |
| def restore_vx(v, x, cond_seq, cond_seq_mask): | |
| """Restore cond positions: x -> clean cond_seq, v -> 0 (cond tokens don't move).""" | |
| if cond_seq is not None: | |
| x = restore_cond(x, cond_seq, cond_seq_mask) | |
| v = restore_cond(v, torch.zeros_like(cond_seq), cond_seq_mask) | |
| return v, x | |
| # ============================================ | |
| # Flow-matching forward passes (with optional self-cond / CFG) | |
| # ============================================ | |
| def net_out_to_v_x(net_out, z, t, t_eps=5e-2): | |
| """Convert x_pred network output to v and x. | |
| When the model returns a tuple (denoised_output, decoder_logits), | |
| decoder logits are discarded here (used separately in training). | |
| """ | |
| if isinstance(net_out, tuple): | |
| net_out = net_out[0] | |
| t_reshaped = t.reshape(-1, 1, 1) | |
| x = net_out | |
| denom = torch.clamp(1.0 - t_reshaped, min=t_eps) | |
| v = (x - z) / denom | |
| return v, x | |
| def _forward_sample_self_cond( | |
| model, z, t_batch, x_pred_prev, config, | |
| self_cond_cfg_scale, cond_seq, cond_seq_mask, | |
| ): | |
| """Forward pass with self-conditioning.""" | |
| t_eps = config.t_eps | |
| self_cond_prob = config.self_cond_prob | |
| def _restore(v, x): | |
| return restore_vx(v, x, cond_seq=cond_seq, cond_seq_mask=cond_seq_mask) | |
| if config.num_self_cond_cfg_tokens > 0: | |
| if x_pred_prev is None: | |
| x_pred_prev = restore_cond(torch.zeros_like(z), cond_seq, cond_seq_mask) | |
| z_input_cond = torch.cat([z, x_pred_prev], dim=-1) | |
| self_cond_scale_batch = torch.full((z.shape[0],), float(self_cond_cfg_scale), | |
| dtype=z.dtype, device=z.device) | |
| net_out_cond = model(z_input_cond, t_batch, deterministic=True, | |
| self_cond_cfg_scale=self_cond_scale_batch) | |
| v_cond, x_cond = net_out_to_v_x(net_out_cond, z, t_batch, t_eps) | |
| return _restore(v_cond, x_cond) | |
| # No self-conditioning | |
| if self_cond_prob == 0: | |
| net_out = model(z, t_batch, deterministic=True) | |
| v, x = net_out_to_v_x(net_out, z, t_batch, t_eps) | |
| return _restore(v, x) | |
| # Combined unconditional and conditional forward pass | |
| v_uncond = x_uncond = None | |
| if self_cond_cfg_scale != 1 or x_pred_prev is None: | |
| z_uncond = restore_cond(torch.zeros_like(z), cond_seq, cond_seq_mask) | |
| z_input_uncond = torch.cat([z, z_uncond], dim=-1) | |
| net_out_uncond = model(z_input_uncond, t_batch, deterministic=True) | |
| v_uncond, x_uncond = net_out_to_v_x(net_out_uncond, z, t_batch, t_eps) | |
| v_uncond, x_uncond = _restore(v_uncond, x_uncond) | |
| if self_cond_cfg_scale == 0.0 or x_pred_prev is None: | |
| return v_uncond, x_uncond | |
| z_input_cond = torch.cat([z, x_pred_prev], dim=-1) | |
| net_out_cond = model(z_input_cond, t_batch, deterministic=True) | |
| v_cond, x_cond = net_out_to_v_x(net_out_cond, z, t_batch, t_eps) | |
| v_cond, x_cond = _restore(v_cond, x_cond) | |
| if self_cond_cfg_scale == 1: | |
| return v_cond, x_cond | |
| v_out = v_uncond + self_cond_cfg_scale * (v_cond - v_uncond) | |
| x_out = x_uncond + self_cond_cfg_scale * (x_cond - x_uncond) | |
| return _restore(v_out, x_out) | |
| def _forward_sample( | |
| model, z, t_batch, x_pred_prev, config, | |
| cfg_scale, self_cond_cfg_scale, cond_seq, cond_seq_mask, | |
| ): | |
| """Forward pass with optional self-conditioning and CFG.""" | |
| v_cond, x_cond = _forward_sample_self_cond( | |
| model, z, t_batch, x_pred_prev, config, | |
| self_cond_cfg_scale=self_cond_cfg_scale, | |
| cond_seq=cond_seq, cond_seq_mask=cond_seq_mask, | |
| ) | |
| if cfg_scale == 1.0: | |
| return v_cond, x_cond | |
| # Unconditional forward: zero out cond prefix, no self-cond state, no restore | |
| z_uncond = restore_cond(z, torch.zeros_like(z), cond_seq_mask) | |
| x_pred_prev_uncond = ( | |
| None if x_pred_prev is None | |
| else restore_cond(x_pred_prev, torch.zeros_like(x_pred_prev), cond_seq_mask) | |
| ) | |
| v_uncond, x_uncond = _forward_sample_self_cond( | |
| model, z_uncond, t_batch, x_pred_prev_uncond, config, | |
| self_cond_cfg_scale=self_cond_cfg_scale, | |
| cond_seq=torch.zeros_like(cond_seq), cond_seq_mask=cond_seq_mask, | |
| ) | |
| v_out = v_uncond + cfg_scale * (v_cond - v_uncond) | |
| x_out = x_uncond + cfg_scale * (x_cond - x_uncond) | |
| return restore_vx(v_out, x_out, cond_seq, cond_seq_mask) | |
| def _ode_step( | |
| model, z, t, t_next, x_pred_prev, | |
| config, cfg_scale, self_cond_cfg_scale, | |
| cond_seq, cond_seq_mask, | |
| ): | |
| """Single ODE (Euler) step for sampling.""" | |
| t_batch = torch.full((z.shape[0],), float(t), dtype=z.dtype, device=z.device) | |
| v_pred, x_pred = _forward_sample( | |
| model=model, z=z, t_batch=t_batch, x_pred_prev=x_pred_prev, | |
| config=config, cfg_scale=cfg_scale, self_cond_cfg_scale=self_cond_cfg_scale, | |
| cond_seq=cond_seq, cond_seq_mask=cond_seq_mask, | |
| ) | |
| return z + (t_next - t) * v_pred, x_pred | |
| def _sde_step( | |
| model, z, t, t_next, x_pred_prev, | |
| config, cfg_scale, self_cond_cfg_scale, | |
| cond_seq, cond_seq_mask, gamma, generator, | |
| ): | |
| """Per-step SDE-style sampler with hybrid (t-and-step) noise scaling. | |
| t_back = t * (1 - gamma * h), where h = t_next - t. alpha = 1 - gamma*h is the | |
| signal-preservation fraction, constant in t. gamma=0 degenerates to a plain ODE step. | |
| Uniform-N-step equivalence with old multiplicative gamma_old: gamma_hybrid = gamma_old * N. | |
| """ | |
| h = float(t_next - t) | |
| alpha = max(0.0, min(1.0, 1.0 - gamma * h)) | |
| t_back = alpha * float(t) | |
| if z.is_cuda: | |
| eps = torch.randn(z.shape, dtype=z.dtype, device=z.device) * config.denoiser_noise_scale | |
| else: | |
| eps = torch.randn(z.shape, generator=generator, dtype=z.dtype) * config.denoiser_noise_scale | |
| z_back = restore_cond(alpha * z + (1.0 - alpha) * eps, cond_seq, cond_seq_mask) | |
| t_batch = torch.full((z.shape[0],), t_back, dtype=z.dtype, device=z.device) | |
| v_pred, x_pred = _forward_sample( | |
| model=model, z=z_back, t_batch=t_batch, x_pred_prev=x_pred_prev, | |
| config=config, cfg_scale=cfg_scale, self_cond_cfg_scale=self_cond_cfg_scale, | |
| cond_seq=cond_seq, cond_seq_mask=cond_seq_mask, | |
| ) | |
| return z_back + (t_next - t_back) * v_pred, x_pred | |