Todokete commited on
Commit
055a45c
·
verified ·
1 Parent(s): 1973671

Update custom_generate/generate.py

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +214 -92
custom_generate/generate.py CHANGED
@@ -1,101 +1,223 @@
1
  import torch
 
2
  import random
 
 
 
 
3
 
4
- def generate(model, input_ids, generation_config=None, **kwargs):
5
- print("✨ using XTC (Exclude Top Choices) generation ✨")
6
 
7
- generation_config = generation_config or model.generation_config
8
-
9
- # Setup generation parameters
10
- cur_length = input_ids.shape[1]
11
- max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # Extract XTC parameters from config or kwargs, with defaults
14
- # Note: You can pass these in the generate() call or set them in generation_config
15
- xtc_threshold = kwargs.get("xtc_threshold", getattr(generation_config, "xtc_threshold", 0.1))
16
- xtc_probability = kwargs.get("xtc_probability", getattr(generation_config, "xtc_probability", 0.0))
17
- temperature = kwargs.get("temperature", getattr(generation_config, "temperature", 1.0))
18
 
19
- # Identify special tokens to protect (EOS is critical, Newline is preferred if known)
20
- eos_token_id = generation_config.eos_token_id
21
- pad_token_id = generation_config.pad_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # Try to find newline token if possible, though hard without direct tokenizer access in this scope.
24
- # Users can pass `protected_token_ids` in kwargs to be specific.
25
- protected_token_ids = kwargs.get("protected_token_ids", [])
26
- if eos_token_id is not None:
27
- protected_token_ids.append(eos_token_id)
28
-
29
- # Basic handling for left_padding (from original example)
30
- left_padding = kwargs.get("left_padding", None)
31
- if left_padding is not None:
32
- if not isinstance(left_padding, int) or left_padding < 0:
33
- raise ValueError(f"left_padding must be an integer larger than 0, but is {left_padding}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- pad_token = kwargs.get("pad_token", None) or pad_token_id
36
- if pad_token is None:
37
- raise ValueError("pad_token is not defined")
38
-
39
- batch_size = input_ids.shape[0]
40
- pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token).to(input_ids.device)
41
- input_ids = torch.cat((pad_tensor, input_ids), dim=1)
42
- cur_length = input_ids.shape[1]
43
-
44
- # Sampling Loop
45
- while cur_length < max_length:
46
- with torch.no_grad():
47
- outputs = model(input_ids)
48
- next_token_logits = outputs.logits[:, -1, :]
49
-
50
- # 1. Apply Temperature
51
- if temperature != 1.0 and temperature > 0:
52
- next_token_logits = next_token_logits / temperature
53
-
54
- # 2. Apply XTC (Exclude Top Choices)
55
- # Logic ported from text-generation-webui sampler_hijack.py
56
- if xtc_probability > 0.0 and random.random() < xtc_probability:
57
- # Calculate probabilities for sorting
58
- # We sort descending to find the "Top" choices
59
- sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
60
- sorted_probs = sorted_logits.softmax(dim=-1)
61
-
62
- # Identify tokens to remove
63
- sorted_indices_to_remove = torch.full_like(sorted_probs, False, dtype=torch.bool)
64
-
65
- # XTC Logic:
66
- # If the *next* token in the sorted list is above threshold,
67
- # it means the current token is a "Top Choice" that can be excluded,
68
- # because we still have good alternatives remaining.
69
- sorted_indices_to_remove[..., :-1] = sorted_probs[..., 1:] >= xtc_threshold
70
-
71
- # Scatter back to original indices
72
- indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
73
-
74
- # Safety Check: Don't remove if special tokens (EOS, etc) are targeted
75
- # If any protected token is in the removal mask, we abort XTC for this step
76
- should_abort = False
77
- if protected_token_ids:
78
- # Check if any protected token is marked for removal
79
- # We convert list to tensor for indexing
80
- protected_tensor = torch.tensor(protected_token_ids, device=input_ids.device)
81
- # We check if any of the columns corresponding to protected tokens are True
82
- if indices_to_remove[:, protected_tensor].any():
83
- should_abort = True
84
-
85
- if not should_abort:
86
- next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float("Inf"))
87
-
88
- # 3. Sample
89
- # Using multinomial sampling (softmax + random selection)
90
- probs = torch.softmax(next_token_logits, dim=-1)
91
- next_tokens = torch.multinomial(probs, num_samples=1)
92
-
93
- # Update input_ids
94
- input_ids = torch.cat((input_ids, next_tokens), dim=-1)
95
- cur_length += 1
96
-
97
- # Stop if all batch items have hit EOS (optional optimization)
98
- if eos_token_id is not None and (next_tokens == eos_token_id).all():
99
- break
100
 
101
- return input_ids
 
 
 
 
 
1
  import torch
2
+ import torch.nn as nn
3
  import random
4
+ import logging
5
+ from typing import Union, List, Optional
6
+ from transformers import LogitsProcessor, LogitsProcessorList, StoppingCriteriaList, GenerationConfig
7
+ from transformers.generation.utils import GenerationMixin, GenerateDecoderOnlyOutput
8
 
9
+ logger = logging.getLogger(__name__)
 
10
 
11
+ class XTCLogitsWarper(LogitsProcessor):
12
+ """
13
+ LogitsWarper that implements Exclude Top Choices (XTC).
14
+ Based on the implementation from text-generation-webui.
15
+ """
16
+ def __init__(self, threshold: float, probability: float, protected_token_ids: Optional[List[int]] = None, filter_value: float = -float("Inf")):
17
+ self.threshold = threshold
18
+ self.probability = probability
19
+ self.filter_value = filter_value
20
+ self.protected_token_ids = set(protected_token_ids) if protected_token_ids is not None else set()
21
+
22
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
23
+ # If probability is 0 or random roll fails, do nothing
24
+ if self.probability <= 0.0 or random.random() >= self.probability:
25
+ return scores
26
+
27
+ # Sort scores descending
28
+ sorted_logits, sorted_indices = torch.sort(scores, descending=True)
29
+ probs = sorted_logits.softmax(dim=-1)
30
+
31
+ # Create a mask for removal
32
+ sorted_indices_to_remove = torch.full_like(probs, False, dtype=torch.bool)
33
+
34
+ # XTC Logic:
35
+ # If the *next* token in the sorted list is above the threshold,
36
+ # then the current token is considered a "top choice" that can be skipped.
37
+ # This keeps the "tail" but trims the "head" if the head is redundant.
38
+ sorted_indices_to_remove[..., :-1] = probs[..., 1:] >= self.threshold
39
+
40
+ # Scatter back to original indices
41
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
42
+
43
+ # Safety: Check if any protected tokens (EOS, Newline) would be removed
44
+ if self.protected_token_ids:
45
+ # Check if any of the columns corresponding to protected IDs are marked True
46
+ # We iterate because constructing a full tensor for boolean indexing can be slow if list is small
47
+ protected_safe = True
48
+ for pid in self.protected_token_ids:
49
+ if indices_to_remove[:, pid].any():
50
+ protected_safe = False
51
+ break
52
+
53
+ if not protected_safe:
54
+ return scores
55
+
56
+ # Apply the filter
57
+ scores = scores.masked_fill(indices_to_remove, self.filter_value)
58
+ return scores
59
+
60
+ def _xtc_decoding(
61
+ model,
62
+ input_ids: torch.LongTensor,
63
+ logits_processor: LogitsProcessorList,
64
+ stopping_criteria: StoppingCriteriaList,
65
+ generation_config: GenerationConfig,
66
+ synced_gpus: bool = False,
67
+ streamer: "BaseStreamer" = None,
68
+ **model_kwargs,
69
+ ) -> Union[GenerateDecoderOnlyOutput, torch.LongTensor]:
70
+ """
71
+ Custom decoding loop that ensures XTC is applied during sampling.
72
+ """
73
 
74
+ # 1. Setup XTC Configuration
75
+ xtc_threshold = getattr(generation_config, "xtc_threshold", 0.1)
76
+ xtc_probability = getattr(generation_config, "xtc_probability", 0.0)
 
 
77
 
78
+ # Identify tokens to protect (EOS and Newlines are standard for XTC)
79
+ protected_ids = []
80
+ if generation_config.eos_token_id is not None:
81
+ if isinstance(generation_config.eos_token_id, list):
82
+ protected_ids.extend(generation_config.eos_token_id)
83
+ else:
84
+ protected_ids.append(generation_config.eos_token_id)
85
+
86
+ # Try to detect newline token (assumes basic ASCII/Llama tokenizer structure if not provided)
87
+ # Users can provide `xtc_protected_tokens` in config if needed.
88
+ custom_protected = getattr(generation_config, "xtc_protected_tokens", None)
89
+ if custom_protected:
90
+ protected_ids.extend(custom_protected)
91
+ else:
92
+ # Fallback heuristic: try to find \n in the model embeddings if possible,
93
+ # or rely on standard ID 13 (Llama/Mistral)
94
+ # Without tokenizer access, we cannot guarantee correct \n detection.
95
+ # It is safer to rely on probability check or user input.
96
+ pass
97
+
98
+ # 2. Inject XTC into the LogitsProcessorList
99
+ # We add it *before* the sampling step (which happens inside the loop)
100
+ # Note: If the user passed temperature/top_p, they are already in `logits_processor`.
101
+ if xtc_probability > 0:
102
+ xtc_warper = XTCLogitsWarper(
103
+ threshold=xtc_threshold,
104
+ probability=xtc_probability,
105
+ protected_token_ids=protected_ids
106
+ )
107
+ logits_processor.append(xtc_warper)
108
+
109
+ # 3. Initialization (Standard Transformers Logic)
110
+ pad_token_id = generation_config._pad_token_tensor
111
+ output_attentions = generation_config.output_attentions
112
+ output_hidden_states = generation_config.output_hidden_states
113
+ output_scores = generation_config.output_scores
114
+ return_dict_in_generate = generation_config.return_dict_in_generate
115
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
116
 
117
+ # We enforce sampling, because XTC with greedy search (argmax) logic is contradictory.
118
+ # (XTC removes top tokens -> Greedy picks the next best -> effectively just degradation).
119
+ do_sample = True
120
+
121
+ # Init output tuples
122
+ scores = () if (return_dict_in_generate and output_scores) else None
123
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
124
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
125
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
126
+
127
+ # Track finished sequences
128
+ batch_size, cur_length = input_ids.shape[:2]
129
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
130
+ model_kwargs = model._get_initial_cache_position(cur_length, input_ids.device, model_kwargs)
131
+
132
+ this_peer_finished = False
133
+
134
+ # 4. The Decoding Loop
135
+ while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
136
+ # Prepare inputs
137
+ model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
138
+
139
+ # Forward pass
140
+ outputs = model(
141
+ **model_inputs,
142
+ return_dict=True,
143
+ output_attentions=output_attentions,
144
+ output_hidden_states=output_hidden_states,
145
+ )
146
+
147
+ if synced_gpus and this_peer_finished:
148
+ continue # don't waste resources
149
+
150
+ # Clone logits for return if needed
151
+ next_token_logits = outputs.logits[:, -1, :]
152
+
153
+ # Apply Logits Processors (This includes: RepetitionPenalty, Temperature, TopP, AND XTC)
154
+ next_token_scores = logits_processor(input_ids, next_token_logits)
155
+
156
+ # Store scores/hidden states if requested
157
+ if return_dict_in_generate and output_scores:
158
+ scores += (next_token_scores,)
159
+ if return_dict_in_generate and output_attentions:
160
+ decoder_attentions += ((outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,))
161
+ if return_dict_in_generate and output_hidden_states:
162
+ decoder_hidden_states += ((outputs.decoder_hidden_states,) if model.config.is_encoder_decoder else (outputs.hidden_states,))
163
+
164
+ # Sample (Multinomial)
165
+ # XTC modifies the distribution (zeros out top tokens), so we sample from the remainder.
166
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
167
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
168
+
169
+ # Handle EOS safety for batching
170
+ if has_eos_stopping_criteria:
171
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
172
+
173
+ # Update inputs for next step
174
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
175
 
176
+ # Streamer interaction
177
+ if streamer is not None:
178
+ streamer.put(next_tokens.cpu())
179
+
180
+ # Update model kwargs (cache, attention mask, etc)
181
+ model_kwargs = model._update_model_kwargs_for_generation(
182
+ outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
183
+ )
184
+
185
+ # Update stopping criteria
186
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
187
+ this_peer_finished = unfinished_sequences.max() == 0
188
+
189
+ if streamer is not None:
190
+ streamer.end()
191
+
192
+ # 5. Return Formatted Output
193
+ if return_dict_in_generate:
194
+ return GenerateDecoderOnlyOutput(
195
+ sequences=input_ids,
196
+ scores=scores,
197
+ attentions=decoder_attentions,
198
+ hidden_states=decoder_hidden_states,
199
+ past_key_values=model_kwargs.get("past_key_values"),
200
+ )
201
+ else:
202
+ return input_ids
203
+
204
+ def generate(model, *args, **kwargs):
205
+ """
206
+ Custom generate function integrating XTC (Exclude Top Choices).
207
+
208
+ Arguments in `kwargs` or `generation_config` to control XTC:
209
+ xtc_probability (float): Probability to perform XTC check (0.0 to 1.0). Default 0.0 (disabled).
210
+ xtc_threshold (float): The threshold for defining a "top choice". Default 0.1.
211
+ xtc_protected_tokens (List[int]): Optional list of specific token IDs to prevent XTC from removing (e.g., newlines).
212
+ """
213
+ # XTC is effectively a sampler, so we should ensure do_sample is True in the config
214
+ if "generation_config" in kwargs:
215
+ kwargs["generation_config"].do_sample = True
216
+ elif "do_sample" not in kwargs:
217
+ kwargs["do_sample"] = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
+ # Delegate to the standard GenerationMixin, injecting our custom decoding loop
220
+ generation_outputs = GenerationMixin.generate(
221
+ model, *args, custom_generate=_xtc_decoding, **kwargs
222
+ )
223
+ return generation_outputs