Maxtimer97 commited on
Commit
c8b397a
·
1 Parent(s): 251c74f

Corrected device assignment error

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +1 -1
modeling_chatglm.py CHANGED
@@ -506,7 +506,7 @@ class NativeSparseAttention(CoreAttention):
506
  cmp_seqlens = torch.cat(
507
  [
508
  torch.zeros(1, dtype=torch.int32, device=query_states.device),
509
- torch.cumsum(y_seqlens, dim=0, device=query_states.device),
510
  ],
511
  dim=0,
512
  ).to(torch.int32)
 
506
  cmp_seqlens = torch.cat(
507
  [
508
  torch.zeros(1, dtype=torch.int32, device=query_states.device),
509
+ torch.cumsum(y_seqlens, dim=0).to(query_states.device),
510
  ],
511
  dim=0,
512
  ).to(torch.int32)