Commit
·
bdf9fdc
1
Parent(s):
bf53a99
Safe cache check
Browse files- modeling_chatglm.py +3 -2
modeling_chatglm.py
CHANGED
|
@@ -717,8 +717,9 @@ class SelfAttention(torch.nn.Module):
|
|
| 717 |
# adjust key and value for inference
|
| 718 |
if kv_cache is not None:
|
| 719 |
cache_k, cache_v = kv_cache
|
| 720 |
-
|
| 721 |
-
|
|
|
|
| 722 |
if use_cache:
|
| 723 |
if kv_cache is None:
|
| 724 |
kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)),
|
|
|
|
| 717 |
# adjust key and value for inference
|
| 718 |
if kv_cache is not None:
|
| 719 |
cache_k, cache_v = kv_cache
|
| 720 |
+
if cache_k is not None and cache_v is not None:
|
| 721 |
+
key_layer = torch.cat((cache_k, key_layer), dim=2)
|
| 722 |
+
value_layer = torch.cat((cache_v, value_layer), dim=2)
|
| 723 |
if use_cache:
|
| 724 |
if kv_cache is None:
|
| 725 |
kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)),
|