Upload folder using huggingface_hub
Browse files- custom_generate/generate.py +17 -17
custom_generate/generate.py
CHANGED
|
@@ -83,7 +83,7 @@ def _dola_select_contrast(
|
|
| 83 |
return logits
|
| 84 |
|
| 85 |
def _dola_decoding(
|
| 86 |
-
|
| 87 |
input_ids: torch.LongTensor,
|
| 88 |
logits_processor: LogitsProcessorList,
|
| 89 |
stopping_criteria: StoppingCriteriaList,
|
|
@@ -141,7 +141,7 @@ def _dola_decoding(
|
|
| 141 |
if getattr(generation_config, "num_beams", 1) != 1:
|
| 142 |
raise ValueError("DoLa generation needs num_beams == 1")
|
| 143 |
|
| 144 |
-
if
|
| 145 |
raise ValueError("DoLa decoding is only available for decoder-only models.")
|
| 146 |
|
| 147 |
if generation_config.repetition_penalty < 1.2:
|
|
@@ -150,13 +150,13 @@ def _dola_decoding(
|
|
| 150 |
"The recommended value for DoLa decoding is `repetition_penalty>=1.2`.",
|
| 151 |
)
|
| 152 |
|
| 153 |
-
if getattr(
|
| 154 |
# DoLa decoding was not designed for stateful models, and would require some changes
|
| 155 |
raise ValueError(
|
| 156 |
-
f"DoLa decoding is not supported with stateful models, such as {
|
| 157 |
)
|
| 158 |
|
| 159 |
-
if
|
| 160 |
raise ValueError("DoLa decoding is only available for decoder-only models.")
|
| 161 |
|
| 162 |
# init values
|
|
@@ -179,17 +179,17 @@ def _dola_decoding(
|
|
| 179 |
# keep track of which sequences are already finished
|
| 180 |
batch_size, cur_length = input_ids.shape[:2]
|
| 181 |
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
| 182 |
-
model_kwargs =
|
| 183 |
|
| 184 |
this_peer_finished = False
|
| 185 |
|
| 186 |
# prepare layers for DoLa decoding
|
| 187 |
-
final_layer =
|
| 188 |
# if the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer,
|
| 189 |
# as the early exit from word embeddings will become identity function
|
| 190 |
# if the model is really shallow (<=2 layers), we use the 1st layer if it's not the final layer and the 0-th
|
| 191 |
# layer otherwise. Notice that DoLa does not help shallow models much.
|
| 192 |
-
if not
|
| 193 |
start_layer = 0
|
| 194 |
elif final_layer > 2:
|
| 195 |
start_layer = 2
|
|
@@ -223,16 +223,16 @@ def _dola_decoding(
|
|
| 223 |
else:
|
| 224 |
raise ValueError("dola_layers must be either 'low', 'high' or a list of integers.")
|
| 225 |
|
| 226 |
-
lm_head =
|
| 227 |
if lm_head is None:
|
| 228 |
raise ValueError("DoLa is not supported for models that don't have output embeddings.")
|
| 229 |
|
| 230 |
-
while
|
| 231 |
# prepare model inputs
|
| 232 |
-
model_inputs =
|
| 233 |
|
| 234 |
# forward pass to get next token
|
| 235 |
-
outputs =
|
| 236 |
**model_inputs,
|
| 237 |
return_dict=True,
|
| 238 |
output_attentions=output_attentions,
|
|
@@ -249,10 +249,10 @@ def _dola_decoding(
|
|
| 249 |
).to(final_logits.device)
|
| 250 |
|
| 251 |
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
| 252 |
-
model_kwargs =
|
| 253 |
outputs,
|
| 254 |
model_kwargs,
|
| 255 |
-
is_encoder_decoder=
|
| 256 |
)
|
| 257 |
if synced_gpus and this_peer_finished:
|
| 258 |
continue
|
|
@@ -272,15 +272,15 @@ def _dola_decoding(
|
|
| 272 |
raw_logits += (final_layer_next_token_logits,)
|
| 273 |
if output_attentions:
|
| 274 |
decoder_attentions += (
|
| 275 |
-
(outputs.decoder_attentions,) if
|
| 276 |
)
|
| 277 |
-
if
|
| 278 |
cross_attentions += (outputs.cross_attentions,)
|
| 279 |
|
| 280 |
if output_hidden_states:
|
| 281 |
decoder_hidden_states += (
|
| 282 |
(outputs.decoder_hidden_states,)
|
| 283 |
-
if
|
| 284 |
else (outputs.hidden_states,)
|
| 285 |
)
|
| 286 |
|
|
|
|
| 83 |
return logits
|
| 84 |
|
| 85 |
def _dola_decoding(
|
| 86 |
+
model,
|
| 87 |
input_ids: torch.LongTensor,
|
| 88 |
logits_processor: LogitsProcessorList,
|
| 89 |
stopping_criteria: StoppingCriteriaList,
|
|
|
|
| 141 |
if getattr(generation_config, "num_beams", 1) != 1:
|
| 142 |
raise ValueError("DoLa generation needs num_beams == 1")
|
| 143 |
|
| 144 |
+
if model.config.is_encoder_decoder:
|
| 145 |
raise ValueError("DoLa decoding is only available for decoder-only models.")
|
| 146 |
|
| 147 |
if generation_config.repetition_penalty < 1.2:
|
|
|
|
| 150 |
"The recommended value for DoLa decoding is `repetition_penalty>=1.2`.",
|
| 151 |
)
|
| 152 |
|
| 153 |
+
if getattr(model, "_is_stateful", False):
|
| 154 |
# DoLa decoding was not designed for stateful models, and would require some changes
|
| 155 |
raise ValueError(
|
| 156 |
+
f"DoLa decoding is not supported with stateful models, such as {model.__class__.__name__}"
|
| 157 |
)
|
| 158 |
|
| 159 |
+
if model.config.is_encoder_decoder:
|
| 160 |
raise ValueError("DoLa decoding is only available for decoder-only models.")
|
| 161 |
|
| 162 |
# init values
|
|
|
|
| 179 |
# keep track of which sequences are already finished
|
| 180 |
batch_size, cur_length = input_ids.shape[:2]
|
| 181 |
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
| 182 |
+
model_kwargs = model._get_initial_cache_position(cur_length, input_ids.device, model_kwargs)
|
| 183 |
|
| 184 |
this_peer_finished = False
|
| 185 |
|
| 186 |
# prepare layers for DoLa decoding
|
| 187 |
+
final_layer = model.config.get_text_config().num_hidden_layers
|
| 188 |
# if the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer,
|
| 189 |
# as the early exit from word embeddings will become identity function
|
| 190 |
# if the model is really shallow (<=2 layers), we use the 1st layer if it's not the final layer and the 0-th
|
| 191 |
# layer otherwise. Notice that DoLa does not help shallow models much.
|
| 192 |
+
if not model.config.tie_word_embeddings:
|
| 193 |
start_layer = 0
|
| 194 |
elif final_layer > 2:
|
| 195 |
start_layer = 2
|
|
|
|
| 223 |
else:
|
| 224 |
raise ValueError("dola_layers must be either 'low', 'high' or a list of integers.")
|
| 225 |
|
| 226 |
+
lm_head = model.get_output_embeddings()
|
| 227 |
if lm_head is None:
|
| 228 |
raise ValueError("DoLa is not supported for models that don't have output embeddings.")
|
| 229 |
|
| 230 |
+
while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
|
| 231 |
# prepare model inputs
|
| 232 |
+
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
| 233 |
|
| 234 |
# forward pass to get next token
|
| 235 |
+
outputs = model(
|
| 236 |
**model_inputs,
|
| 237 |
return_dict=True,
|
| 238 |
output_attentions=output_attentions,
|
|
|
|
| 249 |
).to(final_logits.device)
|
| 250 |
|
| 251 |
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
| 252 |
+
model_kwargs = model._update_model_kwargs_for_generation(
|
| 253 |
outputs,
|
| 254 |
model_kwargs,
|
| 255 |
+
is_encoder_decoder=model.config.is_encoder_decoder,
|
| 256 |
)
|
| 257 |
if synced_gpus and this_peer_finished:
|
| 258 |
continue
|
|
|
|
| 272 |
raw_logits += (final_layer_next_token_logits,)
|
| 273 |
if output_attentions:
|
| 274 |
decoder_attentions += (
|
| 275 |
+
(outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,)
|
| 276 |
)
|
| 277 |
+
if model.config.is_encoder_decoder:
|
| 278 |
cross_attentions += (outputs.cross_attentions,)
|
| 279 |
|
| 280 |
if output_hidden_states:
|
| 281 |
decoder_hidden_states += (
|
| 282 |
(outputs.decoder_hidden_states,)
|
| 283 |
+
if model.config.is_encoder_decoder
|
| 284 |
else (outputs.hidden_states,)
|
| 285 |
)
|
| 286 |
|