Todokete commited on
Commit
cd22ae8
·
verified ·
1 Parent(s): d84b1bf

Delete custom_generate/generate.py

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +0 -217
custom_generate/generate.py DELETED
@@ -1,217 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import random
4
- import logging
5
- from typing import Union, List, Optional
6
- from transformers import LogitsProcessor, LogitsProcessorList, StoppingCriteriaList, GenerationConfig
7
- from transformers.generation.utils import GenerationMixin, GenerateDecoderOnlyOutput
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
- class XTCLogitsWarper(LogitsProcessor):
12
- """
13
- LogitsWarper that implements Exclude Top Choices (XTC).
14
- Based on the implementation from text-generation-webui.
15
- """
16
- def __init__(self, threshold: float, probability: float, protected_token_ids: Optional[List[int]] = None, filter_value: float = -float("Inf")):
17
- self.threshold = threshold
18
- self.probability = probability
19
- self.filter_value = filter_value
20
- self.protected_token_ids = set(protected_token_ids) if protected_token_ids is not None else set()
21
-
22
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
23
- # If probability is 0 or random roll fails, do nothing
24
- if self.probability <= 0.0 or random.random() >= self.probability:
25
- return scores
26
-
27
- # Sort scores descending
28
- sorted_logits, sorted_indices = torch.sort(scores, descending=True)
29
- probs = sorted_logits.softmax(dim=-1)
30
-
31
- # Create a mask for removal
32
- sorted_indices_to_remove = torch.full_like(probs, False, dtype=torch.bool)
33
-
34
- # XTC Logic:
35
- # If the *next* token in the sorted list is above the threshold,
36
- # then the current token is considered a "top choice" that can be skipped.
37
- # This keeps the "tail" but trims the "head" if the head is redundant.
38
- sorted_indices_to_remove[..., :-1] = probs[..., 1:] >= self.threshold
39
-
40
- # Scatter back to original indices
41
- indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
42
-
43
- # Safety: Check if any protected tokens (EOS, Newline) would be removed
44
- if self.protected_token_ids:
45
- # Check if any of the columns corresponding to protected IDs are marked True
46
- # We iterate because constructing a full tensor for boolean indexing can be slow if list is small
47
- protected_safe = True
48
- for pid in self.protected_token_ids:
49
- if indices_to_remove[:, pid].any():
50
- protected_safe = False
51
- break
52
-
53
- if not protected_safe:
54
- return scores
55
-
56
- # Apply the filter
57
- scores = scores.masked_fill(indices_to_remove, self.filter_value)
58
- return scores
59
-
60
- def _xtc_decoding(
61
- model,
62
- input_ids: torch.LongTensor,
63
- logits_processor: LogitsProcessorList,
64
- stopping_criteria: StoppingCriteriaList,
65
- generation_config: GenerationConfig,
66
- synced_gpus: bool = False,
67
- streamer: "BaseStreamer" = None,
68
- **model_kwargs,
69
- ) -> Union[GenerateDecoderOnlyOutput, torch.LongTensor]:
70
- """
71
- Custom decoding loop that ensures XTC is applied during sampling.
72
- """
73
-
74
- # 1. Setup XTC Configuration
75
- xtc_threshold = getattr(generation_config, "xtc_threshold", 0.1)
76
- xtc_probability = getattr(generation_config, "xtc_probability", 0.0)
77
-
78
- # Identify tokens to protect (EOS and Newlines are standard for XTC)
79
- protected_ids = []
80
- if generation_config.eos_token_id is not None:
81
- if isinstance(generation_config.eos_token_id, list):
82
- protected_ids.extend(generation_config.eos_token_id)
83
- else:
84
- protected_ids.append(generation_config.eos_token_id)
85
-
86
- # Try to detect newline token (assumes basic ASCII/Llama tokenizer structure if not provided)
87
- # Users can provide `xtc_protected_tokens` in config if needed.
88
- custom_protected = getattr(generation_config, "xtc_protected_tokens", None)
89
- if custom_protected:
90
- protected_ids.extend(custom_protected)
91
- else:
92
- # Fallback heuristic: try to find \n in the model embeddings if possible,
93
- # or rely on standard ID 13 (Llama/Mistral)
94
- # Without tokenizer access, we cannot guarantee correct \n detection.
95
- # It is safer to rely on probability check or user input.
96
- pass
97
-
98
- # 2. Inject XTC into the LogitsProcessorList
99
- # We add it *before* the sampling step (which happens inside the loop)
100
- # Note: If the user passed temperature/top_p, they are already in `logits_processor`.
101
- if xtc_probability > 0:
102
- xtc_warper = XTCLogitsWarper(
103
- threshold=xtc_threshold,
104
- probability=xtc_probability,
105
- protected_token_ids=protected_ids
106
- )
107
- logits_processor.append(xtc_warper)
108
-
109
- # 3. Initialization (Standard Transformers Logic)
110
- pad_token_id = generation_config._pad_token_tensor
111
- output_attentions = generation_config.output_attentions
112
- output_hidden_states = generation_config.output_hidden_states
113
- output_scores = generation_config.output_scores
114
- return_dict_in_generate = generation_config.return_dict_in_generate
115
- has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
116
-
117
- # We enforce sampling, because XTC with greedy search (argmax) logic is contradictory.
118
- # (XTC removes top tokens -> Greedy picks the next best -> effectively just degradation).
119
- do_sample = True
120
-
121
- # Init output tuples
122
- scores = () if (return_dict_in_generate and output_scores) else None
123
- decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
124
- cross_attentions = () if (return_dict_in_generate and output_attentions) else None
125
- decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
126
-
127
- # Track finished sequences
128
- batch_size, cur_length = input_ids.shape[:2]
129
- unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
130
- model_kwargs = model._get_initial_cache_position(cur_length, input_ids.device, model_kwargs)
131
-
132
- this_peer_finished = False
133
-
134
- # 4. The Decoding Loop
135
- while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
136
- # Prepare inputs
137
- model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
138
-
139
- # Forward pass
140
- outputs = model(
141
- **model_inputs,
142
- return_dict=True,
143
- output_attentions=output_attentions,
144
- output_hidden_states=output_hidden_states,
145
- )
146
-
147
- if synced_gpus and this_peer_finished:
148
- continue # don't waste resources
149
-
150
- # Clone logits for return if needed
151
- next_token_logits = outputs.logits[:, -1, :]
152
-
153
- # Apply Logits Processors (This includes: RepetitionPenalty, Temperature, TopP, AND XTC)
154
- next_token_scores = logits_processor(input_ids, next_token_logits)
155
-
156
- # Store scores/hidden states if requested
157
- if return_dict_in_generate and output_scores:
158
- scores += (next_token_scores,)
159
- if return_dict_in_generate and output_attentions:
160
- decoder_attentions += ((outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,))
161
- if return_dict_in_generate and output_hidden_states:
162
- decoder_hidden_states += ((outputs.decoder_hidden_states,) if model.config.is_encoder_decoder else (outputs.hidden_states,))
163
-
164
- # Sample (Multinomial)
165
- # XTC modifies the distribution (zeros out top tokens), so we sample from the remainder.
166
- probs = nn.functional.softmax(next_token_scores, dim=-1)
167
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
168
-
169
- # Handle EOS safety for batching
170
- if has_eos_stopping_criteria:
171
- next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
172
-
173
- # Update inputs for next step
174
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
175
-
176
- # Streamer interaction
177
- if streamer is not None:
178
- streamer.put(next_tokens.cpu())
179
-
180
- # Update model kwargs (cache, attention mask, etc)
181
- model_kwargs = model._update_model_kwargs_for_generation(
182
- outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
183
- )
184
-
185
- # Update stopping criteria
186
- unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
187
- this_peer_finished = unfinished_sequences.max() == 0
188
-
189
- if streamer is not None:
190
- streamer.end()
191
-
192
- # 5. Return Formatted Output
193
- if return_dict_in_generate:
194
- return GenerateDecoderOnlyOutput(
195
- sequences=input_ids,
196
- scores=scores,
197
- attentions=decoder_attentions,
198
- hidden_states=decoder_hidden_states,
199
- past_key_values=model_kwargs.get("past_key_values"),
200
- )
201
- else:
202
- return input_ids
203
-
204
- def generate(model, *args, **kwargs):
205
- """
206
- Custom generate function integrating XTC (Exclude Top Choices).
207
-
208
- Arguments in `kwargs` or `generation_config` to control XTC:
209
- xtc_probability (float): Probability to perform XTC check (0.0 to 1.0). Default 0.0 (disabled).
210
- xtc_threshold (float): The threshold for defining a "top choice". Default 0.1.
211
- xtc_protected_tokens (List[int]): Optional list of specific token IDs to prevent XTC from removing (e.g., newlines).
212
- """
213
- # Delegate to the standard GenerationMixin, injecting our custom decoding loop
214
- generation_outputs = GenerationMixin.generate(
215
- model, *args, custom_generate=_xtc_decoding, **kwargs
216
- )
217
- return generation_outputs