File size: 4,311 Bytes
d04a061
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import sys
import torch
import random
import numpy as np
from tqdm import tqdm
from src.utils.model_utils import _print

class UnconditionalSampler:
    def __init__(self, tokenizer, model):
        self.model = model
        self.tokenizer = tokenizer

        self.device = self.model.device
        self.mask_id = self.tokenizer.mask_token_id
        self.seed_everything(seed=42)

    @torch.inference_mode()
    def sample_unconditional(self, xt, num_steps, tau=0.7, kappa_fn=lambda t: t, eta=1, alpha=1., banned_token_ids=None, return_logits=None): 
        """
        Stochastic remasking sampling method for iterative refinement of sequences.

        Args:
            xt (Tensor): Initial token tensor.
            num_steps (int): Number of refinement steps.
            tau (float): Temperature parameter for softmax sampling.
            kappa_fn (callable): Function controlling the unmasking schedule.
            eta (float): Scaling factor for score adjustments.
            alpha (float): Weighting for confidence-based scoring.

        Returns:
            Tensor: Final sampled sequence tensor.
        """
        
        dt = 1 / num_steps
        fix_mask = xt != self.mask_id # tokens to retain
        attention_mask = torch.ones_like(xt).to(self.device)

        for i in range(1, num_steps + 1):
            kappa_t = kappa_fn(i * dt)
            logits = self.model(input_ids=xt, attention_mask=attention_mask)
            last_mask = xt == self.mask_id # tokens currently masked
            unmask_t = ~last_mask & ~fix_mask # unmasked and not fixed tokens - candidates for masking

            x0, logp = self.stochastic_sample_from_categorical(logits, tau, banned_token_ids=banned_token_ids) # tokens, logprobs

            # Confidence-based scoring
            entropy = torch.distributions.Categorical(logits=logits).entropy()
            score = alpha * logp + (1 - alpha) * -entropy # alpha = 1 --> score = logp
            score = score.masked_fill(fix_mask, float('inf'))

            score[unmask_t] = score[unmask_t] * eta

            num_to_mask = ((~fix_mask).sum(1, keepdim=True).float() * (1 - kappa_t)).long()
            lowest_k_mask = self.topk_lowest_masking(score, num_to_mask)

            xt[lowest_k_mask] = self.mask_id
            mask_2_x0 = last_mask & ~lowest_k_mask
            xt[mask_2_x0] = x0[mask_2_x0]

            # print(f"Step {i}/{num_steps} | eta: {eta}, alpha: {alpha}, Stochastic remask: \n", xt[0])

        xt[xt == self.mask_id] = x0[xt == self.mask_id]
        return xt, logits if return_logits else xt

    def stochastic_sample_from_categorical(self, logits, temperature, noise_scale=1.0, banned_token_ids=None):
        """
        Sample from a categorical distribution with optional temperature scaling and Gumbel noise.
        """
        logits = logits.double()

        if banned_token_ids is not None:
            banned_token_mask = torch.zeros_like(logits, device=logits.device).bool()
            for token_id in banned_token_ids:
                banned_token_mask[..., token_id] = True
            logits = logits.masked_fill(banned_token_mask, float('-inf'))

        if temperature != 0:
            gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-8) + 1e-8)
            logits = logits / temperature + noise_scale * gumbel_noise
        scores, tokens = logits.log_softmax(dim=-1).max(dim=-1)

        return tokens, scores

    def topk_lowest_masking(self, scores, cutoff_len):
        """
        scores: [b, n]
        cutoff_len: [b, 1]
        returns:
            mask: [b, n], with 1 if the token is in top-k lowest scores, 0 otherwise
        """
        sorted_index = scores.sort(-1)[0]
        cutoff = sorted_index.gather(dim=-1, index=cutoff_len)
        return scores < cutoff

    def seed_everything(self, seed):
        """
        Set the seed for reproducibility across various libraries.
        """
        if seed is None:
            return
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)  # if using multi-GPU
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False