Use nn.Linear instead of a custom linear function

#27
by Disty0 - opened
Files changed (1) hide show
  1. layers.py +2 -6
layers.py CHANGED
@@ -31,10 +31,6 @@ class LinearWeights:
31
  bias: torch.Tensor
32
 
33
 
34
- def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
35
- return F.linear(x, w.weight, w.bias)
36
-
37
-
38
  def dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16):
39
  _step = W_q.shape[0]
40
  W_r = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device)
@@ -226,9 +222,9 @@ def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor:
226
 
227
  q, k, v = [
228
  t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
229
- for t in linear(x, w.qkv).chunk(3, dim=-1)
230
  ]
231
  out = F.scaled_dot_product_attention(q, k, v)
232
  out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
233
- out = linear(out, w.proj)
234
  return out
 
31
  bias: torch.Tensor
32
 
33
 
 
 
 
 
34
  def dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16):
35
  _step = W_q.shape[0]
36
  W_r = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device)
 
222
 
223
  q, k, v = [
224
  t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
225
+ for t in w.qkv(x).chunk(3, dim=-1)
226
  ]
227
  out = F.scaled_dot_product_attention(q, k, v)
228
  out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
229
+ out = w.proj(out)
230
  return out