yagizdevre commited on
Commit
550c31b
·
verified ·
1 Parent(s): 4f58476

Update modeling_ministu.py

Browse files
Files changed (1) hide show
  1. modeling_ministu.py +97 -0
modeling_ministu.py CHANGED
@@ -138,3 +138,100 @@ class MiniSTU(PreTrainedModel):
138
  torch.nn.init.zeros_(module.c_attn.bias)
139
  if module.c_proj.bias is not None:
140
  torch.nn.init.zeros_(module.c_proj.bias)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  torch.nn.init.zeros_(module.c_attn.bias)
139
  if module.c_proj.bias is not None:
140
  torch.nn.init.zeros_(module.c_proj.bias)
141
+
142
+ @staticmethod
143
+ def top_k_top_p_filtering(
144
+ logits: torch.Tensor,
145
+ top_k: int = 50,
146
+ top_p: float = 0.95,
147
+ filter_value: float = float("-inf"),
148
+ ):
149
+ """
150
+ Filters a distribution of logits using top-k and/or nucleus (top-p) filtering.
151
+ """
152
+ # top_k
153
+ if top_k > 0:
154
+ top_k = min(top_k, logits.size(-1))
155
+ # Remove all logits that are not in the top k
156
+ indices_to_remove = logits < torch.topk(logits, top_k, dim=-1).values[:, -1, None]
157
+ logits[indices_to_remove] = filter_value
158
+
159
+ # top_p (nucleus)
160
+ if 0 < top_p < 1.0:
161
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
162
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
163
+
164
+ # Remove tokens with cumulative probability above the threshold
165
+ sorted_indices_to_remove = cumulative_probs > top_p
166
+ # Shift the indices to the right to keep also the first token above the threshold
167
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
168
+ sorted_indices_to_remove[:, 0] = False
169
+
170
+ indices_to_remove = sorted_indices_to_remove.scatter(
171
+ dim=1, index=sorted_indices, src=sorted_indices_to_remove
172
+ )
173
+ logits[indices_to_remove] = filter_value
174
+
175
+ return logits
176
+
177
+ def generate(
178
+ self,
179
+ input_ids: torch.LongTensor,
180
+ max_new_tokens: int = 50,
181
+ temperature: float = 1.0,
182
+ top_k: int = 50,
183
+ top_p: float = 0.95,
184
+ eos_token_id: int = None,
185
+ pad_token_id: int = 0,
186
+ **kwargs
187
+ ):
188
+ """
189
+ Naive token-by-token generation loop that uses top-k/top-p filtering and optional temperature.
190
+
191
+ Args:
192
+ input_ids (torch.LongTensor): shape (batch_size, sequence_length).
193
+ max_new_tokens (int): max number of tokens to generate (beyond input_ids length).
194
+ temperature (float): sampling temperature (>=0).
195
+ top_k (int): Top-K sampling cutoff.
196
+ top_p (float): Nucleus sampling cutoff.
197
+ eos_token_id (int): If set, stop generation when this token is produced.
198
+ pad_token_id (int): If set, can be used to pad sequences. (Not fully used here.)
199
+ kwargs: Unused arguments (like num_beams) for compatibility.
200
+
201
+ Returns:
202
+ torch.LongTensor: shape (batch_size, sequence_length + generated_tokens).
203
+ """
204
+ device = input_ids.device
205
+
206
+ # We'll accumulate new tokens into generated_ids
207
+ generated_ids = input_ids.clone()
208
+
209
+ for _ in range(max_new_tokens):
210
+ # Forward pass to get logits for the last token
211
+ outputs = self.forward(generated_ids)
212
+ logits = outputs.logits[:, -1, :] # shape: (batch_size, vocab_size)
213
+
214
+ # Scale logits by temperature
215
+ if temperature != 1.0:
216
+ logits = logits / temperature
217
+
218
+ # Filter logits using top-k and/or top-p
219
+ logits = self.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
220
+
221
+ # Convert to probabilities
222
+ probabilities = F.softmax(logits, dim=-1)
223
+
224
+ # Sample from the distribution
225
+ next_token = torch.multinomial(probabilities, num_samples=1) # (batch_size, 1)
226
+
227
+ # Append next token
228
+ generated_ids = torch.cat([generated_ids, next_token], dim=1)
229
+
230
+ # If eos_token_id is set and any sample produced it, we optionally could break early
231
+ if eos_token_id is not None:
232
+ # Check if all sequences in the batch ended
233
+ # or if you want to do a more fine-grained approach
234
+ if (next_token == eos_token_id).all():
235
+ break
236
+
237
+ return generated_ids