manbeast3b commited on
Commit
8e641d3
·
verified ·
1 Parent(s): 2f5adf0

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +9 -9
src/pipeline.py CHANGED
@@ -17,14 +17,19 @@ import os
17
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.01"
18
  Pipeline = None
19
 
 
 
 
 
 
 
 
20
  class W8A16LinearLayer(nn.Module):
21
  def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
22
  super().__init__()
23
-
24
  self.register_buffer(
25
  "int8_weights",
26
  torch.randint(-128, 127, (out_features, in_features), dtype=torch.int8))
27
-
28
  self.register_buffer("scales", torch.randn((out_features), dtype=dtype))
29
 
30
  if bias:
@@ -33,14 +38,12 @@ class W8A16LinearLayer(nn.Module):
33
 
34
  def quantize(self, weights):
35
  w_fp32 = weights.clone().to(torch.float32)
36
-
37
  scales = w_fp32.abs().max(dim=-1).values / 127
38
  scales = scales.to(weights.dtype)
39
-
40
  int8_weights = torch.round(weights/scales.unsqueeze(1)).to(torch.int8)
41
-
42
  self.int8_weights = int8_weights
43
- self.scales = scales self.bias = None
 
44
 
45
  def forward(self, input):
46
  return w8_a16_forward(self.int8_weights, input, self.scales, self.bias)
@@ -51,12 +54,9 @@ def replace_linear_with_target_and_quantize(module, target_class, module_name_to
51
  any([x == name for x in module_name_to_exclude]):
52
  old_bias = child.bias
53
  old_weight = child.weight
54
-
55
  new_module = target_class(child.in_features, child.out_features, old_bias is not None, child.weight.dtype)
56
  setattr(module, name, new_module)
57
-
58
  getattr(module, name).quantize(old_weight)
59
-
60
  if old_bias is not None:
61
  getattr(module, name).bias = old_bias
62
  else:
 
17
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False,garbage_collection_threshold:0.01"
18
  Pipeline = None
19
 
20
+ def w8_a16_forward(weight, input, scales, bias=None):
21
+ casted_weights = weight.to(input.dtype)
22
+ output = F.linear(input, casted_weights) * scales # overhead
23
+ if bias is not None:
24
+ output = output + bias
25
+ return output
26
+
27
  class W8A16LinearLayer(nn.Module):
28
  def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
29
  super().__init__()
 
30
  self.register_buffer(
31
  "int8_weights",
32
  torch.randint(-128, 127, (out_features, in_features), dtype=torch.int8))
 
33
  self.register_buffer("scales", torch.randn((out_features), dtype=dtype))
34
 
35
  if bias:
 
38
 
39
  def quantize(self, weights):
40
  w_fp32 = weights.clone().to(torch.float32)
 
41
  scales = w_fp32.abs().max(dim=-1).values / 127
42
  scales = scales.to(weights.dtype)
 
43
  int8_weights = torch.round(weights/scales.unsqueeze(1)).to(torch.int8)
 
44
  self.int8_weights = int8_weights
45
+ self.scales = scales
46
+ self.bias = None
47
 
48
  def forward(self, input):
49
  return w8_a16_forward(self.int8_weights, input, self.scales, self.bias)
 
54
  any([x == name for x in module_name_to_exclude]):
55
  old_bias = child.bias
56
  old_weight = child.weight
 
57
  new_module = target_class(child.in_features, child.out_features, old_bias is not None, child.weight.dtype)
58
  setattr(module, name, new_module)
 
59
  getattr(module, name).quantize(old_weight)
 
60
  if old_bias is not None:
61
  getattr(module, name).bias = old_bias
62
  else: