manbeast3b commited on
Commit
6005ab8
·
verified ·
1 Parent(s): afb21bf

Update src/pipeline.py

Browse files
Files changed (1) hide show
  1. src/pipeline.py +3 -2
src/pipeline.py CHANGED
@@ -50,7 +50,7 @@ class W8A16LinearLayer(nn.Module):
50
  def replace_linear_with_target_and_quantize(module, target_class, module_name_to_exclude):
51
  # with open("/root/.cache/huggingface/hub/output_layers.txt", "a") as f:
52
  for name, child in module.named_children():
53
- if isinstance(child, nn.Linear) and ( 'to_q' in name or 'to_k' in name or 'to_v' in name ): #and not any([x == name for x in module_name_to_exclude]): 'linear' in name or
54
  old_bias = child.bias
55
  old_weight = child.weight
56
  new_module = target_class(child.in_features, child.out_features, old_bias is not None, child.weight.dtype)
@@ -98,7 +98,8 @@ def load_pipeline() -> Pipeline:
98
  pipeline.text_encoder.to(memory_format=torch.channels_last)
99
  pipeline.transformer.to(memory_format=torch.channels_last)
100
  replace_linear_with_target_and_quantize(pipeline.transformer, W8A16LinearLayer, [])
101
- # exit()
 
102
 
103
  pipeline.vae.to(memory_format=torch.channels_last)
104
  pipeline.vae = torch.compile(pipeline.vae)
 
50
  def replace_linear_with_target_and_quantize(module, target_class, module_name_to_exclude):
51
  # with open("/root/.cache/huggingface/hub/output_layers.txt", "a") as f:
52
  for name, child in module.named_children():
53
+ if isinstance(child, nn.Linear) and ( 'add_k_proj' in name or 'add_v_proj' in name or 'add_q_proj' in name ): #and not any([x == name for x in module_name_to_exclude]): 'linear' in name or
54
  old_bias = child.bias
55
  old_weight = child.weight
56
  new_module = target_class(child.in_features, child.out_features, old_bias is not None, child.weight.dtype)
 
98
  pipeline.text_encoder.to(memory_format=torch.channels_last)
99
  pipeline.transformer.to(memory_format=torch.channels_last)
100
  replace_linear_with_target_and_quantize(pipeline.transformer, W8A16LinearLayer, [])
101
+ pipeline.transformer.save_pretrained("/root/.cache/huggingface/hub/transformer-flux")
102
+ exit()
103
 
104
  pipeline.vae.to(memory_format=torch.channels_last)
105
  pipeline.vae = torch.compile(pipeline.vae)