now fixed
Browse files
custom_generate/generate.py
CHANGED
|
@@ -297,12 +297,11 @@ def _contrastive_search(
|
|
| 297 |
for i in range(top_k):
|
| 298 |
# compute the candidate tokens by the language model and collect their hidden_states
|
| 299 |
next_model_inputs = model.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs)
|
|
|
|
| 300 |
|
| 301 |
outputs = model(
|
| 302 |
**next_model_inputs,
|
| 303 |
return_dict=True,
|
| 304 |
-
output_hidden_states=True,
|
| 305 |
-
output_attentions=output_attentions,
|
| 306 |
)
|
| 307 |
# Remove past K-V from output since we don't need to stack later
|
| 308 |
outputs["past_key_values"] = None
|
|
@@ -316,12 +315,11 @@ def _contrastive_search(
|
|
| 316 |
# compute the candidate tokens by the language model and collect their hidden_states
|
| 317 |
# assembles top_k_ids into batch of size k
|
| 318 |
next_model_inputs = model.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)
|
|
|
|
| 319 |
|
| 320 |
outputs = model(
|
| 321 |
**next_model_inputs,
|
| 322 |
return_dict=True,
|
| 323 |
-
output_hidden_states=True,
|
| 324 |
-
output_attentions=output_attentions,
|
| 325 |
)
|
| 326 |
|
| 327 |
# This is essential to avoid having a last reference to the big past K-V and double the necessary memory
|
|
@@ -385,8 +383,6 @@ def _contrastive_search(
|
|
| 385 |
selected_outputs = model(
|
| 386 |
**next_model_input,
|
| 387 |
return_dict=True,
|
| 388 |
-
output_hidden_states=False,
|
| 389 |
-
output_attentions=False,
|
| 390 |
)
|
| 391 |
next_past_key_values = selected_outputs["past_key_values"]
|
| 392 |
|
|
|
|
| 297 |
for i in range(top_k):
|
| 298 |
# compute the candidate tokens by the language model and collect their hidden_states
|
| 299 |
next_model_inputs = model.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs)
|
| 300 |
+
next_model_inputs['output_hidden_states'] = True
|
| 301 |
|
| 302 |
outputs = model(
|
| 303 |
**next_model_inputs,
|
| 304 |
return_dict=True,
|
|
|
|
|
|
|
| 305 |
)
|
| 306 |
# Remove past K-V from output since we don't need to stack later
|
| 307 |
outputs["past_key_values"] = None
|
|
|
|
| 315 |
# compute the candidate tokens by the language model and collect their hidden_states
|
| 316 |
# assembles top_k_ids into batch of size k
|
| 317 |
next_model_inputs = model.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)
|
| 318 |
+
next_model_inputs['output_hidden_states'] = True
|
| 319 |
|
| 320 |
outputs = model(
|
| 321 |
**next_model_inputs,
|
| 322 |
return_dict=True,
|
|
|
|
|
|
|
| 323 |
)
|
| 324 |
|
| 325 |
# This is essential to avoid having a last reference to the big past K-V and double the necessary memory
|
|
|
|
| 383 |
selected_outputs = model(
|
| 384 |
**next_model_input,
|
| 385 |
return_dict=True,
|
|
|
|
|
|
|
| 386 |
)
|
| 387 |
next_past_key_values = selected_outputs["past_key_values"]
|
| 388 |
|