Update custom_generate/generate.py
Browse files- custom_generate/generate.py +11 -55
custom_generate/generate.py
CHANGED
|
@@ -282,11 +282,11 @@ def _contrastive_search(
|
|
| 282 |
f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
|
| 283 |
"for contrastive search."
|
| 284 |
)
|
| 285 |
-
#
|
| 286 |
-
elif not
|
| 287 |
raise ValueError(
|
| 288 |
-
f"
|
| 289 |
-
"
|
| 290 |
)
|
| 291 |
|
| 292 |
# contrastive_search main logic start:
|
|
@@ -324,19 +324,7 @@ def _contrastive_search(
|
|
| 324 |
del outputs
|
| 325 |
|
| 326 |
if not sequential:
|
| 327 |
-
|
| 328 |
-
if isinstance(outputs["past_key_values"], DynamicCache) or (
|
| 329 |
-
isinstance(outputs["past_key_values"], EncoderDecoderCache)
|
| 330 |
-
and isinstance(
|
| 331 |
-
outputs["past_key_values"].self_attention_cache, DynamicCache
|
| 332 |
-
)
|
| 333 |
-
):
|
| 334 |
-
model_kwargs["past_key_values"] = model_kwargs["past_key_values"].batch_repeat_interleave(top_k)
|
| 335 |
-
else:
|
| 336 |
-
raise ValueError(
|
| 337 |
-
f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
|
| 338 |
-
"dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
|
| 339 |
-
)
|
| 340 |
|
| 341 |
if sequential:
|
| 342 |
all_outputs = []
|
|
@@ -352,21 +340,10 @@ def _contrastive_search(
|
|
| 352 |
output_hidden_states=True,
|
| 353 |
output_attentions=output_attentions,
|
| 354 |
)
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
)
|
| 360 |
-
):
|
| 361 |
-
# Remove past K-V from output since we don't need to stack later
|
| 362 |
-
outputs["past_key_values"] = None
|
| 363 |
-
# Remove last token from past K-V since we don't want to append it at this point
|
| 364 |
-
model_kwargs["past_key_values"].crop(-1)
|
| 365 |
-
else:
|
| 366 |
-
raise ValueError(
|
| 367 |
-
f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
|
| 368 |
-
"dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
|
| 369 |
-
)
|
| 370 |
|
| 371 |
all_outputs.append(outputs)
|
| 372 |
outputs = stack_model_outputs(all_outputs, model.config.get_text_config())
|
|
@@ -463,17 +440,7 @@ def _contrastive_search(
|
|
| 463 |
next_past_key_values = next_past_key_values or getattr(
|
| 464 |
outputs, possible_cache_name, None
|
| 465 |
)
|
| 466 |
-
|
| 467 |
-
if isinstance(next_past_key_values, DynamicCache) or (
|
| 468 |
-
isinstance(next_past_key_values, EncoderDecoderCache)
|
| 469 |
-
and isinstance(next_past_key_values.self_attention_cache, DynamicCache)
|
| 470 |
-
):
|
| 471 |
-
next_past_key_values.batch_select_indices(augmented_idx)
|
| 472 |
-
else:
|
| 473 |
-
raise ValueError(
|
| 474 |
-
f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
|
| 475 |
-
"dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
|
| 476 |
-
)
|
| 477 |
|
| 478 |
logit_for_next_step = torch.stack(torch.split(logits, top_k))[
|
| 479 |
range(batch_size), selected_idx, :
|
|
@@ -549,18 +516,7 @@ def _contrastive_search(
|
|
| 549 |
# Contrastive search works by forward looking at the next token, so we need to exclude it from
|
| 550 |
# `past_key_values` to be consistent with the other decoding methods
|
| 551 |
if model_kwargs.get("past_key_values") is not None:
|
| 552 |
-
|
| 553 |
-
isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
|
| 554 |
-
and isinstance(
|
| 555 |
-
model_kwargs["past_key_values"].self_attention_cache, DynamicCache
|
| 556 |
-
)
|
| 557 |
-
):
|
| 558 |
-
model_kwargs["past_key_values"].crop(-1)
|
| 559 |
-
else:
|
| 560 |
-
raise ValueError(
|
| 561 |
-
f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
|
| 562 |
-
"dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
|
| 563 |
-
)
|
| 564 |
|
| 565 |
if model.config.is_encoder_decoder:
|
| 566 |
return GenerateEncoderDecoderOutput(
|
|
|
|
| 282 |
f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
|
| 283 |
"for contrastive search."
|
| 284 |
)
|
| 285 |
+
# Only those caches have the necesary methods
|
| 286 |
+
elif not (isinstance(past_key_values, DynamicCache) or (isinstance(past_key_values, EncoderDecoderCache) and isinstance(past_key_values.self_attention_cache, DynamicCache))):
|
| 287 |
raise ValueError(
|
| 288 |
+
f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
|
| 289 |
+
"dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
|
| 290 |
)
|
| 291 |
|
| 292 |
# contrastive_search main logic start:
|
|
|
|
| 324 |
del outputs
|
| 325 |
|
| 326 |
if not sequential:
|
| 327 |
+
model_kwargs["past_key_values"] = model_kwargs["past_key_values"].batch_repeat_interleave(top_k)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
|
| 329 |
if sequential:
|
| 330 |
all_outputs = []
|
|
|
|
| 340 |
output_hidden_states=True,
|
| 341 |
output_attentions=output_attentions,
|
| 342 |
)
|
| 343 |
+
# Remove past K-V from output since we don't need to stack later
|
| 344 |
+
outputs["past_key_values"] = None
|
| 345 |
+
# Remove last token from past K-V since we don't want to append it at this point
|
| 346 |
+
model_kwargs["past_key_values"].crop(-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
|
| 348 |
all_outputs.append(outputs)
|
| 349 |
outputs = stack_model_outputs(all_outputs, model.config.get_text_config())
|
|
|
|
| 440 |
next_past_key_values = next_past_key_values or getattr(
|
| 441 |
outputs, possible_cache_name, None
|
| 442 |
)
|
| 443 |
+
next_past_key_values.batch_select_indices(augmented_idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
|
| 445 |
logit_for_next_step = torch.stack(torch.split(logits, top_k))[
|
| 446 |
range(batch_size), selected_idx, :
|
|
|
|
| 516 |
# Contrastive search works by forward looking at the next token, so we need to exclude it from
|
| 517 |
# `past_key_values` to be consistent with the other decoding methods
|
| 518 |
if model_kwargs.get("past_key_values") is not None:
|
| 519 |
+
model_kwargs["past_key_values"].crop(-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
|
| 521 |
if model.config.is_encoder_decoder:
|
| 522 |
return GenerateEncoderDecoderOutput(
|