Fix typo
Browse files
custom_generate/generate.py
CHANGED
|
@@ -183,7 +183,7 @@ def _contrastive_search(
|
|
| 183 |
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
| 184 |
# Does not exist anymore in recent versions!
|
| 185 |
if hasattr(model, "_get_initial_cache_position"):
|
| 186 |
-
model_kwargs = model._get_initial_cache_position(
|
| 187 |
|
| 188 |
# Create cosine_matrix_mask based on the attention_mask
|
| 189 |
cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long)
|
|
|
|
| 183 |
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
| 184 |
# Does not exist anymore in recent versions!
|
| 185 |
if hasattr(model, "_get_initial_cache_position"):
|
| 186 |
+
model_kwargs = model._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
|
| 187 |
|
| 188 |
# Create cosine_matrix_mask based on the attention_mask
|
| 189 |
cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long)
|