Update src/pipeline.py
Browse files- src/pipeline.py +1 -0
src/pipeline.py
CHANGED
|
@@ -52,6 +52,7 @@ class W8A16LinearLayer(nn.Module):
|
|
| 52 |
def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
|
| 53 |
super().__init__()
|
| 54 |
self.weight = nn.Parameter(torch.randn(out_features, in_features, dtype=dtype))
|
|
|
|
| 55 |
if bias:
|
| 56 |
self.bias = nn.Parameter(torch.randn(1, out_features, dtype=dtype))
|
| 57 |
self.scales = nn.Parameter(torch.randn(out_features, dtype=dtype))
|
|
|
|
| 52 |
def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
|
| 53 |
super().__init__()
|
| 54 |
self.weight = nn.Parameter(torch.randn(out_features, in_features, dtype=dtype))
|
| 55 |
+
self.weight.requires_grad = False
|
| 56 |
if bias:
|
| 57 |
self.bias = nn.Parameter(torch.randn(1, out_features, dtype=dtype))
|
| 58 |
self.scales = nn.Parameter(torch.randn(out_features, dtype=dtype))
|