jood-canva commited on
Commit
ca11c72
·
verified ·
1 Parent(s): bbf36bb

allow Output subclasses in contrastive search

Browse files

Hi,

I am trying to use contrastive search on a custom multimodal model and I found out that it was overriding the model outputs to have the default type (like CausalLMOutputWithPast for instance). This can be an issue for models that rely on extra attributes in the output class. I think we could simply replace the attributes we want in Outputs rather than recreating the object. We achieve the same thing but it's a bit cleaner in my opinion. Happy to discuss if I'm missing something!

Also happy to add assertions to check that outputs inherit from the correct class if you want.

Files changed (1) hide show
  1. custom_generate/generate.py +6 -3
custom_generate/generate.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import logging
2
  from typing import TYPE_CHECKING, Optional, Union
3
 
@@ -14,7 +15,6 @@ from transformers.generation.utils import (
14
  GenerateNonBeamOutput,
15
  GenerationMixin,
16
  )
17
- from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
18
  from transformers.utils import ModelOutput
19
 
20
 
@@ -414,7 +414,8 @@ def _contrastive_search(
414
  for layer in outputs.decoder_attentions:
415
  layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
416
  next_step_decoder_attentions += (layer,)
417
- outputs = Seq2SeqLMOutput(
 
418
  past_key_values=next_past_key_values,
419
  decoder_hidden_states=next_decoder_hidden_states,
420
  decoder_attentions=next_step_decoder_attentions or None,
@@ -426,11 +427,13 @@ def _contrastive_search(
426
  for layer in outputs.attentions:
427
  layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
428
  next_step_attentions += (layer,)
429
- outputs = CausalLMOutputWithPast(
 
430
  past_key_values=next_past_key_values,
431
  hidden_states=next_decoder_hidden_states,
432
  attentions=next_step_attentions or None,
433
  )
 
434
  # contrastive_search main logic end
435
 
436
  # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
 
1
+ from dataclasses import replace
2
  import logging
3
  from typing import TYPE_CHECKING, Optional, Union
4
 
 
15
  GenerateNonBeamOutput,
16
  GenerationMixin,
17
  )
 
18
  from transformers.utils import ModelOutput
19
 
20
 
 
414
  for layer in outputs.decoder_attentions:
415
  layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
416
  next_step_decoder_attentions += (layer,)
417
+ outputs = replace(
418
+ outputs,
419
  past_key_values=next_past_key_values,
420
  decoder_hidden_states=next_decoder_hidden_states,
421
  decoder_attentions=next_step_decoder_attentions or None,
 
427
  for layer in outputs.attentions:
428
  layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
429
  next_step_attentions += (layer,)
430
+ outputs = replace(
431
+ outputs,
432
  past_key_values=next_past_key_values,
433
  hidden_states=next_decoder_hidden_states,
434
  attentions=next_step_attentions or None,
435
  )
436
+
437
  # contrastive_search main logic end
438
 
439
  # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping