log
Browse files
custom_generate/generate.py
CHANGED
|
@@ -55,5 +55,6 @@ def generate(model, input_ids, generation_config=None, n_sigma:float=1.0, **kwar
|
|
| 55 |
next_tokens = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1)
|
| 56 |
input_ids = torch.cat((input_ids, next_tokens), dim=-1)
|
| 57 |
cur_length += 1
|
|
|
|
| 58 |
|
| 59 |
return input_ids
|
|
|
|
| 55 |
next_tokens = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1)
|
| 56 |
input_ids = torch.cat((input_ids, next_tokens), dim=-1)
|
| 57 |
cur_length += 1
|
| 58 |
+
print(f"Current length: {cur_length}, Next token: {next_tokens.item()}")
|
| 59 |
|
| 60 |
return input_ids
|