Update yuan_hf_model.py
Browse files- yuan_hf_model.py +4 -4
yuan_hf_model.py
CHANGED
|
@@ -32,8 +32,8 @@ from transformers.modeling_utils import PreTrainedModel
|
|
| 32 |
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
| 33 |
from .configuration_yuan import YuanConfig
|
| 34 |
from einops import rearrange
|
| 35 |
-
|
| 36 |
-
|
| 37 |
|
| 38 |
import copy
|
| 39 |
|
|
@@ -268,8 +268,8 @@ class YuanAttention(nn.Module):
|
|
| 268 |
is_first_step = False
|
| 269 |
if use_cache:
|
| 270 |
if past_key_value is None:
|
| 271 |
-
|
| 272 |
-
inference_hidden_states_memory = torch.empty(bsz, 2, hidden_states.shape[2], dtype=hidden_states.dtype)
|
| 273 |
is_first_step = True
|
| 274 |
else:
|
| 275 |
before_hidden_states = past_key_value[2]
|
|
|
|
| 32 |
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
| 33 |
from .configuration_yuan import YuanConfig
|
| 34 |
from einops import rearrange
|
| 35 |
+
from flash_attn import flash_attn_varlen_func as flash_attn_unpadded_func
|
| 36 |
+
from flash_attn import flash_attn_func
|
| 37 |
|
| 38 |
import copy
|
| 39 |
|
|
|
|
| 268 |
is_first_step = False
|
| 269 |
if use_cache:
|
| 270 |
if past_key_value is None:
|
| 271 |
+
inference_hidden_states_memory = torch.empty(bsz, 2, hidden_states.shape[2], dtype=hidden_states.dtype ,device=torch.cuda.current_device())
|
| 272 |
+
#inference_hidden_states_memory = torch.empty(bsz, 2, hidden_states.shape[2], dtype=hidden_states.dtype)
|
| 273 |
is_first_step = True
|
| 274 |
else:
|
| 275 |
before_hidden_states = past_key_value[2]
|