log debug temp
Browse files
custom_generate/generate.py
CHANGED
|
@@ -41,11 +41,14 @@ def generate(model, input_ids, generation_config=None, n_sigma:float=1.0, **kwar
|
|
| 41 |
and max_new_tokens.
|
| 42 |
n_sigma (float): The number of standard deviations to use for topN-sigma sampling.
|
| 43 |
**kwargs: Additional keyword arguments.
|
| 44 |
-
"""
|
| 45 |
generation_config = generation_config or model.generation_config # default to the model generation config
|
| 46 |
cur_length = input_ids.shape[1]
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
while cur_length < max_length:
|
| 51 |
logits = model(input_ids).logits
|
|
@@ -56,6 +59,5 @@ def generate(model, input_ids, generation_config=None, n_sigma:float=1.0, **kwar
|
|
| 56 |
next_tokens = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1)
|
| 57 |
input_ids = torch.cat((input_ids, next_tokens), dim=-1)
|
| 58 |
cur_length += 1
|
| 59 |
-
print(f"Current length: {cur_length}, Next token: {next_tokens.item()}")
|
| 60 |
|
| 61 |
return input_ids
|
|
|
|
| 41 |
and max_new_tokens.
|
| 42 |
n_sigma (float): The number of standard deviations to use for topN-sigma sampling.
|
| 43 |
**kwargs: Additional keyword arguments.
|
| 44 |
+
"""
|
| 45 |
generation_config = generation_config or model.generation_config # default to the model generation config
|
| 46 |
cur_length = input_ids.shape[1]
|
| 47 |
+
if generation_config.max_length is None:
|
| 48 |
+
max_length = cur_length + generation_config.max_new_tokens
|
| 49 |
+
else:
|
| 50 |
+
max_length = generation_config.max_length
|
| 51 |
+
print(f"Starting generation with max_length: {max_length}, current length: {cur_length} and temperature: {generation_config.temperature}")
|
| 52 |
|
| 53 |
while cur_length < max_length:
|
| 54 |
logits = model(input_ids).logits
|
|
|
|
| 59 |
next_tokens = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1)
|
| 60 |
input_ids = torch.cat((input_ids, next_tokens), dim=-1)
|
| 61 |
cur_length += 1
|
|
|
|
| 62 |
|
| 63 |
return input_ids
|