log
Browse files
custom_generate/generate.py
CHANGED
|
@@ -45,6 +45,7 @@ def generate(model, input_ids, generation_config=None, n_sigma:float=1.0, **kwar
|
|
| 45 |
generation_config = generation_config or model.generation_config # default to the model generation config
|
| 46 |
cur_length = input_ids.shape[1]
|
| 47 |
max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
|
|
|
|
| 48 |
|
| 49 |
while cur_length < max_length:
|
| 50 |
logits = model(input_ids).logits
|
|
|
|
| 45 |
generation_config = generation_config or model.generation_config # default to the model generation config
|
| 46 |
cur_length = input_ids.shape[1]
|
| 47 |
max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
|
| 48 |
+
print(f"Starting generation with max_length: {max_length}, current length: {cur_length}")
|
| 49 |
|
| 50 |
while cur_length < max_length:
|
| 51 |
logits = model(input_ids).logits
|