File size: 3,399 Bytes
9b58924
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# -*- coding: utf-8 -*-
"""

Generation related utility functions

"""
import math
import torch
import torch.nn.functional as F
import numpy as np
from typing import Callable, Optional


def add_gumbel_noise(logits, temperature):
    """

    Gumbel noise addition function

    According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality

    Therefore using float64

    """
    if temperature == 0:
        return logits
    logits = logits.to(torch.float64)
    noise = torch.rand_like(logits, dtype=torch.float64)
    gumbel_noise = (- torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise


def cosine_schedule(t: torch.Tensor) -> torch.Tensor:
    """Cosine schedule function: m(t) = cos(Ο€/2 Β· t) – MaskGit paper Eq.(3)"""
    return torch.cos(0.5 * math.pi * t)


def gumbel_noise(t: torch.Tensor, *, generator: Optional[torch.Generator] = None) -> torch.Tensor:
    """Return i.i.d. Gumbel(0,1) noise with same shape as t"""
    if generator is None:
        u = torch.rand_like(t)
    else:
        u = torch.rand(t.shape, device=t.device, dtype=t.dtype, generator=generator)
    return -torch.log(-torch.log(u + 1e-20) + 1e-20)


def gumbel_max_sample(logits: torch.Tensor, tau: float = 1.0, *, generator: Optional[torch.Generator] = None) -> torch.Tensor:
    """Sample from categorical(logits) via Gumbel-Max. Ο„=0 β†’ greedy argmax"""
    if tau == 0.0:
        return logits.argmax(dim=-1)
    g = gumbel_noise(logits, generator=generator)
    return (logits / tau + g).argmax(dim=-1)


def mask_by_random_topk(

    mask_len: torch.Tensor,     # (B,) number of tokens to keep masked

    probs: torch.Tensor,        # (B, L) sampled token probability

    *,

    temperature: float = 1.0,

    generator: Optional[torch.Generator] = None,

) -> torch.BoolTensor:
    """Return Boolean mask – True means *stay masked* for next step"""
    g = gumbel_noise(probs, generator=generator)
    confidence = torch.log(probs.clamp_min(1e-20)) + temperature * g  # higher = more confident
    sorted_conf = torch.sort(confidence, dim=-1).values               # ascending
    k = mask_len.long().unsqueeze(1).clamp_(0, probs.size(1) - 1)
    cut_off = torch.gather(sorted_conf, 1, k)                         # (B,1)
    return confidence < cut_off                                       # (B,L)


def get_num_transfer_tokens(mask_index, steps):
    """

    In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals

    Since LLaDA employs a linear noise schedule (as defined in Eq.(8)),

    the expected number of tokens transitioned at each step should be consistent

    

    This function is designed to precompute the number of tokens that need to be transitioned at each step

    """
    mask_num = mask_index.sum(dim=1, keepdim=True)

    base = mask_num // steps
    remainder = mask_num % steps

    num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base

    for i in range(mask_num.size(0)):
        num_transfer_tokens[i, :remainder[i]] += 1

    return num_transfer_tokens

def setup_seed(seed: int):
    """Set random seed"""
    import random
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)