Update custom_generate/generate.py
Browse files
custom_generate/generate.py
CHANGED
|
@@ -210,12 +210,6 @@ def generate(model, *args, **kwargs):
|
|
| 210 |
xtc_threshold (float): The threshold for defining a "top choice". Default 0.1.
|
| 211 |
xtc_protected_tokens (List[int]): Optional list of specific token IDs to prevent XTC from removing (e.g., newlines).
|
| 212 |
"""
|
| 213 |
-
# XTC is effectively a sampler, so we should ensure do_sample is True in the config
|
| 214 |
-
if "generation_config" in kwargs:
|
| 215 |
-
kwargs["generation_config"].do_sample = True
|
| 216 |
-
elif "do_sample" not in kwargs:
|
| 217 |
-
kwargs["do_sample"] = True
|
| 218 |
-
|
| 219 |
# Delegate to the standard GenerationMixin, injecting our custom decoding loop
|
| 220 |
generation_outputs = GenerationMixin.generate(
|
| 221 |
model, *args, custom_generate=_xtc_decoding, **kwargs
|
|
|
|
| 210 |
xtc_threshold (float): The threshold for defining a "top choice". Default 0.1.
|
| 211 |
xtc_protected_tokens (List[int]): Optional list of specific token IDs to prevent XTC from removing (e.g., newlines).
|
| 212 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
# Delegate to the standard GenerationMixin, injecting our custom decoding loop
|
| 214 |
generation_outputs = GenerationMixin.generate(
|
| 215 |
model, *args, custom_generate=_xtc_decoding, **kwargs
|