manbeast3b commited on
Commit
3fc2f94
·
verified ·
1 Parent(s): 5829a8d

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +1 -0
src/pipeline.py CHANGED
@@ -52,6 +52,7 @@ class W8A16LinearLayer(nn.Module):
52
  def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
53
  super().__init__()
54
  self.weight = nn.Parameter(torch.randn(out_features, in_features, dtype=dtype))
 
55
  if bias:
56
  self.bias = nn.Parameter(torch.randn(1, out_features, dtype=dtype))
57
  self.scales = nn.Parameter(torch.randn(out_features, dtype=dtype))
 
52
  def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
53
  super().__init__()
54
  self.weight = nn.Parameter(torch.randn(out_features, in_features, dtype=dtype))
55
+ self.weight.requires_grad = False
56
  if bias:
57
  self.bias = nn.Parameter(torch.randn(1, out_features, dtype=dtype))
58
  self.scales = nn.Parameter(torch.randn(out_features, dtype=dtype))