Spaces:
Sleeping
Sleeping
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| from configs.config import Config, SamplingConfig | |
| from utils.sampling_utils import restore_cond, _ode_step, _sde_step | |
| # ============================================ | |
| # Generation utilities | |
| # ============================================ | |
| def mask_after_eos(predicted_ids: torch.Tensor, eos_token_id: int, pad_token_id: int) -> torch.Tensor: | |
| """Mask everything at/after first EOS token per sequence.""" | |
| eos_mask = (predicted_ids == eos_token_id) | |
| keep_mask = (eos_mask.to(torch.int32).cumsum(dim=1) == 0) | |
| return torch.where(keep_mask, predicted_ids, torch.full_like(predicted_ids, pad_token_id)) | |
| def shift_left(x: torch.Tensor, shift_per_sample: torch.Tensor, pad_value=0, axis: int = 1) -> torch.Tensor: | |
| """Shift each sample left along the sequence axis; pad emptied positions.""" | |
| if x.dim() < 2: | |
| raise ValueError("x must have at least batch and sequence dimensions") | |
| if axis < 0: | |
| axis = x.dim() + axis | |
| if axis == 0: | |
| raise ValueError("axis=0 is the batch axis and cannot be shifted") | |
| shift_per_sample = shift_per_sample.to(torch.long) | |
| if axis != 1: | |
| x = x.movedim(axis, 1) | |
| seq_len = x.shape[1] | |
| base_idx = torch.arange(seq_len, device=x.device)[None, :] | |
| gather_idx = shift_per_sample[:, None].to(x.device) + base_idx | |
| valid = gather_idx < seq_len | |
| gather_idx = gather_idx.clamp(0, seq_len - 1) | |
| if x.dim() == 2: | |
| shifted = torch.gather(x, 1, gather_idx) | |
| shifted = torch.where(valid, shifted, torch.full_like(shifted, pad_value)) | |
| else: | |
| expand_shape = [-1, -1] + list(x.shape[2:]) | |
| idx = gather_idx.view(*gather_idx.shape, *([1] * (x.dim() - 2))).expand(*expand_shape) | |
| valid_b = valid.view(*valid.shape, *([1] * (x.dim() - 2))).expand(*expand_shape) | |
| shifted = torch.gather(x, 1, idx) | |
| shifted = torch.where(valid_b, shifted, torch.full_like(shifted, pad_value)) | |
| if axis != 1: | |
| shifted = shifted.movedim(1, axis) | |
| return shifted | |
| # ============================================ | |
| # Single-batch sampling (PyTorch) | |
| # ============================================ | |
| def _generate_samples_single_batch( | |
| model: nn.Module, | |
| generator: torch.Generator, | |
| z: torch.Tensor, | |
| t_steps: torch.Tensor, | |
| cond_seq: Optional[torch.Tensor], | |
| cond_seq_mask: Optional[torch.Tensor], | |
| config: Config, | |
| sampling_config: SamplingConfig, | |
| cfg_scale: float, | |
| self_cond_cfg_scale: float, | |
| ) -> torch.Tensor: | |
| """Generate samples for a single batch (PyTorch Euler / SDE rollout).""" | |
| method = sampling_config.sampling_method | |
| batch_size, max_length, d_model = z.shape | |
| if cond_seq is None: | |
| cond_seq = torch.zeros((batch_size, max_length, d_model), dtype=z.dtype, device=z.device) | |
| cond_seq_mask = torch.zeros((batch_size, max_length), dtype=z.dtype, device=z.device) | |
| step_kwargs = dict( | |
| model=model, config=config, | |
| cfg_scale=cfg_scale, self_cond_cfg_scale=self_cond_cfg_scale, | |
| cond_seq=cond_seq, cond_seq_mask=cond_seq_mask, | |
| ) | |
| z = restore_cond(z, cond_seq, cond_seq_mask) | |
| x_pred = restore_cond(torch.zeros_like(z), cond_seq, cond_seq_mask) | |
| n = t_steps.shape[0] | |
| sde_gamma = getattr(sampling_config, "sde_gamma", 0.0) | |
| use_bf16 = bool(getattr(config, "use_bf16", True)) and z.is_cuda | |
| with torch.amp.autocast('cuda', dtype=torch.bfloat16, enabled=use_bf16): | |
| for i in range(n - 2): | |
| t = t_steps[i].item() | |
| t_next = t_steps[i + 1].item() | |
| if method == "sde": | |
| z, x_pred = _sde_step( | |
| z=z, t=t, t_next=t_next, x_pred_prev=x_pred, | |
| gamma=sde_gamma, generator=generator, **step_kwargs, | |
| ) | |
| elif method == "ode": | |
| z, x_pred = _ode_step(z=z, t=t, t_next=t_next, x_pred_prev=x_pred, **step_kwargs) | |
| else: | |
| raise ValueError(f"Invalid sampling method: {method}") | |
| # Last step always with ODE. | |
| t = t_steps[-2].item() | |
| t_next = t_steps[-1].item() | |
| z, x_pred = _ode_step(z=z, t=t, t_next=t_next, x_pred_prev=x_pred, **step_kwargs) | |
| return z | |
| def _dlm_decode_batch(z: torch.Tensor, model: nn.Module, t_final_val, | |
| config, self_cond_cfg_scale: float) -> torch.Tensor: | |
| """Decode z -> tokens with the DLM decoder head.""" | |
| batch_size = z.shape[0] | |
| if isinstance(t_final_val, torch.Tensor) and t_final_val.dim() == 0: | |
| t_final = torch.full((batch_size,), t_final_val.item(), dtype=z.dtype, device=z.device) | |
| else: | |
| t_final = torch.full((batch_size,), float(t_final_val), dtype=z.dtype, device=z.device) | |
| sc_batch = ( | |
| torch.full((batch_size,), float(self_cond_cfg_scale), dtype=z.dtype, device=z.device) | |
| if config.num_self_cond_cfg_tokens > 0 else None | |
| ) | |
| z_input = torch.cat([z, torch.zeros_like(z)], dim=-1) if config.self_cond_prob > 0 else z | |
| use_bf16 = bool(getattr(config, "use_bf16", True)) and z.is_cuda | |
| with torch.amp.autocast('cuda', dtype=torch.bfloat16, enabled=use_bf16): | |
| _, decoder_logits = model( | |
| z_input, t_final, deterministic=True, | |
| self_cond_cfg_scale=sc_batch, | |
| decoder_step_active=True, | |
| ) | |
| return decoder_logits.argmax(dim=-1) | |
| def _build_run_name(sampling_method, num_sampling_steps, cfg_scale, self_cond_cfg_scale, | |
| time_schedule, sde_gamma, suffix): | |
| ts_str = f"-ts_{time_schedule}" | |
| sccfg_str = f"-sccfg{self_cond_cfg_scale}" if self_cond_cfg_scale != 1.0 else "" | |
| sde_str = f"-gamma{sde_gamma}" if sampling_method == "sde" else "" | |
| return f"{sampling_method}-steps{num_sampling_steps}-cfg{cfg_scale}{sccfg_str}{ts_str}{sde_str}-{suffix}" | |