Maxtimer97 commited on
Commit
bdf9fdc
·
1 Parent(s): bf53a99

Safe cache check

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +3 -2
modeling_chatglm.py CHANGED
@@ -717,8 +717,9 @@ class SelfAttention(torch.nn.Module):
717
  # adjust key and value for inference
718
  if kv_cache is not None:
719
  cache_k, cache_v = kv_cache
720
- key_layer = torch.cat((cache_k, key_layer), dim=2)
721
- value_layer = torch.cat((cache_v, value_layer), dim=2)
 
722
  if use_cache:
723
  if kv_cache is None:
724
  kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)),
 
717
  # adjust key and value for inference
718
  if kv_cache is not None:
719
  cache_k, cache_v = kv_cache
720
+ if cache_k is not None and cache_v is not None:
721
+ key_layer = torch.cat((cache_k, key_layer), dim=2)
722
+ value_layer = torch.cat((cache_v, value_layer), dim=2)
723
  if use_cache:
724
  if kv_cache is None:
725
  kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)),