Update custom_generate/generate.py

#1
by RaushanTurganbay HF Staff - opened
Files changed (1) hide show
  1. custom_generate/generate.py +23 -18
custom_generate/generate.py CHANGED
@@ -17,12 +17,12 @@ from transformers.generation.utils import (
17
  )
18
 
19
 
20
- def generate(
21
  model: Any,
22
  input_ids: torch.LongTensor,
23
- logits_processor: Optional[LogitsProcessorList] = None,
24
- stopping_criteria: Optional[StoppingCriteriaList] = None,
25
- generation_config: Optional[GenerationConfig] = None,
26
  synced_gpus: bool = False,
27
  streamer: Optional[Any] = None,
28
  **model_kwargs,
@@ -44,12 +44,6 @@ def generate(
44
  depending on `return_dict_in_generate` and model type.
45
  """
46
 
47
- # Ensure processors/criteria are defined
48
- if logits_processor is None:
49
- logits_processor = LogitsProcessorList()
50
- if stopping_criteria is None:
51
- stopping_criteria = StoppingCriteriaList()
52
-
53
  # Get DeepCONF parameters from generation_config or set defaults
54
  enable_conf = getattr(generation_config, "enable_conf", False)
55
  enable_early_stopping = getattr(generation_config, "enable_early_stopping", True) # NEW: Allow disabling early stopping
@@ -75,14 +69,7 @@ def generate(
75
 
76
  # Initialize values
77
  # Handle pad token properly (following HF best practices)
78
- pad_token_id = generation_config.pad_token_id
79
- if pad_token_id is None and hasattr(generation_config, "_pad_token_tensor"):
80
- pad_token_id = generation_config._pad_token_tensor
81
- if pad_token_id is None and hasattr(model.config, "pad_token_id"):
82
- pad_token_id = model.config.pad_token_id
83
- if pad_token_id is None and generation_config.eos_token_id is not None:
84
- # Use eos token as pad token if not set
85
- pad_token_id = generation_config.eos_token_id
86
 
87
  output_attentions = generation_config.output_attentions
88
  output_hidden_states = generation_config.output_hidden_states
@@ -383,3 +370,21 @@ def generate(
383
  return output
384
  else:
385
  return input_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  )
18
 
19
 
20
+ def _deepconf_generate(
21
  model: Any,
22
  input_ids: torch.LongTensor,
23
+ logits_processor: Optional[LogitsProcessorList],
24
+ stopping_criteria: Optional[StoppingCriteriaList],
25
+ generation_config: Optional[GenerationConfig],
26
  synced_gpus: bool = False,
27
  streamer: Optional[Any] = None,
28
  **model_kwargs,
 
44
  depending on `return_dict_in_generate` and model type.
45
  """
46
 
 
 
 
 
 
 
47
  # Get DeepCONF parameters from generation_config or set defaults
48
  enable_conf = getattr(generation_config, "enable_conf", False)
49
  enable_early_stopping = getattr(generation_config, "enable_early_stopping", True) # NEW: Allow disabling early stopping
 
69
 
70
  # Initialize values
71
  # Handle pad token properly (following HF best practices)
72
+ pad_token_id = generation_config._pad_token_tensor
 
 
 
 
 
 
 
73
 
74
  output_attentions = generation_config.output_attentions
75
  output_hidden_states = generation_config.output_hidden_states
 
370
  return output
371
  else:
372
  return input_ids
373
+
374
+
375
+ def generate(model, *args, **kwargs):
376
+ """Custom generate function for group beam search decoding.
377
+ Args:
378
+ model (`PreTrainedModel`):
379
+ The model to generate from.
380
+ num_beams (`int`): The number of beams to use for beam search.
381
+ num_beam_groups (`int`): The number of beam groups to use for beam search.
382
+ length_penalty (`float`): The length penalty to use for beam search.
383
+ early_stopping (`bool`): Whether to stop beam search when sufficient beams have finished.
384
+ num_return_sequences (`int`): The number of sequences to return.
385
+ max_length (`int`): The maximum length of the generated sequence.
386
+ """
387
+ generation_outputs = GenerationMixin.generate(
388
+ model, *args, custom_generate=_deepconf_generate, **kwargs
389
+ )
390
+ return generation_outputs