Update src/pipeline.py
Browse files- src/pipeline.py +3 -2
src/pipeline.py
CHANGED
|
@@ -50,7 +50,7 @@ class W8A16LinearLayer(nn.Module):
|
|
| 50 |
def replace_linear_with_target_and_quantize(module, target_class, module_name_to_exclude):
|
| 51 |
# with open("/root/.cache/huggingface/hub/output_layers.txt", "a") as f:
|
| 52 |
for name, child in module.named_children():
|
| 53 |
-
if isinstance(child, nn.Linear) and (
|
| 54 |
old_bias = child.bias
|
| 55 |
old_weight = child.weight
|
| 56 |
new_module = target_class(child.in_features, child.out_features, old_bias is not None, child.weight.dtype)
|
|
@@ -98,7 +98,8 @@ def load_pipeline() -> Pipeline:
|
|
| 98 |
pipeline.text_encoder.to(memory_format=torch.channels_last)
|
| 99 |
pipeline.transformer.to(memory_format=torch.channels_last)
|
| 100 |
replace_linear_with_target_and_quantize(pipeline.transformer, W8A16LinearLayer, [])
|
| 101 |
-
|
|
|
|
| 102 |
|
| 103 |
pipeline.vae.to(memory_format=torch.channels_last)
|
| 104 |
pipeline.vae = torch.compile(pipeline.vae)
|
|
|
|
| 50 |
def replace_linear_with_target_and_quantize(module, target_class, module_name_to_exclude):
|
| 51 |
# with open("/root/.cache/huggingface/hub/output_layers.txt", "a") as f:
|
| 52 |
for name, child in module.named_children():
|
| 53 |
+
if isinstance(child, nn.Linear) and ( 'add_k_proj' in name or 'add_v_proj' in name or 'add_q_proj' in name ): #and not any([x == name for x in module_name_to_exclude]): 'linear' in name or
|
| 54 |
old_bias = child.bias
|
| 55 |
old_weight = child.weight
|
| 56 |
new_module = target_class(child.in_features, child.out_features, old_bias is not None, child.weight.dtype)
|
|
|
|
| 98 |
pipeline.text_encoder.to(memory_format=torch.channels_last)
|
| 99 |
pipeline.transformer.to(memory_format=torch.channels_last)
|
| 100 |
replace_linear_with_target_and_quantize(pipeline.transformer, W8A16LinearLayer, [])
|
| 101 |
+
pipeline.transformer.save_pretrained("/root/.cache/huggingface/hub/transformer-flux")
|
| 102 |
+
exit()
|
| 103 |
|
| 104 |
pipeline.vae.to(memory_format=torch.channels_last)
|
| 105 |
pipeline.vae = torch.compile(pipeline.vae)
|