manueldeprada HF Staff commited on
Commit
005232e
·
verified ·
1 Parent(s): 5f8a24f

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. custom_generate/generate.py +4 -2
custom_generate/generate.py CHANGED
@@ -1,7 +1,7 @@
1
  from typing import Union
2
  import torch
3
  from transformers import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
4
- from transformers.generation.utils import GenerateNonBeamOutput, GenerateDecoderOnlyOutput
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  import numpy as np
@@ -332,5 +332,7 @@ def generate(model, *args, **kwargs):
332
  `'high'` to improve short-answer tasks. Check the [documentation](https://huggingface.co/transformers-community/dola)
333
  or [the paper](https://huggingface.co/papers/2309.03883) for more details.
334
  """
335
- generation_outputs = model.generate(*args, custom_generate=_dola_decoding, **kwargs)
 
 
336
  return generation_outputs
 
1
  from typing import Union
2
  import torch
3
  from transformers import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
4
+ from transformers.generation.utils import GenerationMixin, GenerateNonBeamOutput, GenerateDecoderOnlyOutput
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  import numpy as np
 
332
  `'high'` to improve short-answer tasks. Check the [documentation](https://huggingface.co/transformers-community/dola)
333
  or [the paper](https://huggingface.co/papers/2309.03883) for more details.
334
  """
335
+ generation_outputs = GenerationMixin.generate(
336
+ model, *args, custom_generate=_dola_decoding, **kwargs
337
+ )
338
  return generation_outputs