allow Output subclasses in contrastive search
Browse filesHi,
I am trying to use contrastive search on a custom multimodal model and I found out that it was overriding the model outputs to have the default type (like CausalLMOutputWithPast for instance). This can be an issue for models that rely on extra attributes in the output class. I think we could simply replace the attributes we want in Outputs rather than recreating the object. We achieve the same thing but it's a bit cleaner in my opinion. Happy to discuss if I'm missing something!
Also happy to add assertions to check that outputs inherit from the correct class if you want.
custom_generate/generate.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import logging
|
| 2 |
from typing import TYPE_CHECKING, Optional, Union
|
| 3 |
|
|
@@ -14,7 +15,6 @@ from transformers.generation.utils import (
|
|
| 14 |
GenerateNonBeamOutput,
|
| 15 |
GenerationMixin,
|
| 16 |
)
|
| 17 |
-
from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
|
| 18 |
from transformers.utils import ModelOutput
|
| 19 |
|
| 20 |
|
|
@@ -414,7 +414,8 @@ def _contrastive_search(
|
|
| 414 |
for layer in outputs.decoder_attentions:
|
| 415 |
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
|
| 416 |
next_step_decoder_attentions += (layer,)
|
| 417 |
-
outputs =
|
|
|
|
| 418 |
past_key_values=next_past_key_values,
|
| 419 |
decoder_hidden_states=next_decoder_hidden_states,
|
| 420 |
decoder_attentions=next_step_decoder_attentions or None,
|
|
@@ -426,11 +427,13 @@ def _contrastive_search(
|
|
| 426 |
for layer in outputs.attentions:
|
| 427 |
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
|
| 428 |
next_step_attentions += (layer,)
|
| 429 |
-
outputs =
|
|
|
|
| 430 |
past_key_values=next_past_key_values,
|
| 431 |
hidden_states=next_decoder_hidden_states,
|
| 432 |
attentions=next_step_attentions or None,
|
| 433 |
)
|
|
|
|
| 434 |
# contrastive_search main logic end
|
| 435 |
|
| 436 |
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
|
|
|
| 1 |
+
from dataclasses import replace
|
| 2 |
import logging
|
| 3 |
from typing import TYPE_CHECKING, Optional, Union
|
| 4 |
|
|
|
|
| 15 |
GenerateNonBeamOutput,
|
| 16 |
GenerationMixin,
|
| 17 |
)
|
|
|
|
| 18 |
from transformers.utils import ModelOutput
|
| 19 |
|
| 20 |
|
|
|
|
| 414 |
for layer in outputs.decoder_attentions:
|
| 415 |
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
|
| 416 |
next_step_decoder_attentions += (layer,)
|
| 417 |
+
outputs = replace(
|
| 418 |
+
outputs,
|
| 419 |
past_key_values=next_past_key_values,
|
| 420 |
decoder_hidden_states=next_decoder_hidden_states,
|
| 421 |
decoder_attentions=next_step_decoder_attentions or None,
|
|
|
|
| 427 |
for layer in outputs.attentions:
|
| 428 |
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
|
| 429 |
next_step_attentions += (layer,)
|
| 430 |
+
outputs = replace(
|
| 431 |
+
outputs,
|
| 432 |
past_key_values=next_past_key_values,
|
| 433 |
hidden_states=next_decoder_hidden_states,
|
| 434 |
attentions=next_step_attentions or None,
|
| 435 |
)
|
| 436 |
+
|
| 437 |
# contrastive_search main logic end
|
| 438 |
|
| 439 |
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|