davda54 commited on
Commit
1a516b1
·
verified ·
1 Parent(s): 7a782b6

Update modeling_gptbert.py

Browse files
Files changed (1) hide show
  1. modeling_gptbert.py +2 -0
modeling_gptbert.py CHANGED
@@ -349,6 +349,8 @@ class SelfAttention(nn.Module):
349
  self.k_out_dim = self.d_qk * self.num_kv_heads
350
  self.v_out_dim = self.d_v * self.num_kv_heads
351
 
 
 
352
  self.qk_proj = MultiCastedLinearOrthoIn(self.hidden_size, [self.q_out_dim, self.k_out_dim], bias=False)
353
  self.v_proj = CastedLinearIn(self.hidden_size, self.v_out_dim, bias=False)
354
  self.out_proj = CastedLinearIn(self.d_v*self.num_attention_heads, self.hidden_size, bias=False)
 
349
  self.k_out_dim = self.d_qk * self.num_kv_heads
350
  self.v_out_dim = self.d_v * self.num_kv_heads
351
 
352
+ self.is_causal = is_decoder
353
+
354
  self.qk_proj = MultiCastedLinearOrthoIn(self.hidden_size, [self.q_out_dim, self.k_out_dim], bias=False)
355
  self.v_proj = CastedLinearIn(self.hidden_size, self.v_out_dim, bias=False)
356
  self.out_proj = CastedLinearIn(self.d_v*self.num_attention_heads, self.hidden_size, bias=False)