File size: 8,734 Bytes
9468637
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import torch
import torch.nn as nn
import random
import logging
import copy
from typing import Union, List, Optional
from transformers import LogitsProcessor, LogitsProcessorList, StoppingCriteriaList, GenerationConfig
from transformers.generation.utils import GenerationMixin, GenerateDecoderOnlyOutput

logger = logging.getLogger(__name__)

class XTCLogitsWarper(LogitsProcessor):
    """
    LogitsWarper that implements Exclude Top Choices (XTC).
    """
    def __init__(self, threshold: float, probability: float, protected_token_ids: Optional[List[int]] = None, filter_value: float = -float("Inf")):
        self.threshold = threshold
        self.probability = probability
        self.filter_value = filter_value
        self.protected_token_ids = set(protected_token_ids) if protected_token_ids is not None else set()

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if self.probability <= 0.0 or random.random() >= self.probability:
            return scores

        # Sort scores descending
        sorted_logits, sorted_indices = torch.sort(scores, descending=True)
        probs = sorted_logits.softmax(dim=-1)

        # Create a mask for removal
        sorted_indices_to_remove = torch.full_like(probs, False, dtype=torch.bool)

        # XTC Logic
        sorted_indices_to_remove[..., :-1] = probs[..., 1:] >= self.threshold

        # Scatter back to original indices
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)

        # Safety: Check if protected tokens would be removed
        if self.protected_token_ids:
            for pid in self.protected_token_ids:
                if indices_to_remove[:, pid].any():
                    # If any protected token is targeted, abort XTC for this step
                    return scores

        # Apply the filter
        scores = scores.masked_fill(indices_to_remove, self.filter_value)
        return scores

def _xtc_decoding(
    model,
    input_ids: torch.LongTensor,
    logits_processor: LogitsProcessorList,
    stopping_criteria: StoppingCriteriaList,
    generation_config: GenerationConfig,
    synced_gpus: bool = False,
    streamer: "BaseStreamer" = None,
    **model_kwargs,
) -> Union[GenerateDecoderOnlyOutput, torch.LongTensor]:
    """
    Custom decoding loop that ensures XTC is applied during sampling.
    """
    
    # 1. Retrieve XTC params from the config (injected by the generate wrapper)
    xtc_threshold = getattr(generation_config, "xtc_threshold", 0.1)
    xtc_probability = getattr(generation_config, "xtc_probability", 0.0)
    
    # Identify tokens to protect
    protected_ids = []
    if generation_config.eos_token_id is not None:
        if isinstance(generation_config.eos_token_id, list):
            protected_ids.extend(generation_config.eos_token_id)
        else:
            protected_ids.append(generation_config.eos_token_id)
            
    # Check for custom protected tokens injected via config
    custom_protected = getattr(generation_config, "xtc_protected_tokens", None)
    if custom_protected:
        protected_ids.extend(custom_protected)

    # 2. Inject XTC into the LogitsProcessorList
    if xtc_probability > 0:
        xtc_warper = XTCLogitsWarper(
            threshold=xtc_threshold,
            probability=xtc_probability,
            protected_token_ids=protected_ids
        )
        logits_processor.append(xtc_warper)

    # 3. Initialization
    pad_token_id = generation_config._pad_token_tensor
    output_attentions = generation_config.output_attentions
    output_hidden_states = generation_config.output_hidden_states
    output_scores = generation_config.output_scores
    return_dict_in_generate = generation_config.return_dict_in_generate
    has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
    
    # Ensure sampling is on
    do_sample = True 

    # Init output tuples
    scores = () if (return_dict_in_generate and output_scores) else None
    decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
    cross_attentions = () if (return_dict_in_generate and output_attentions) else None
    decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

    # Track finished sequences
    batch_size, cur_length = input_ids.shape[:2]
    unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
    model_kwargs = model._get_initial_cache_position(cur_length, input_ids.device, model_kwargs)

    this_peer_finished = False

    # 4. Decoding Loop
    while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
        model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)

        outputs = model(
            **model_inputs,
            return_dict=True,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        if synced_gpus and this_peer_finished:
            continue

        next_token_logits = outputs.logits[:, -1, :]

        # Apply Logits Processors (XTC happens here)
        next_token_scores = logits_processor(input_ids, next_token_logits)

        # Store scores
        if return_dict_in_generate and output_scores:
            scores += (next_token_scores,)
        if return_dict_in_generate and output_attentions:
             decoder_attentions += ((outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,))
        if return_dict_in_generate and output_hidden_states:
            decoder_hidden_states += ((outputs.decoder_hidden_states,) if model.config.is_encoder_decoder else (outputs.hidden_states,))

        # Sample (Multinomial)
        probs = nn.functional.softmax(next_token_scores, dim=-1)
        next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

        # EOS check
        if has_eos_stopping_criteria:
            next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

        # Update inputs
        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
        
        if streamer is not None:
            streamer.put(next_tokens.cpu())

        model_kwargs = model._update_model_kwargs_for_generation(
            outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
        )

        unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
        this_peer_finished = unfinished_sequences.max() == 0

    if streamer is not None:
        streamer.end()

    if return_dict_in_generate:
        return GenerateDecoderOnlyOutput(
            sequences=input_ids,
            scores=scores,
            attentions=decoder_attentions,
            hidden_states=decoder_hidden_states,
            past_key_values=model_kwargs.get("past_key_values"),
        )
    else:
        return input_ids

def generate(model, *args, **kwargs):
    """
    Wrapper function that prepares parameters and calls the internal decoding loop.
    """
    # 1. Extract XTC parameters from kwargs using .pop()
    # This prevents the "unused model_kwargs" warning because they are removed from kwargs
    xtc_probability = kwargs.pop("xtc_probability", 0.0)
    xtc_threshold = kwargs.pop("xtc_threshold", 0.1)
    xtc_protected_tokens = kwargs.pop("xtc_protected_tokens", None)

    # 2. Prepare GenerationConfig
    # We must handle the case where generation_config is None or not present
    generation_config = kwargs.get("generation_config", None)
    
    if generation_config is None:
        # If no config passed, copy the model's default
        generation_config = copy.deepcopy(model.generation_config)
    else:
        # If passed, verify it's not None
        if generation_config is None: 
            generation_config = copy.deepcopy(model.generation_config)
            
    # Force sampling (XTC doesn't work with greedy)
    generation_config.do_sample = True
    
    # 3. Inject XTC params into the config object
    # Python allows dynamic attribute assignment
    generation_config.xtc_probability = xtc_probability
    generation_config.xtc_threshold = xtc_threshold
    generation_config.xtc_protected_tokens = xtc_protected_tokens
    
    # Update kwargs with the modified config
    kwargs["generation_config"] = generation_config

    # 4. Call standard generation, which will route to `custom_generate` (_xtc_decoding)
    # We pass _xtc_decoding as the function to execute
    return GenerationMixin.generate(
        model, *args, custom_generate=_xtc_decoding, **kwargs
    )