cyrilvallez HF Staff commited on
Commit
00ab760
·
verified ·
1 Parent(s): cd66a74
Files changed (1) hide show
  1. custom_generate/generate.py +1 -1
custom_generate/generate.py CHANGED
@@ -183,7 +183,7 @@ def _contrastive_search(
183
  unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
184
  # Does not exist anymore in recent versions!
185
  if hasattr(model, "_get_initial_cache_position"):
186
- model_kwargs = model._get_initial_cache_position(cur_length, input_ids.device, model_kwargs)
187
 
188
  # Create cosine_matrix_mask based on the attention_mask
189
  cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long)
 
183
  unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
184
  # Does not exist anymore in recent versions!
185
  if hasattr(model, "_get_initial_cache_position"):
186
+ model_kwargs = model._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
187
 
188
  # Create cosine_matrix_mask based on the attention_mask
189
  cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long)