Spaces:
Paused
Paused
ssalb
commited on
Commit
·
7c0d92c
1
Parent(s):
16746e5
Update space with latest code and dependencies on Mon Jan 6 09:01:27 UTC 2025
Browse files
story_beam_search/beam_search.py
CHANGED
|
@@ -8,8 +8,8 @@ from story_beam_search.scoring import StoryScorer
|
|
| 8 |
|
| 9 |
@dataclass
|
| 10 |
class BeamSearchConfig:
|
| 11 |
-
num_beams: int =
|
| 12 |
-
num_return_sequences: int =
|
| 13 |
max_length: int = 100
|
| 14 |
no_repeat_ngram_size: int = 2
|
| 15 |
temperature: float = 0.8
|
|
@@ -100,12 +100,16 @@ class BeamSearchGenerator:
|
|
| 100 |
attention_mask_batch = torch.stack(padded_attention_masks).to(self.device)
|
| 101 |
|
| 102 |
# Calculate continuation length
|
|
|
|
| 103 |
continuation_length = (
|
| 104 |
max_length + self.config.max_length // self.config.num_iterations
|
| 105 |
)
|
| 106 |
|
| 107 |
# Generate all continuations in one pass
|
| 108 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
| 109 |
outputs = self.model.generate(
|
| 110 |
input_ids=input_ids_batch,
|
| 111 |
attention_mask=attention_mask_batch,
|
|
|
|
| 8 |
|
| 9 |
@dataclass
|
| 10 |
class BeamSearchConfig:
|
| 11 |
+
num_beams: int = 4
|
| 12 |
+
num_return_sequences: int = 2
|
| 13 |
max_length: int = 100
|
| 14 |
no_repeat_ngram_size: int = 2
|
| 15 |
temperature: float = 0.8
|
|
|
|
| 100 |
attention_mask_batch = torch.stack(padded_attention_masks).to(self.device)
|
| 101 |
|
| 102 |
# Calculate continuation length
|
| 103 |
+
# we want this length, times the num_iterations, to be roughly the max_length set by the user.
|
| 104 |
continuation_length = (
|
| 105 |
max_length + self.config.max_length // self.config.num_iterations
|
| 106 |
)
|
| 107 |
|
| 108 |
# Generate all continuations in one pass
|
| 109 |
with torch.no_grad():
|
| 110 |
+
# Technically speaking, this generation is also using beam search at the token level
|
| 111 |
+
# in this case though, I'm using it to generate multiple sequences at once and evaluate them
|
| 112 |
+
# not by token probability, but my custom metrics.
|
| 113 |
outputs = self.model.generate(
|
| 114 |
input_ids=input_ids_batch,
|
| 115 |
attention_mask=attention_mask_batch,
|