Maxtimer97 commited on
Commit
251c74f
·
1 Parent(s): 3d4c21e

Corrected device assignment error

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +2 -2
modeling_chatglm.py CHANGED
@@ -505,8 +505,8 @@ class NativeSparseAttention(CoreAttention):
505
  y_seqlens[seqlens < self.kernel_size] = 0
506
  cmp_seqlens = torch.cat(
507
  [
508
- torch.zeros(1, dtype=torch.int32, device=y_seqlens.device),
509
- torch.cumsum(y_seqlens, dim=0),
510
  ],
511
  dim=0,
512
  ).to(torch.int32)
 
505
  y_seqlens[seqlens < self.kernel_size] = 0
506
  cmp_seqlens = torch.cat(
507
  [
508
+ torch.zeros(1, dtype=torch.int32, device=query_states.device),
509
+ torch.cumsum(y_seqlens, dim=0, device=query_states.device),
510
  ],
511
  dim=0,
512
  ).to(torch.int32)