English
naveensp commited on
Commit
df83555
·
verified ·
1 Parent(s): dacc037

Delete beam_search.py

Browse files
Files changed (1) hide show
  1. beam_search.py +0 -1078
beam_search.py DELETED
@@ -1,1078 +0,0 @@
1
- """
2
- This is a self-contained and flexible beam search implementation adapted from
3
- AllenNLP's beam search: https://github.com/allenai/allennlp/blob/main/allennlp/nn/beam_search.py
4
- """
5
-
6
- import copy
7
- import warnings
8
- from abc import abstractmethod
9
- from inspect import signature
10
- from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast
11
-
12
- import torch
13
-
14
- __all__ = [
15
- "Sampler",
16
- "DeterministicSampler",
17
- "MultinomialSampler",
18
- "TopKSampler",
19
- "TopPSampler",
20
- "GumbelSampler",
21
- "FinalSequenceScorer",
22
- "SequenceLogProbabilityScorer",
23
- "LengthNormalizedSequenceLogProbabilityScorer",
24
- "Constraint",
25
- "RepeatedNGramBlockingConstraint",
26
- "BeamSearch",
27
- ]
28
-
29
- StateType = Dict[str, torch.Tensor]
30
- StepFunctionTypeWithTimestep = Callable[[torch.Tensor, StateType, int], Tuple[torch.Tensor, StateType]]
31
- StepFunctionTypeNoTimestep = Callable[[torch.Tensor, StateType], Tuple[torch.Tensor, StateType]]
32
-
33
- StepFunctionType = TypeVar("StepFunctionType", StepFunctionTypeWithTimestep, StepFunctionTypeNoTimestep)
34
- """
35
- The type of step function that can be passed to [`BeamSearch.search`](#search).
36
-
37
- This can either be [`StepFunctionTypeWithTimestep`](#stepfunctiontypewithtimestep)
38
- or [`StepFunctionTypeNoTimestep`](#stepfunctiontypenotimestep).
39
- """
40
-
41
- ConstraintStateType = List[List[Dict[str, Any]]]
42
-
43
-
44
- class Sampler:
45
- """
46
- An abstract class that can be used to sample candidates (either nodes or beams)
47
- within `BeamSearch`.
48
-
49
- A `Sampler` just has three methods, `init_state()`, `sample_nodes()` and `sample_beams()`.
50
-
51
- `init_state()` takes three arguments:
52
-
53
- - a tensor of starting log probs with shape `(batch_size,, num_classes)`,
54
- - the batch size, an int,
55
- - and the number of classes, also an int.
56
-
57
- It returns a state dictionary with any state tensors needed for subsequent
58
- calls to `sample_nodes()` and `sample_beams()`.
59
-
60
- By default this method just returns an empty dictionary.
61
-
62
- Both `sample_nodes()` and `sample_beams()` should take three arguments:
63
-
64
- - tensor of normalized log probabilities with shape `(batch_size, num_examples)`,
65
- - an integer representing the number of samples to take for each example in the batch,
66
- - and a state dictionary which could contain any tensors needed for the `Sampler` to keep
67
- track of state.
68
-
69
- For `sample_nodes()`, `num_examples = num_classes`, but for `sample_beams`,
70
- `num_examples = beam_size * per_node_beam_size`.
71
-
72
- The return value should be a tuple containing:
73
-
74
- - a tensor of log probabilities of the sampled examples with shape `(batch_size, num_samples)`,
75
- - a tensor of indices of the sampled examples with shape `(batch_size, num_samples)`,
76
- - and the updated state dictionary.
77
-
78
- A default implementation of `sample_beams` is provided, which just deterministically
79
- picks the `k` examples with highest log probability.
80
- """
81
-
82
- def init_state(
83
- self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
84
- ) -> StateType:
85
- del start_class_log_probabilities, batch_size, num_classes
86
- return {}
87
-
88
- @abstractmethod
89
- def sample_nodes(
90
- self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
91
- ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
92
- raise NotImplementedError
93
-
94
- def sample_beams(
95
- self, log_probs: torch.Tensor, beam_size: int, state: StateType
96
- ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
97
- del state
98
- selected_log_probs, selected_indices = torch.topk(log_probs, beam_size, dim=-1)
99
- return selected_log_probs, selected_indices, {}
100
-
101
-
102
- class DeterministicSampler(Sampler):
103
- """
104
- A `Sampler` that just deterministically returns the `k` nodes or beams with highest
105
- log probability.
106
- """
107
-
108
- def sample_nodes(
109
- self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
110
- ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
111
- del state
112
- selected_log_probs, selected_indices = torch.topk(log_probs, per_node_beam_size, dim=-1)
113
- return selected_log_probs, selected_indices, {}
114
-
115
-
116
- class MultinomialSampler(Sampler):
117
- """
118
- A `Sampler` which samples nodes from the given multinomial distribution. Beams are sampled
119
- in the default, non-deterministic way.
120
-
121
- :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
122
- above 1.0 produces a flatter probability distribution.
123
- :param with_replacement: Whether to sample with replacement.
124
-
125
- """
126
-
127
- def __init__(
128
- self,
129
- temperature: float = 1.0,
130
- with_replacement: bool = False,
131
- ) -> None:
132
- self.temperature = temperature
133
- self.with_replacement = with_replacement
134
-
135
- def sample_nodes(
136
- self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
137
- ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
138
- if self.temperature != 1.0:
139
- _probabilities = torch.nn.functional.softmax(log_probs / self.temperature, dim=-1)
140
- else:
141
- _probabilities = log_probs.exp()
142
-
143
- selected_indices = torch.multinomial(_probabilities, per_node_beam_size, replacement=self.with_replacement)
144
-
145
- return torch.gather(log_probs, 1, selected_indices), selected_indices, state
146
-
147
-
148
- class TopKSampler(Sampler):
149
- """
150
- A `Sampler` which redistributes the probability mass function for nodes among the
151
- top `k` choices, then samples from that subset after re-normalizing the probabilities.
152
-
153
- Beams are sampled in the default, deterministic way.
154
-
155
- :param k: The number of top choices to be selected from.
156
- :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
157
- above 1.0 produces a flatter probability distribution.
158
- :param with_replacement: If set to `True`, samples will be selected with replacement from the top k choices.
159
- """
160
-
161
- def __init__(
162
- self,
163
- k: int = 1,
164
- temperature: float = 1.0,
165
- with_replacement: bool = False,
166
- ):
167
- self.k = k
168
- self.temperature = temperature or 1.0
169
- self.with_replacement = with_replacement
170
-
171
- def sample_nodes(
172
- self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
173
- ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
174
- if not per_node_beam_size <= self.k <= log_probs.size()[1]:
175
- raise ValueError(
176
- "k must be a postive integer no less than per_node_beam_size and no greater than vocabulary size"
177
- )
178
-
179
- # shape (both): (batch_size, k)
180
- top_k_log_probs, top_k_indices = log_probs.topk(self.k, dim=-1)
181
-
182
- # Apply temperature if necessary.
183
- # shape: (batch_size, k)
184
- if self.temperature != 1.0:
185
- top_k_log_probs = top_k_log_probs / self.temperature
186
-
187
- # Re-normalize the subset.
188
- # shape: (batch_size, k)
189
- normalized_top_k_probs = torch.nn.functional.softmax(top_k_log_probs, dim=-1)
190
-
191
- # Sample from the re-normalized subset.
192
- # NOTE: These indices are not indices into `log_probs`, they are indices into `top_k_log_probs`.
193
- # shape: (batch_size, per_node_beam_size)
194
- sampled_indices = torch.multinomial(
195
- normalized_top_k_probs, per_node_beam_size, replacement=self.with_replacement
196
- )
197
-
198
- # Convert `sampled_indices` back to indices in the original `log_probs` tensor.
199
- # shape: (batch_size, per_node_beam_size)
200
- indices = top_k_indices.gather(-1, sampled_indices)
201
-
202
- return log_probs.gather(1, indices), indices, state
203
-
204
-
205
- class TopPSampler(Sampler):
206
- """
207
- A `Sampler` which redistributes the probability mass function for nodes among
208
- the top choices with a cumulative probability of at least `p`, then samples from that subset
209
- after re-normalizing the probabilities.
210
-
211
- Beams are sampled in the default, deterministic way.
212
-
213
- :param p:
214
- The cumulative probability cutoff threshold. A higher value of `p` will result in more possible
215
- examples to sample from. If `with_replacement` is `False` and the number of possible samples is
216
- insufficient to sample without replacement from when calling `sample_nodes`, then the top
217
- `per_node_beam_size` examples will be chosen.
218
- :param temperature:
219
- A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
220
- above 1.0 produces a flatter probability distribution.
221
- :param with_replacement:
222
- If set to `True`, samples will be selected with replacement from the top choices.
223
-
224
- """
225
-
226
- def __init__(
227
- self,
228
- p: float = 0.9,
229
- temperature: float = 1.0,
230
- with_replacement: bool = False,
231
- ):
232
- if p < 0.0 or p > 1.0:
233
- raise ValueError("p must be a positive float no greater than 1.0")
234
- self.p = p
235
- self.temperature = temperature or 1.0
236
- self.with_replacement = with_replacement
237
-
238
- def sample_nodes(
239
- self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
240
- ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
241
- if not per_node_beam_size <= log_probs.size()[1]:
242
- raise ValueError("per_node_beam_size cannot be greater than vocabulary size")
243
-
244
- # First apply temperature coefficient:
245
- if self.temperature != 1.0:
246
- _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
247
- else:
248
- _log_probs = log_probs
249
-
250
- # Sort the probabilities in descending order to then find cumulative sum
251
- log_probs_descending, sorting_indices = torch.sort(_log_probs, descending=True)
252
-
253
- # shape: (batch_size, num_classes)
254
- probabilities_descending = log_probs_descending.exp()
255
- probabilities_summed = torch.cumsum(probabilities_descending, dim=-1)
256
-
257
- # Create a mask for filtering out probabilities that don't make the top `p`.
258
- # shape: (batch_size, num_classes)
259
- exclusion_mask = probabilities_summed >= self.p
260
-
261
- # We want to include the first index where probabilities_summed >= p, so we shift over one.
262
- exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone()
263
- exclusion_mask[..., 0] = False
264
-
265
- # Make sure there's at least `per_node_beam_size` options to be selected.
266
- if not self.with_replacement:
267
- exclusion_mask[..., :per_node_beam_size] = False
268
-
269
- log_probs_descending[exclusion_mask] = torch.finfo(log_probs.dtype).min
270
-
271
- # Now re-normalized the included log probs.
272
- # shape: (batch_size, num_classes)
273
- filtered_probabilities = torch.nn.functional.softmax(log_probs_descending, dim=-1)
274
-
275
- # Sample from the re-normalized subset.
276
- # NOTE: These indices are not indices into `log_probs`, they are indices into `log_probs_descending`.
277
- # shape: (batch_size, per_node_beam_size)
278
- sampled_indices = torch.multinomial(
279
- filtered_probabilities, per_node_beam_size, replacement=self.with_replacement
280
- )
281
-
282
- # Convert `sampled_indices` back to indices in the original `log_probs` tensor.
283
- # shape: (batch_size, per_node_beam_size)
284
- selected_indices = sorting_indices.gather(-1, sampled_indices)
285
-
286
- # Return (selected log probabilities, selected classes)
287
- # shape: (len(log_probs),1) , (len(log_probs), 1)
288
- return torch.gather(log_probs, 1, selected_indices), selected_indices, state
289
-
290
-
291
- class GumbelSampler(Sampler):
292
- """
293
- A `Sampler` which uses the Gumbel-Top-K trick to sample without replacement. See
294
- [*Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling
295
- Sequences Without Replacement*, W Kool, H Van Hoof and M Welling, 2010]
296
- (https://api.semanticscholar.org/CorpusID:76662039).
297
-
298
- :param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
299
- above 1.0 produces a flatter probability distribution.
300
- """
301
-
302
- def __init__(self, temperature: float = 1.0):
303
- self.temperature = temperature
304
-
305
- def init_state(
306
- self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
307
- ) -> StateType:
308
- # shape: (batch_size, num_classes)
309
- zeros = start_class_log_probabilities.new_zeros((batch_size, num_classes))
310
-
311
- # shape: (batch_size, num_classes)
312
- G_phi_S = self.gumbel_with_max(start_class_log_probabilities, zeros)
313
-
314
- return {"G_phi_S": G_phi_S}
315
-
316
- def sample_nodes(
317
- self,
318
- log_probs: torch.Tensor,
319
- per_node_beam_size: int,
320
- state: StateType,
321
- ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
322
- # First apply temperature coefficient:
323
- # shape: (batch_size * beam_size, num_classes)
324
- if self.temperature != 1.0:
325
- _log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
326
- else:
327
- _log_probs = log_probs
328
-
329
- # shape: (group_size,)
330
- phi_S = state["phi_S"]
331
-
332
- # shape: (group_size, num_classes)
333
- phi_S = phi_S.unsqueeze(-1).expand_as(_log_probs)
334
-
335
- # shape: (group_size, num_classes)
336
- phi_S_new = phi_S + _log_probs
337
-
338
- # shape: (group_size, 1)
339
- G_phi_S = state["G_phi_S"].unsqueeze(-1)
340
-
341
- # shape: (group_size, num_classes)
342
- G_phi_S_new = self.gumbel_with_max(phi_S_new, G_phi_S)
343
-
344
- # Replace NaNs with very negative number.
345
- # shape: (group_size, num_classes)
346
- # G_phi_S_new[G_phi_S_new.isnan()] = torch.finfo(G_phi_S_new.dtype).min
347
-
348
- # shape (both): (group_size, per_node_beam_size)
349
- top_G_phi_S_new, top_indices = torch.topk(G_phi_S_new, per_node_beam_size, dim=-1)
350
-
351
- # shape: (group_size, per_node_beam_size)
352
- top_log_probs = log_probs.gather(1, top_indices)
353
-
354
- return top_log_probs, top_indices, {"G_phi_S": top_G_phi_S_new}
355
-
356
- def sample_beams(
357
- self,
358
- log_probs: torch.Tensor,
359
- beam_size: int,
360
- state: StateType,
361
- ) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
362
- """
363
- Returns the beams with the highest perturbed log probabilities.
364
- """
365
- # shape (log_probs): (batch_size, beam_size * per_node_beam_size)
366
-
367
- batch_size = log_probs.size()[0]
368
-
369
- # shape: (batch_size * beam_size, per_node_beam_size)
370
- G_phi_S = state["G_phi_S"]
371
-
372
- # shape: (batch_size, beam_size * per_node_beam_size)
373
- G_phi_S = G_phi_S.reshape_as(log_probs)
374
-
375
- # shape (both): (batch_size, beam_size)
376
- G_phi_S_new, selected_indices = torch.topk(G_phi_S, beam_size, dim=-1)
377
-
378
- # shape: (batch_size, beam_size)
379
- selected_log_probs = log_probs.gather(1, selected_indices)
380
-
381
- # Now sort the selected beams by their true log prob.
382
- # shape (all): (batch_size, beam_size)
383
- selected_log_probs, sort_indices = selected_log_probs.sort(dim=-1, descending=True)
384
- selected_indices = selected_indices.gather(1, sort_indices)
385
- G_phi_S_new = G_phi_S_new.gather(1, sort_indices)
386
-
387
- # shape: (batch_size * beam_size,)
388
- G_phi_S_new = G_phi_S_new.reshape(batch_size * beam_size)
389
-
390
- # shape: (batch_size * beam_size,)
391
- phi_S = selected_log_probs.reshape(batch_size * beam_size)
392
-
393
- return selected_log_probs, selected_indices, {"G_phi_S": G_phi_S_new, "phi_S": phi_S}
394
-
395
- def gumbel(self, phi) -> torch.Tensor:
396
- """
397
- Sample `Gumbel(phi)`.
398
-
399
- `phi` should have shape `(batch_size, num_classes)`.
400
- """
401
- return -torch.log(-torch.log(torch.rand_like(phi))) + phi
402
-
403
- def gumbel_with_max(self, phi, T) -> torch.Tensor:
404
- """
405
- Sample `Gumbel(phi)` conditioned on the maximum value being equal to `T`.
406
-
407
- `phi` should have shape `(batch_size, num_classes)` and `T` should have
408
- shape `(batch_size, 1)`.
409
- """
410
- # Shape: (batch_size, num_classes)
411
- G_phi = self.gumbel(phi)
412
-
413
- # Now we find the maximum from these samples.
414
- # Shape: (batch_size, )
415
- Z, _ = G_phi.max(dim=-1)
416
-
417
- # Shape: (batch_size, num_classes)
418
- v = T - G_phi + torch.log1p(-torch.exp(G_phi - Z.unsqueeze(-1)))
419
-
420
- # Shape: (batch_size, num_classes)
421
- return T - torch.nn.functional.relu(v) - torch.log1p(torch.exp(-v.abs()))
422
-
423
-
424
- class FinalSequenceScorer:
425
- """
426
- An abstract class that can be used to score the final generated sequences found
427
- by beam search. Given the predicted sequences and the corresponding log probabilities of
428
- those sequences, the class calculates and returns the final score of the sequences.
429
-
430
- The default implementation scores the sequences using the sum of the log probabilities of
431
- the sequence, which is passed as input.
432
- """
433
-
434
- @abstractmethod
435
- def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
436
- """
437
- Score the final predictions found by beam search.
438
- Returns a tensor of the final sequence scores of shape `(batch_size, beam_size)`.
439
-
440
- :param predictions: A tensor containing the initial predictions with shape `(batch_size, beam_size, max_steps)`.
441
- :param log_probabilities: A tensor containing the log probabilities of the sequence, defined as the sum
442
- of the log probabilities per token, with shape `(batch_size, beam_size)`.
443
- :param end_index: The index of the end symbol.
444
-
445
- """
446
- raise NotImplementedError
447
-
448
-
449
- class SequenceLogProbabilityScorer(FinalSequenceScorer):
450
- """
451
- A :class:`FinalSequenceScorer` which scores the sequences by the sum of the log probabilities
452
- across the sequence's tokens.
453
- """
454
-
455
- def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
456
- del predictions, end_index
457
- # The sum of the sequence log probabilities is the input parameter, so just
458
- # return it.
459
- return log_probabilities
460
-
461
-
462
- class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer):
463
- """
464
- A :class:`FinalSequenceScorer` which scores the sequences by the average log probability of the
465
- tokens in the sequence. It optionally includes a length penalty which promotes
466
- or demotes sequences based on their lengths. The final score for a sequence will
467
- be `(sequence_log_probability) / (sequence_length ** length_penalty)`. The sequence length
468
- here includes the end token.
469
-
470
- :param length_penalty: The length penalty to use. A value of 1.0 means no length penalty is used.
471
- A value > 1.0 favors longer sequences, and < 1.0 favors shorter sequences.
472
- """
473
-
474
- def __init__(self, length_penalty: float = 1.0):
475
- super().__init__()
476
- self.length_penalty = length_penalty
477
-
478
- def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
479
- # shape: (batch_size, beam_size)
480
- lengths = (predictions != end_index).long().sum(dim=2)
481
-
482
- # If the sequence ended during beam search, the `log_probabilities` will include
483
- # the transition to the end token. Therefore, in such situations, `lengths` is
484
- # actually off by 1. This corrects for that.
485
- # shape: (batch_size, beam_size)
486
- is_end_token = predictions[:, :, -1] == end_index
487
- lengths += is_end_token.long()
488
-
489
- # shape: (batch_size, beam_size)
490
- average_log_probs = log_probabilities / (lengths**self.length_penalty)
491
- return average_log_probs
492
-
493
-
494
- class Constraint:
495
- """
496
- An abstract class that can be used to enforce constraints on the output predictions
497
- by manipulating the class log probabilities during beam search.
498
-
499
- A `Constraint` just has three methods that need to be implemented by subclasses:
500
- `init_state()`, `apply()` and `_update_state()`.
501
-
502
- `init_state()` takes one argument:
503
-
504
- - the batch size, an int
505
-
506
- It returns a constraint state, which is a nested list of dictionaries, with any state needed for subsequent
507
- calls to `apply()` and `update_state()`. The length of the outer list should be equal to `batch_size`.
508
- Each inner list should be of length 1.
509
-
510
- `apply()` takes two arguments:
511
-
512
- - the constraint state, which is a nested list of dictionaries. The length of the outer list is `batch_size`
513
- and the length of each inner list is `beam_size` except on the first time `apply()` is called when it is 1.
514
- - `class_log_probabilities`, a tensor of shape `(batch_size, beam_size, num_classes)` that contains the
515
- log probabilities for the classes during search. The first time `apply()` is called, `beam_size = 1`.
516
-
517
- The `apply()` method should return new `class_log_probabilities` that enforce the constraint
518
- for this step of beam search. For instance, it may prevent a specific class from being selected by setting
519
- the corresponding log probability to a negligible value such as `float("-inf")` or
520
- `torch.finfo(class_log_probabilities.dtype).min`.
521
-
522
- `_update_state()` takes two arguments:
523
-
524
- - the copied parent constraint state, which is a nested list of dictionaries. `state[i][j]` contains the
525
- copied state for the parent of `last_prediction[i, j]`. It is unique to that batch and beam, so it can be
526
- directly edited in-place without affecting the others.
527
- - last_prediction, a tensor of shape `(batch_size, beam_size)` containing the predictions from the last
528
- step of beam search.
529
-
530
- The `_update_state()` function should return a new constraint state, a nested list of dictionaries of
531
- length `batch_size` and inner list of length `beam_size`, one for each of the predictions in `last_prediction`.
532
-
533
- """
534
-
535
- @abstractmethod
536
- def init_state(
537
- self,
538
- batch_size: int,
539
- ) -> ConstraintStateType:
540
- raise NotImplementedError
541
-
542
- @abstractmethod
543
- def apply(
544
- self,
545
- state: ConstraintStateType,
546
- class_log_probabilities: torch.Tensor,
547
- ) -> torch.Tensor:
548
- raise NotImplementedError
549
-
550
- @staticmethod
551
- def _copy_state(
552
- state: ConstraintStateType,
553
- batch_size: int,
554
- beam_size: int,
555
- last_backpointer: Optional[torch.Tensor] = None,
556
- ) -> ConstraintStateType:
557
- """
558
- Copies the `state` . This method copies the data in `state` using `copy.deepcopy()`. If this
559
- is not appropriate for your constraint, you will need to implement the copying yourself.
560
- """
561
- new_state = []
562
- for i in range(batch_size):
563
- batch_state = []
564
- for j in range(beam_size):
565
- if last_backpointer is None:
566
- # This is the first prediction, so the backpointer is 0
567
- backpointer = 0
568
- else:
569
- backpointer = last_backpointer[i, j].item()
570
- batch_state.append(copy.deepcopy(state[i][backpointer])) # type: ignore
571
- new_state.append(batch_state)
572
- return new_state
573
-
574
- def update_state(
575
- self,
576
- state: ConstraintStateType,
577
- last_prediction: torch.Tensor,
578
- last_backpointer: Optional[torch.Tensor] = None,
579
- ) -> ConstraintStateType:
580
- batch_size, beam_size = last_prediction.size()
581
- new_state = self._copy_state(state, batch_size, beam_size, last_backpointer)
582
- return self._update_state(new_state, last_prediction)
583
-
584
- @abstractmethod
585
- def _update_state(
586
- self,
587
- state: ConstraintStateType,
588
- last_prediction: torch.Tensor,
589
- ) -> ConstraintStateType:
590
- raise NotImplementedError
591
-
592
-
593
- class RepeatedNGramBlockingConstraint(Constraint):
594
- def __init__(self, ngram_size: int, **kwargs) -> None:
595
- super().__init__(**kwargs)
596
- self.ngram_size = ngram_size
597
-
598
- def init_state(
599
- self,
600
- batch_size: int,
601
- ) -> ConstraintStateType:
602
- return [[{"seen_ngrams": {}, "current_prefix": []}] for _ in range(batch_size)]
603
-
604
- def apply(
605
- self,
606
- state: ConstraintStateType,
607
- class_log_probabilities: torch.Tensor,
608
- ) -> torch.Tensor:
609
- for i, batch in enumerate(state):
610
- for j, beam in enumerate(batch):
611
- current_prefix = tuple(beam["current_prefix"])
612
- seen_ngrams = beam["seen_ngrams"]
613
- try:
614
- disallowed_indices = seen_ngrams[current_prefix]
615
- class_log_probabilities[i, j, disallowed_indices] = torch.finfo(
616
- class_log_probabilities.dtype
617
- ).min
618
- except KeyError:
619
- # We have not seen this prefix before, so there is no index
620
- # that needs to be blocked
621
- pass
622
- return class_log_probabilities
623
-
624
- def _update_state(
625
- self,
626
- state: ConstraintStateType,
627
- last_prediction: torch.Tensor,
628
- ) -> ConstraintStateType:
629
- for i, batch in enumerate(state):
630
- for j, beam in enumerate(batch):
631
- prediction = last_prediction[i, j].item()
632
- prefix = beam["current_prefix"]
633
- seen_ngrams = beam["seen_ngrams"]
634
-
635
- if len(prefix) == self.ngram_size - 1:
636
- # This is a new ngram that we have to remember
637
- if tuple(prefix) not in seen_ngrams:
638
- seen_ngrams[tuple(prefix)] = []
639
- seen_ngrams[tuple(prefix)].append(prediction)
640
-
641
- # Create the new prefix, removing the oldest index if the prefix
642
- # is too long
643
- prefix.append(prediction)
644
- if len(prefix) == self.ngram_size:
645
- prefix.pop(0)
646
- return state
647
-
648
-
649
- class BeamSearch:
650
- """
651
- Implements the beam search algorithm for decoding the most likely sequences.
652
-
653
- :param end_index: The index of the "stop" or "end" token in the vocabulary. Usually the EOS token ID.
654
-
655
- :param max_steps: The maximum number of decoding steps to take, i.e. the maximum length
656
- of the predicted sequences.
657
-
658
- :param beam_size: The width of the beam used.
659
-
660
- :param per_node_beam_size: The maximum number of candidates to consider per node, at each step in the search.
661
- If not given, this just defaults to `beam_size`. Setting this parameter
662
- to a number smaller than `beam_size` may give better results, as it can introduce
663
- more diversity into the search. See
664
- [*Beam Search Strategies for Neural Machine Translation*, Freitag and Al-Onaizan, 2017]
665
- (https://api.semanticscholar.org/CorpusID:2229477).
666
-
667
- :param sampler: An optional `Sampler` which is used to pick next candidate nodes and beams.
668
- If not specified, `DeterministicSampler` will be used, which just takes the
669
- `per_node_beam_size` most likely nodes and the `beam_size` most likely beams.
670
-
671
- Using the [`GumbelSampler`](#gumbelsampler), on the other hand, will give you
672
- [Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039).
673
-
674
- :param min_steps: The minimum number of decoding steps to take, i.e. the minimum length of
675
- the predicted sequences. This does not include the start or end tokens. If `None`,
676
- no minimum is enforced.
677
-
678
- :param final_sequence_scorer: An optional `FinalSequenceScorer` which is used to score the final generated sequences.
679
- The output from this module is what is returned by the `search` method. If not
680
- specified, `SequenceLogProbabilityScorer` will be used, which scores the sequences
681
- by the sum of the token log probabilities.
682
-
683
- :param constraints: An optional list of `Constraint`s which should be applied during beam search. If not
684
- provided, no constraints will be enforced.
685
-
686
- """
687
-
688
- def __init__(
689
- self,
690
- end_index: int,
691
- *,
692
- max_steps: int = 50,
693
- beam_size: int = 10,
694
- per_node_beam_size: Optional[int] = None,
695
- sampler: Optional[Sampler] = None,
696
- min_steps: Optional[int] = None,
697
- final_sequence_scorer: Optional[FinalSequenceScorer] = None,
698
- constraints: Optional[List[Constraint]] = None,
699
- ) -> None:
700
- if not max_steps > 0:
701
- raise ValueError("max_steps must be positive")
702
- if not beam_size > 0:
703
- raise ValueError("beam_size must be positive")
704
- if per_node_beam_size is not None and not per_node_beam_size > 0:
705
- raise ValueError("per_node_beam_size must be positive")
706
- if min_steps is not None:
707
- if not min_steps >= 0:
708
- raise ValueError("min_steps must be non-negative")
709
- if not min_steps <= max_steps:
710
- raise ValueError("min_steps must be less than or equal to max_steps")
711
-
712
- self._end_index = end_index
713
- self.max_steps = max_steps
714
- self.beam_size = beam_size
715
- self.per_node_beam_size = per_node_beam_size or beam_size
716
- self.sampler = sampler or DeterministicSampler()
717
- self.min_steps = min_steps or 0
718
- self.final_sequence_scorer = final_sequence_scorer or SequenceLogProbabilityScorer()
719
- self.constraints = constraints or []
720
-
721
- @staticmethod
722
- def _reconstruct_sequences(predictions, backpointers):
723
- # Reconstruct the sequences.
724
- # shape: [(batch_size, beam_size, 1)]
725
- reconstructed_predictions = [predictions[-1].unsqueeze(2)]
726
-
727
- if not backpointers:
728
- return reconstructed_predictions
729
-
730
- # shape: (batch_size, beam_size)
731
- cur_backpointers = backpointers[-1]
732
-
733
- for timestep in range(len(predictions) - 2, 0, -1):
734
- # shape: (batch_size, beam_size, 1)
735
- cur_preds = predictions[timestep].gather(1, cur_backpointers).unsqueeze(2)
736
-
737
- reconstructed_predictions.append(cur_preds)
738
-
739
- # shape: (batch_size, beam_size)
740
- cur_backpointers = backpointers[timestep - 1].gather(1, cur_backpointers)
741
-
742
- # shape: (batch_size, beam_size, 1)
743
- final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2)
744
-
745
- reconstructed_predictions.append(final_preds)
746
-
747
- return reconstructed_predictions
748
-
749
- def search(
750
- self,
751
- start_predictions: torch.Tensor,
752
- start_state: StateType,
753
- step: StepFunctionType,
754
- ) -> Tuple[torch.Tensor, torch.Tensor]:
755
- """
756
- Given a starting state and a step function, apply beam search to find the
757
- most likely target sequences.
758
-
759
- Returns a tuple of `(predictions, final_scores)`, where `predictions`
760
- has shape `(batch_size, beam_size, max_steps)` and `final_scores`
761
- has shape `(batch_size, beam_size)`.
762
-
763
- .. note::
764
- If your step function returns `-inf` for some log probabilities
765
- (like if you're using a masked log-softmax) then some of the "best"
766
- sequences returned may also have `-inf` log probability. Specifically
767
- this happens when the beam size is smaller than the number of actions
768
- with finite log probability (non-zero probability) returned by the step function.
769
- Therefore if you're using a mask you may want to check the results from `search`
770
- and potentially discard sequences with non-finite log probability.
771
-
772
- :param start_predictions: A tensor containing the initial predictions with shape `(batch_size,)`.
773
- Usually the initial predictions are just the index of the "start" token
774
- in the target vocabulary.
775
-
776
- :param start_state: The initial state passed to the `step` function. Each value of the state dict
777
- should be a tensor of shape `(batch_size, *)`, where `*` means any other
778
- number of dimensions.
779
-
780
- :param step: A function that is responsible for computing the next most likely tokens,
781
- given the current state and the predictions from the last time step.
782
- The function should accept two or three arguments:
783
-
784
- - a tensor of shape `(group_size,)` or representing the index of the predicted
785
- tokens from the last time step,
786
- - the current state, a `StateType`, and
787
- - optionally, the timestep, an `int`.
788
-
789
- The `group_size` will be `batch_size * beam_size`, except in the initial
790
- step, for which it will just be `batch_size`.
791
-
792
- The function is expected to return a tuple, where the first element
793
- is a tensor of shape `(group_size, vocab_size)` containing
794
- the log probabilities of the tokens for the next step, and the second
795
- element is the updated state. The tensor in the state should have shape
796
- `(group_size, *)`, where `*` means any other number of dimensions.
797
-
798
- """
799
- step_signature = signature(step)
800
- if len(step_signature.parameters) < 3:
801
- # If the step function we're given does not take the time step argument, wrap it
802
- # in one that does.
803
- old_step = cast(StepFunctionTypeNoTimestep, step)
804
-
805
- def new_step(last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], time_step: int):
806
- del time_step
807
- return old_step(last_predictions, state)
808
-
809
- return self._search(start_predictions, start_state, new_step)
810
- else:
811
- return self._search(start_predictions, start_state, cast(StepFunctionTypeWithTimestep, step))
812
-
813
- def _search(
814
- self,
815
- start_predictions: torch.Tensor,
816
- start_state: StateType,
817
- step: StepFunctionTypeWithTimestep,
818
- ) -> Tuple[torch.Tensor, torch.Tensor]:
819
- batch_size = start_predictions.size()[0]
820
-
821
- # List of (batch_size, beam_size) tensors. One for each time step. Does not
822
- # include the start symbols, which are implicit.
823
- predictions: List[torch.Tensor] = []
824
-
825
- # List of (batch_size, beam_size) tensors. One for each time step. None for
826
- # the first. Stores the index n for the parent prediction, i.e.
827
- # predictions[t-1][i][n], that it came from.
828
- backpointers: List[torch.Tensor] = []
829
-
830
- constraint_states = [constraint.init_state(batch_size) for constraint in self.constraints]
831
-
832
- # Calculate the first timestep. This is done outside the main loop
833
- # because we are going from a single decoder input (the output from the
834
- # encoder) to the top `beam_size` decoder outputs. On the other hand,
835
- # within the main loop we are going from the `beam_size` elements of the
836
- # beam to `beam_size`^2 candidates from which we will select the top
837
- # `beam_size` elements for the next iteration.
838
- # shape: (batch_size, num_classes)
839
- start_class_log_probabilities, state = step(start_predictions, start_state, 0)
840
-
841
- num_classes = start_class_log_probabilities.size()[1]
842
-
843
- # Make sure `per_node_beam_size` is not larger than `num_classes`.
844
- if self.per_node_beam_size > num_classes:
845
- raise ValueError(
846
- f"Vocab size ({num_classes:d}) too small "
847
- f"relative to per_node_beam_size ({self.per_node_beam_size:d}).\n"
848
- f"Please decrease beam_size or per_node_beam_size."
849
- )
850
-
851
- sampler_state = self.sampler.init_state(start_class_log_probabilities, batch_size, num_classes)
852
-
853
- # Apply all constraints.
854
- if self.constraints:
855
- # shape: (batch_size, 1, num_classes)
856
- expanded_start_class_log_probabilities = start_class_log_probabilities.unsqueeze(1)
857
- for constraint, constraint_state in zip(self.constraints, constraint_states):
858
- expanded_start_class_log_probabilities = constraint.apply(
859
- constraint_state, expanded_start_class_log_probabilities
860
- )
861
- start_class_log_probabilities = expanded_start_class_log_probabilities.squeeze(1)
862
-
863
- # Prevent selecting the end symbol if there is any min_steps constraint
864
- if self.min_steps >= 1:
865
- start_class_log_probabilities[:, self._end_index] = torch.finfo(
866
- start_class_log_probabilities.dtype
867
- ).min
868
-
869
- # Get the initial predicted classed and their log probabilities.
870
- # shape: (batch_size, beam_size), (batch_size, beam_size)
871
- (
872
- start_top_log_probabilities,
873
- start_predicted_classes,
874
- sampler_state,
875
- ) = self.sampler.sample_beams(start_class_log_probabilities, self.beam_size, sampler_state)
876
-
877
- if self.beam_size == 1 and (start_predicted_classes == self._end_index).all():
878
- warnings.warn(
879
- "Empty sequences predicted. You may want to increase the beam size or ensure "
880
- "your step function is working properly.",
881
- RuntimeWarning,
882
- )
883
- return start_predicted_classes.unsqueeze(-1), start_top_log_probabilities
884
-
885
- # The log probabilities for the last time step.
886
- # shape: (batch_size, beam_size)
887
- last_log_probabilities = start_top_log_probabilities
888
-
889
- # shape: [(batch_size, beam_size)]
890
- predictions.append(start_predicted_classes)
891
-
892
- # Log probability tensor that mandates that the end token is selected.
893
- # shape: (batch_size * beam_size, num_classes)
894
- log_probs_after_end = start_class_log_probabilities.new_full(
895
- (batch_size * self.beam_size, num_classes),
896
- torch.finfo(start_class_log_probabilities.dtype).min,
897
- )
898
- log_probs_after_end[:, self._end_index] = 0.0
899
-
900
- # Set the same state for each element in the beam.
901
- self._update_initial_state(state, batch_size)
902
-
903
- for i, constraint in enumerate(self.constraints):
904
- constraint_states[i] = constraint.update_state(constraint_states[i], start_predicted_classes)
905
-
906
- for timestep in range(self.max_steps - 1):
907
- # shape: (batch_size * beam_size,)
908
- last_predictions = predictions[-1].reshape(batch_size * self.beam_size)
909
-
910
- # If every predicted token from the last step is `self._end_index`,
911
- # then we can stop early.
912
- if (last_predictions == self._end_index).all():
913
- break
914
- # Take a step. This get the predicted log probs of the next classes
915
- # and updates the state.
916
- # shape: (batch_size * beam_size, num_classes)
917
- class_log_probabilities, state = step(last_predictions, state, timestep + 1)
918
-
919
- # Apply all constraints.
920
- if self.constraints:
921
- # shape: (batch_size, beam_size, num_classes)
922
- reshaped_class_log_probabilities = class_log_probabilities.view(batch_size, self.beam_size, -1)
923
- for constraint, constraint_state in zip(self.constraints, constraint_states):
924
- reshaped_class_log_probabilities = constraint.apply(
925
- constraint_state, reshaped_class_log_probabilities
926
- )
927
- # shape: (batch_size * beam_size, num_classes)
928
- class_log_probabilities = reshaped_class_log_probabilities.view(batch_size * self.beam_size, -1)
929
-
930
- # The `timestep`-th iteration of the for loop is generating the `timestep + 2`-th token
931
- # of the sequence (because `timestep` is 0-indexed and we generated the first token
932
- # before the for loop). Here we block the end index if the search is not allowed to
933
- # terminate on this iteration.
934
- if timestep + 2 <= self.min_steps:
935
- class_log_probabilities[:, self._end_index] = torch.finfo(class_log_probabilities.dtype).min
936
-
937
- # shape: (batch_size * beam_size, num_classes)
938
- last_predictions_expanded = last_predictions.unsqueeze(-1).expand(
939
- batch_size * self.beam_size, num_classes
940
- )
941
-
942
- # Here we are finding any beams where we predicted the end token in
943
- # the previous timestep and replacing the distribution with a
944
- # one-hot distribution, forcing the beam to predict the end token
945
- # this timestep as well.
946
- # shape: (batch_size * beam_size, num_classes)
947
- cleaned_log_probabilities = torch.where(
948
- last_predictions_expanded == self._end_index,
949
- log_probs_after_end,
950
- class_log_probabilities,
951
- )
952
-
953
- # shape (both): (batch_size * beam_size, per_node_beam_size)
954
- top_log_probabilities, predicted_classes, sampler_state = self.sampler.sample_nodes(
955
- cleaned_log_probabilities, self.per_node_beam_size, sampler_state
956
- )
957
-
958
- # Here we expand the last log probabilities to (batch_size * beam_size, per_node_beam_size)
959
- # so that we can add them to the current log probs for this timestep.
960
- # This lets us maintain the log probability of each element on the beam.
961
- # shape: (batch_size * beam_size, per_node_beam_size)
962
- expanded_last_log_probabilities = (
963
- last_log_probabilities.unsqueeze(2)
964
- .expand(batch_size, self.beam_size, self.per_node_beam_size)
965
- .reshape(batch_size * self.beam_size, self.per_node_beam_size)
966
- )
967
-
968
- # shape: (batch_size * beam_size, per_node_beam_size)
969
- summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities
970
-
971
- # shape: (batch_size, beam_size * per_node_beam_size)
972
- reshaped_summed = summed_top_log_probabilities.reshape(
973
- batch_size, self.beam_size * self.per_node_beam_size
974
- )
975
-
976
- # shape: (batch_size, beam_size * per_node_beam_size)
977
- reshaped_predicted_classes = predicted_classes.reshape(
978
- batch_size, self.beam_size * self.per_node_beam_size
979
- )
980
-
981
- # Keep only the top `beam_size` beam indices.
982
- # shape (both): (batch_size, beam_size)
983
- (
984
- restricted_beam_log_probs,
985
- restricted_beam_indices,
986
- sampler_state,
987
- ) = self.sampler.sample_beams(reshaped_summed, self.beam_size, sampler_state)
988
-
989
- # Use the beam indices to extract the corresponding classes.
990
- # shape: (batch_size, beam_size)
991
- restricted_predicted_classes = reshaped_predicted_classes.gather(1, restricted_beam_indices)
992
-
993
- predictions.append(restricted_predicted_classes)
994
-
995
- # shape: (batch_size, beam_size)
996
- last_log_probabilities = restricted_beam_log_probs
997
-
998
- # The beam indices come from a `beam_size * per_node_beam_size` dimension where the
999
- # indices with a common ancestor are grouped together. Hence
1000
- # dividing by per_node_beam_size gives the ancestor. (Note that this is integer
1001
- # division as the tensor is a LongTensor.)
1002
- # shape: (batch_size, beam_size)
1003
- backpointer = torch.divide(restricted_beam_indices, self.per_node_beam_size, rounding_mode="trunc")
1004
- backpointers.append(backpointer)
1005
-
1006
- # Keep only the pieces of the state tensors corresponding to the
1007
- # ancestors created this iteration.
1008
- self._update_state(state, backpointer)
1009
-
1010
- for i, constraint in enumerate(self.constraints):
1011
- constraint_states[i] = constraint.update_state(
1012
- constraint_states[i], restricted_predicted_classes, last_backpointer=backpointer
1013
- )
1014
-
1015
- # Warn about "-inf" log probabilities if not using any constraints (negligible
1016
- # log probabilities are expected when using constraints).
1017
- if not self.constraints and (
1018
- not torch.isfinite(last_log_probabilities).all()
1019
- or (last_log_probabilities == torch.finfo(last_log_probabilities.dtype).min).any()
1020
- ):
1021
- warnings.warn(
1022
- "Negligible log probabilities encountered ('-inf' or equivalent). "
1023
- "Some final sequences may not make sense. "
1024
- "This can happen when the beam size is larger than the number of valid (non-zero "
1025
- "probability) transitions that the step function produces.",
1026
- RuntimeWarning,
1027
- )
1028
-
1029
- reconstructed_predictions = self._reconstruct_sequences(predictions, backpointers)
1030
-
1031
- # shape: (batch_size, beam_size, max_steps)
1032
- all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2)
1033
-
1034
- # Calculate the final sequence scores
1035
- # shape: (batch_size, beam_size)
1036
- final_scores = self.final_sequence_scorer.score(all_predictions, last_log_probabilities, self._end_index)
1037
-
1038
- # Sort the sequences based on the final scores so the best scoring
1039
- # sequence is at index 0
1040
- sorted_final_scores, sorted_indices = torch.sort(final_scores, dim=1, descending=True)
1041
- sorted_all_predictions = torch.gather(
1042
- all_predictions, 1, sorted_indices.unsqueeze(-1).expand_as(all_predictions)
1043
- )
1044
-
1045
- return sorted_all_predictions, sorted_final_scores
1046
-
1047
- def _update_initial_state(self, state: StateType, batch_size: int):
1048
- """
1049
- Expand tensors in a state dictionary from `(batch_size, *)` to `(batch_size * beam_size, *)`.
1050
- """
1051
- for key, state_tensor in state.items():
1052
- if state_tensor is None:
1053
- continue
1054
- # shape: (batch_size * beam_size, *)
1055
- _, *last_dims = state_tensor.size()
1056
- state[key] = (
1057
- state_tensor.unsqueeze(1)
1058
- .expand(batch_size, self.beam_size, *last_dims)
1059
- .reshape(batch_size * self.beam_size, *last_dims)
1060
- )
1061
-
1062
- def _update_state(self, state: StateType, backpointer: torch.Tensor):
1063
- batch_size = backpointer.size()[0]
1064
-
1065
- for key, state_tensor in state.items():
1066
- if state_tensor is None:
1067
- continue
1068
- _, *last_dims = state_tensor.size()
1069
- # shape: (batch_size, beam_size, *)
1070
- expanded_backpointer = backpointer.view(batch_size, self.beam_size, *([1] * len(last_dims))).expand(
1071
- batch_size, self.beam_size, *last_dims
1072
- )
1073
- # shape: (batch_size * beam_size, *)
1074
- state[key] = (
1075
- state_tensor.reshape(batch_size, self.beam_size, *last_dims)
1076
- .gather(1, expanded_backpointer)
1077
- .reshape(batch_size * self.beam_size, *last_dims)
1078
- )