Spaces:
Running
on
Zero
Running
on
Zero
| # -*- coding: utf-8 -*- | |
| """ | |
| Generation related utility functions | |
| """ | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from typing import Callable, Optional | |
| def add_gumbel_noise(logits, temperature): | |
| """ | |
| Gumbel noise addition function | |
| According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality | |
| Therefore using float64 | |
| """ | |
| if temperature == 0: | |
| return logits | |
| logits = logits.to(torch.float64) | |
| noise = torch.rand_like(logits, dtype=torch.float64) | |
| gumbel_noise = (- torch.log(noise)) ** temperature | |
| return logits.exp() / gumbel_noise | |
| def cosine_schedule(t: torch.Tensor) -> torch.Tensor: | |
| """Cosine schedule function: m(t) = cos(Ο/2 Β· t) β MaskGit paper Eq.(3)""" | |
| return torch.cos(0.5 * math.pi * t) | |
| def gumbel_noise(t: torch.Tensor, *, generator: Optional[torch.Generator] = None) -> torch.Tensor: | |
| """Return i.i.d. Gumbel(0,1) noise with same shape as t""" | |
| if generator is None: | |
| u = torch.rand_like(t) | |
| else: | |
| u = torch.rand(t.shape, device=t.device, dtype=t.dtype, generator=generator) | |
| return -torch.log(-torch.log(u + 1e-20) + 1e-20) | |
| def gumbel_max_sample(logits: torch.Tensor, tau: float = 1.0, *, generator: Optional[torch.Generator] = None) -> torch.Tensor: | |
| """Sample from categorical(logits) via Gumbel-Max. Ο=0 β greedy argmax""" | |
| if tau == 0.0: | |
| return logits.argmax(dim=-1) | |
| g = gumbel_noise(logits, generator=generator) | |
| return (logits / tau + g).argmax(dim=-1) | |
| def mask_by_random_topk( | |
| mask_len: torch.Tensor, # (B,) number of tokens to keep masked | |
| probs: torch.Tensor, # (B, L) sampled token probability | |
| *, | |
| temperature: float = 1.0, | |
| generator: Optional[torch.Generator] = None, | |
| ) -> torch.BoolTensor: | |
| """Return Boolean mask β True means *stay masked* for next step""" | |
| g = gumbel_noise(probs, generator=generator) | |
| confidence = torch.log(probs.clamp_min(1e-20)) + temperature * g # higher = more confident | |
| sorted_conf = torch.sort(confidence, dim=-1).values # ascending | |
| k = mask_len.long().unsqueeze(1).clamp_(0, probs.size(1) - 1) | |
| cut_off = torch.gather(sorted_conf, 1, k) # (B,1) | |
| return confidence < cut_off # (B,L) | |
| def get_num_transfer_tokens(mask_index, steps): | |
| """ | |
| In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals | |
| Since LLaDA employs a linear noise schedule (as defined in Eq.(8)), | |
| the expected number of tokens transitioned at each step should be consistent | |
| This function is designed to precompute the number of tokens that need to be transitioned at each step | |
| """ | |
| mask_num = mask_index.sum(dim=1, keepdim=True) | |
| base = mask_num // steps | |
| remainder = mask_num % steps | |
| num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base | |
| for i in range(mask_num.size(0)): | |
| num_transfer_tokens[i, :remainder[i]] += 1 | |
| return num_transfer_tokens | |
| def setup_seed(seed: int): | |
| """Set random seed""" | |
| import random | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| random.seed(seed) | |