leejunhyeok commited on
Commit
72cc86d
·
verified ·
1 Parent(s): 20b97f1

Update modeling_motif.py

Browse files
Files changed (1) hide show
  1. modeling_motif.py +4 -7
modeling_motif.py CHANGED
@@ -545,9 +545,7 @@ class MotifFlashAttention2(MotifAttention):
545
 
546
  bsz = query_states.shape[0]
547
 
548
- return map(
549
- lambda x: x.float32(),
550
- _flash_attention_forward(query_states.bfloat16(),
551
  key_states.bfloat16(),
552
  value_states.bfloat16(),
553
  attention_mask,
@@ -557,7 +555,6 @@ class MotifFlashAttention2(MotifAttention):
557
  sliding_window=sliding_window,
558
  is_causal=self.is_causal,
559
  use_top_left_mask=self._flash_attn_uses_top_left_mask)
560
- )
561
 
562
  def forward(
563
  self,
@@ -642,7 +639,7 @@ class MotifFlashAttention2(MotifAttention):
642
  self._compute_attention(q2, k2, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window, self.batch_num)
643
 
644
 
645
- attn1, attn2 = torch.cat([attn11, attn12], dim=-1), torch.cat([attn21, attn22], dim=-1)
646
 
647
  lambda_q1 = self.lambda_q1.unsqueeze(0).expand([bsz, self.lambda_q1.shape[0]]) # bsz, num_head
648
  lambda_q2 = self.lambda_q2.unsqueeze(0).expand([bsz, self.lambda_q2.shape[0]]) # bsz, num_head
@@ -661,10 +658,10 @@ class MotifFlashAttention2(MotifAttention):
661
  raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
662
  f" {attn_output.size()}")
663
 
664
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
665
  attn_output = self.o_proj(attn_output) * self.o_proj_alpha
666
 
667
- return attn_output, None, past_key_value
668
 
669
 
670
  # @log_timing
 
545
 
546
  bsz = query_states.shape[0]
547
 
548
+ return _flash_attention_forward(query_states.bfloat16(),
 
 
549
  key_states.bfloat16(),
550
  value_states.bfloat16(),
551
  attention_mask,
 
555
  sliding_window=sliding_window,
556
  is_causal=self.is_causal,
557
  use_top_left_mask=self._flash_attn_uses_top_left_mask)
 
558
 
559
  def forward(
560
  self,
 
639
  self._compute_attention(q2, k2, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window, self.batch_num)
640
 
641
 
642
+ attn1, attn2 = torch.cat([attn11, attn12], dim=-1).float(), torch.cat([attn21, attn22], dim=-1).float()
643
 
644
  lambda_q1 = self.lambda_q1.unsqueeze(0).expand([bsz, self.lambda_q1.shape[0]]) # bsz, num_head
645
  lambda_q2 = self.lambda_q2.unsqueeze(0).expand([bsz, self.lambda_q2.shape[0]]) # bsz, num_head
 
658
  raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
659
  f" {attn_output.size()}")
660
 
661
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).bfloat16()
662
  attn_output = self.o_proj(attn_output) * self.o_proj_alpha
663
 
664
+ return attn_output.float(), None, past_key_value
665
 
666
 
667
  # @log_timing