Upload folder using huggingface_hub
Browse files
custom_generate/generate.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from typing import Union
|
| 2 |
import torch
|
| 3 |
from transformers import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
|
| 4 |
-
from transformers.generation.utils import GenerateNonBeamOutput, GenerateDecoderOnlyOutput
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.nn.functional as F
|
| 7 |
import numpy as np
|
|
@@ -332,5 +332,7 @@ def generate(model, *args, **kwargs):
|
|
| 332 |
`'high'` to improve short-answer tasks. Check the [documentation](https://huggingface.co/transformers-community/dola)
|
| 333 |
or [the paper](https://huggingface.co/papers/2309.03883) for more details.
|
| 334 |
"""
|
| 335 |
-
generation_outputs =
|
|
|
|
|
|
|
| 336 |
return generation_outputs
|
|
|
|
| 1 |
from typing import Union
|
| 2 |
import torch
|
| 3 |
from transformers import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
|
| 4 |
+
from transformers.generation.utils import GenerationMixin, GenerateNonBeamOutput, GenerateDecoderOnlyOutput
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.nn.functional as F
|
| 7 |
import numpy as np
|
|
|
|
| 332 |
`'high'` to improve short-answer tasks. Check the [documentation](https://huggingface.co/transformers-community/dola)
|
| 333 |
or [the paper](https://huggingface.co/papers/2309.03883) for more details.
|
| 334 |
"""
|
| 335 |
+
generation_outputs = GenerationMixin.generate(
|
| 336 |
+
model, *args, custom_generate=_dola_decoding, **kwargs
|
| 337 |
+
)
|
| 338 |
return generation_outputs
|