Todokete commited on
Commit
9468637
·
verified ·
1 Parent(s): cd22ae8

Upload generate.py

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +214 -0
custom_generate/generate.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import random
4
+ import logging
5
+ import copy
6
+ from typing import Union, List, Optional
7
+ from transformers import LogitsProcessor, LogitsProcessorList, StoppingCriteriaList, GenerationConfig
8
+ from transformers.generation.utils import GenerationMixin, GenerateDecoderOnlyOutput
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class XTCLogitsWarper(LogitsProcessor):
13
+ """
14
+ LogitsWarper that implements Exclude Top Choices (XTC).
15
+ """
16
+ def __init__(self, threshold: float, probability: float, protected_token_ids: Optional[List[int]] = None, filter_value: float = -float("Inf")):
17
+ self.threshold = threshold
18
+ self.probability = probability
19
+ self.filter_value = filter_value
20
+ self.protected_token_ids = set(protected_token_ids) if protected_token_ids is not None else set()
21
+
22
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
23
+ if self.probability <= 0.0 or random.random() >= self.probability:
24
+ return scores
25
+
26
+ # Sort scores descending
27
+ sorted_logits, sorted_indices = torch.sort(scores, descending=True)
28
+ probs = sorted_logits.softmax(dim=-1)
29
+
30
+ # Create a mask for removal
31
+ sorted_indices_to_remove = torch.full_like(probs, False, dtype=torch.bool)
32
+
33
+ # XTC Logic
34
+ sorted_indices_to_remove[..., :-1] = probs[..., 1:] >= self.threshold
35
+
36
+ # Scatter back to original indices
37
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
38
+
39
+ # Safety: Check if protected tokens would be removed
40
+ if self.protected_token_ids:
41
+ for pid in self.protected_token_ids:
42
+ if indices_to_remove[:, pid].any():
43
+ # If any protected token is targeted, abort XTC for this step
44
+ return scores
45
+
46
+ # Apply the filter
47
+ scores = scores.masked_fill(indices_to_remove, self.filter_value)
48
+ return scores
49
+
50
+ def _xtc_decoding(
51
+ model,
52
+ input_ids: torch.LongTensor,
53
+ logits_processor: LogitsProcessorList,
54
+ stopping_criteria: StoppingCriteriaList,
55
+ generation_config: GenerationConfig,
56
+ synced_gpus: bool = False,
57
+ streamer: "BaseStreamer" = None,
58
+ **model_kwargs,
59
+ ) -> Union[GenerateDecoderOnlyOutput, torch.LongTensor]:
60
+ """
61
+ Custom decoding loop that ensures XTC is applied during sampling.
62
+ """
63
+
64
+ # 1. Retrieve XTC params from the config (injected by the generate wrapper)
65
+ xtc_threshold = getattr(generation_config, "xtc_threshold", 0.1)
66
+ xtc_probability = getattr(generation_config, "xtc_probability", 0.0)
67
+
68
+ # Identify tokens to protect
69
+ protected_ids = []
70
+ if generation_config.eos_token_id is not None:
71
+ if isinstance(generation_config.eos_token_id, list):
72
+ protected_ids.extend(generation_config.eos_token_id)
73
+ else:
74
+ protected_ids.append(generation_config.eos_token_id)
75
+
76
+ # Check for custom protected tokens injected via config
77
+ custom_protected = getattr(generation_config, "xtc_protected_tokens", None)
78
+ if custom_protected:
79
+ protected_ids.extend(custom_protected)
80
+
81
+ # 2. Inject XTC into the LogitsProcessorList
82
+ if xtc_probability > 0:
83
+ xtc_warper = XTCLogitsWarper(
84
+ threshold=xtc_threshold,
85
+ probability=xtc_probability,
86
+ protected_token_ids=protected_ids
87
+ )
88
+ logits_processor.append(xtc_warper)
89
+
90
+ # 3. Initialization
91
+ pad_token_id = generation_config._pad_token_tensor
92
+ output_attentions = generation_config.output_attentions
93
+ output_hidden_states = generation_config.output_hidden_states
94
+ output_scores = generation_config.output_scores
95
+ return_dict_in_generate = generation_config.return_dict_in_generate
96
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
97
+
98
+ # Ensure sampling is on
99
+ do_sample = True
100
+
101
+ # Init output tuples
102
+ scores = () if (return_dict_in_generate and output_scores) else None
103
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
104
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
105
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
106
+
107
+ # Track finished sequences
108
+ batch_size, cur_length = input_ids.shape[:2]
109
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
110
+ model_kwargs = model._get_initial_cache_position(cur_length, input_ids.device, model_kwargs)
111
+
112
+ this_peer_finished = False
113
+
114
+ # 4. Decoding Loop
115
+ while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
116
+ model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
117
+
118
+ outputs = model(
119
+ **model_inputs,
120
+ return_dict=True,
121
+ output_attentions=output_attentions,
122
+ output_hidden_states=output_hidden_states,
123
+ )
124
+
125
+ if synced_gpus and this_peer_finished:
126
+ continue
127
+
128
+ next_token_logits = outputs.logits[:, -1, :]
129
+
130
+ # Apply Logits Processors (XTC happens here)
131
+ next_token_scores = logits_processor(input_ids, next_token_logits)
132
+
133
+ # Store scores
134
+ if return_dict_in_generate and output_scores:
135
+ scores += (next_token_scores,)
136
+ if return_dict_in_generate and output_attentions:
137
+ decoder_attentions += ((outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,))
138
+ if return_dict_in_generate and output_hidden_states:
139
+ decoder_hidden_states += ((outputs.decoder_hidden_states,) if model.config.is_encoder_decoder else (outputs.hidden_states,))
140
+
141
+ # Sample (Multinomial)
142
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
143
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
144
+
145
+ # EOS check
146
+ if has_eos_stopping_criteria:
147
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
148
+
149
+ # Update inputs
150
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
151
+
152
+ if streamer is not None:
153
+ streamer.put(next_tokens.cpu())
154
+
155
+ model_kwargs = model._update_model_kwargs_for_generation(
156
+ outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
157
+ )
158
+
159
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
160
+ this_peer_finished = unfinished_sequences.max() == 0
161
+
162
+ if streamer is not None:
163
+ streamer.end()
164
+
165
+ if return_dict_in_generate:
166
+ return GenerateDecoderOnlyOutput(
167
+ sequences=input_ids,
168
+ scores=scores,
169
+ attentions=decoder_attentions,
170
+ hidden_states=decoder_hidden_states,
171
+ past_key_values=model_kwargs.get("past_key_values"),
172
+ )
173
+ else:
174
+ return input_ids
175
+
176
+ def generate(model, *args, **kwargs):
177
+ """
178
+ Wrapper function that prepares parameters and calls the internal decoding loop.
179
+ """
180
+ # 1. Extract XTC parameters from kwargs using .pop()
181
+ # This prevents the "unused model_kwargs" warning because they are removed from kwargs
182
+ xtc_probability = kwargs.pop("xtc_probability", 0.0)
183
+ xtc_threshold = kwargs.pop("xtc_threshold", 0.1)
184
+ xtc_protected_tokens = kwargs.pop("xtc_protected_tokens", None)
185
+
186
+ # 2. Prepare GenerationConfig
187
+ # We must handle the case where generation_config is None or not present
188
+ generation_config = kwargs.get("generation_config", None)
189
+
190
+ if generation_config is None:
191
+ # If no config passed, copy the model's default
192
+ generation_config = copy.deepcopy(model.generation_config)
193
+ else:
194
+ # If passed, verify it's not None
195
+ if generation_config is None:
196
+ generation_config = copy.deepcopy(model.generation_config)
197
+
198
+ # Force sampling (XTC doesn't work with greedy)
199
+ generation_config.do_sample = True
200
+
201
+ # 3. Inject XTC params into the config object
202
+ # Python allows dynamic attribute assignment
203
+ generation_config.xtc_probability = xtc_probability
204
+ generation_config.xtc_threshold = xtc_threshold
205
+ generation_config.xtc_protected_tokens = xtc_protected_tokens
206
+
207
+ # Update kwargs with the modified config
208
+ kwargs["generation_config"] = generation_config
209
+
210
+ # 4. Call standard generation, which will route to `custom_generate` (_xtc_decoding)
211
+ # We pass _xtc_decoding as the function to execute
212
+ return GenerationMixin.generate(
213
+ model, *args, custom_generate=_xtc_decoding, **kwargs
214
+ )