manbeast3b commited on
Commit
0b842dd
·
verified ·
1 Parent(s): 3fc2f94

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +24 -1
src/pipeline.py CHANGED
@@ -71,9 +71,32 @@ class W8A16LinearLayer(nn.Module):
71
  output = output + self.bias
72
  return output
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  def replace_linear_with_target_and_quantize(module, target_class, module_name_to_exclude):
75
  # with open("/root/.cache/huggingface/hub/output_layers.txt", "a") as f:
76
- for name, child in module.named_children():
 
 
77
  if isinstance(child, nn.Linear) and ( 'add_k_proj' in name or 'add_v_proj' in name or 'add_q_proj' in name ): #and not any([x == name for x in module_name_to_exclude]): 'linear' in name or
78
  old_bias = child.bias
79
  old_weight = child.weight
 
71
  output = output + self.bias
72
  return output
73
 
74
+ # def replace_linear_with_target_and_quantize(module, target_class, module_name_to_exclude):
75
+ # # with open("/root/.cache/huggingface/hub/output_layers.txt", "a") as f:
76
+ # for name, child in module.named_children():
77
+ # if isinstance(child, nn.Linear) and ( 'add_k_proj' in name or 'add_v_proj' in name or 'add_q_proj' in name ): #and not any([x == name for x in module_name_to_exclude]): 'linear' in name or
78
+ # old_bias = child.bias
79
+ # old_weight = child.weight
80
+ # new_module = target_class(child.in_features, child.out_features, old_bias is not None, child.weight.dtype)
81
+ # new_module.quantize(old_weight)
82
+ # delattr(module, name)
83
+ # setattr(module, name, new_module)
84
+ # if old_bias is not None:
85
+ # getattr(module, name).bias = old_bias
86
+
87
+ # # # Print the replaced layer name and calculate the change in size
88
+ # # old_size = old_weight.numel() * old_weight.element_size()
89
+ # # new_size = new_module.int8_weights.numel() * new_module.int8_weights.element_size()
90
+ # # f.write(f"Replaced layer: {name}" + f" Size reduction: {old_size} bytes -> {new_size} bytes ({(old_size - new_size) / old_size * 100:.2f}% reduction)")
91
+ # else:
92
+ # # Recursively call the function for nested modules
93
+ # replace_linear_with_target_and_quantize(child, target_class, module_name_to_exclude)
94
+
95
  def replace_linear_with_target_and_quantize(module, target_class, module_name_to_exclude):
96
  # with open("/root/.cache/huggingface/hub/output_layers.txt", "a") as f:
97
+
98
+ for name in list(module._modules.keys()):
99
+ child = module._modules[name]
100
  if isinstance(child, nn.Linear) and ( 'add_k_proj' in name or 'add_v_proj' in name or 'add_q_proj' in name ): #and not any([x == name for x in module_name_to_exclude]): 'linear' in name or
101
  old_bias = child.bias
102
  old_weight = child.weight