Commit
·
880ed7e
1
Parent(s):
e516101
Initial commit
Browse files
custom_generate/generate.py
CHANGED
|
@@ -103,9 +103,9 @@ def ancestral_sampling(model_kwargs, model, eos_token_ids, pad_token_id, bos_tok
|
|
| 103 |
active_seqs = input_ids.new_ones((batch_size, 1), dtype=torch.bool)
|
| 104 |
lens = torch.full((batch_size,), max_prompts_len, dtype=torch.long, device=input_ids.device)
|
| 105 |
# Modified log probabilities of the sequences
|
| 106 |
-
scores = torch.zeros((batch_size, max_new_tokens), dtype=
|
| 107 |
# Unfiltered sequence log probabilities (T=1, no sampling modifications)
|
| 108 |
-
logps = torch.zeros((batch_size, max_new_tokens), dtype=
|
| 109 |
|
| 110 |
for i in range(max_new_tokens):
|
| 111 |
# Get the next token probabilities and update the KV cache
|
|
@@ -153,6 +153,7 @@ def generate(model, **kwargs):
|
|
| 153 |
"""
|
| 154 |
generation_config = model.generation_config
|
| 155 |
max_new_tokens = kwargs.get('max_new_tokens', generation_config.max_new_tokens)
|
|
|
|
| 156 |
do_sample = kwargs.get('do_sample', True)
|
| 157 |
eos_token_ids = kwargs.get('eos_token_ids', generation_config.eos_token_id)
|
| 158 |
if eos_token_ids is None:
|
|
|
|
| 103 |
active_seqs = input_ids.new_ones((batch_size, 1), dtype=torch.bool)
|
| 104 |
lens = torch.full((batch_size,), max_prompts_len, dtype=torch.long, device=input_ids.device)
|
| 105 |
# Modified log probabilities of the sequences
|
| 106 |
+
scores = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype)
|
| 107 |
# Unfiltered sequence log probabilities (T=1, no sampling modifications)
|
| 108 |
+
logps = torch.zeros((batch_size, max_new_tokens), dtype=model.dtype)
|
| 109 |
|
| 110 |
for i in range(max_new_tokens):
|
| 111 |
# Get the next token probabilities and update the KV cache
|
|
|
|
| 153 |
"""
|
| 154 |
generation_config = model.generation_config
|
| 155 |
max_new_tokens = kwargs.get('max_new_tokens', generation_config.max_new_tokens)
|
| 156 |
+
max_new_tokens = 512 if max_new_tokens is None else max_new_tokens
|
| 157 |
do_sample = kwargs.get('do_sample', True)
|
| 158 |
eos_token_ids = kwargs.get('eos_token_ids', generation_config.eos_token_id)
|
| 159 |
if eos_token_ids is None:
|