Upload convert_url_to_diffusers_flux_gr.py
Browse files
convert_url_to_diffusers_flux_gr.py
CHANGED
|
@@ -10,7 +10,14 @@ import os
|
|
| 10 |
import argparse
|
| 11 |
import gradio as gr
|
| 12 |
# also requires aria, gdown, peft, huggingface_hub, safetensors, transformers, accelerate, pytorch_lightning
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
import spaces
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
flux_dev_repo = "ChuckMcSneed/FLUX.1-dev"
|
| 16 |
flux_schnell_repo = "black-forest-labs/FLUX.1-schnell"
|
|
@@ -38,6 +45,24 @@ def is_repo_name(s):
|
|
| 38 |
import re
|
| 39 |
return re.fullmatch(r'^[^/,\s]+?/[^/,\s]+?$', s)
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
def print_resource_usage():
|
| 42 |
import psutil
|
| 43 |
cpu_usage = psutil.cpu_percent()
|
|
@@ -363,7 +388,7 @@ def read_safetensors_metadata(path):
|
|
| 363 |
|
| 364 |
def normalize_key(k: str):
|
| 365 |
return k.replace("vae.", "").replace("model.diffusion_model.", "")\
|
| 366 |
-
.replace("text_encoders.clip_l.transformer.
|
| 367 |
.replace("text_encoders.t5xxl.transformer.", "")
|
| 368 |
|
| 369 |
def load_json_list(path: str):
|
|
@@ -465,9 +490,7 @@ with torch.no_grad():
|
|
| 465 |
print(e)
|
| 466 |
return
|
| 467 |
finally:
|
| 468 |
-
|
| 469 |
-
torch.cuda.empty_cache()
|
| 470 |
-
gc.collect()
|
| 471 |
new_path = str(Path(savepath, Path(path).stem + "_fixed" + Path(path).suffix))
|
| 472 |
metadata = read_safetensors_metadata(path)
|
| 473 |
progress(0.5, desc=f"Saving FLUX.1 safetensors: {new_path}")
|
|
@@ -476,9 +499,7 @@ with torch.no_grad():
|
|
| 476 |
save_file(new_sd, new_path, metadata={"format": "pt", **metadata})
|
| 477 |
progress(1, desc=f"Saved FLUX.1 safetensors: {new_path}")
|
| 478 |
print(f"Saved FLUX.1 safetensors: {new_path}")
|
| 479 |
-
|
| 480 |
-
torch.cuda.empty_cache()
|
| 481 |
-
gc.collect()
|
| 482 |
|
| 483 |
with torch.no_grad():
|
| 484 |
def extract_norm_flux_module_sd(path: str, dtype: torch.dtype = torch.bfloat16,
|
|
@@ -506,9 +527,7 @@ with torch.no_grad():
|
|
| 506 |
finally:
|
| 507 |
progress(1, desc=f"Normalized FLUX.1 {name} safetensors: {path}")
|
| 508 |
print(f"Normalized FLUX.1 {name} safetensors: {path}")
|
| 509 |
-
|
| 510 |
-
torch.cuda.empty_cache()
|
| 511 |
-
gc.collect()
|
| 512 |
return new_sd
|
| 513 |
|
| 514 |
with torch.no_grad():
|
|
@@ -541,9 +560,7 @@ with torch.no_grad():
|
|
| 541 |
for k, v in sharded_sd.items():
|
| 542 |
sharded_sd[k] = v.to(device="cpu")
|
| 543 |
sd = sd | sharded_sd.copy()
|
| 544 |
-
|
| 545 |
-
torch.cuda.empty_cache()
|
| 546 |
-
gc.collect()
|
| 547 |
except Exception as e:
|
| 548 |
print(e)
|
| 549 |
return sd
|
|
@@ -561,9 +578,7 @@ with torch.no_grad():
|
|
| 561 |
for k, v in sd.items():
|
| 562 |
if k in set(keys_flux_transformer): sd[k] = v.to(device="cpu")
|
| 563 |
save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size)
|
| 564 |
-
|
| 565 |
-
torch.cuda.empty_cache()
|
| 566 |
-
gc.collect()
|
| 567 |
progress(0.25, desc=f"Saved temporary files to disk: {path}")
|
| 568 |
print(f"Saved temporary files to disk: {path}")
|
| 569 |
for filepath in glob.glob(f"{path}/*.safetensors"):
|
|
@@ -574,9 +589,7 @@ with torch.no_grad():
|
|
| 574 |
for k, v in sharded_sd.items():
|
| 575 |
sharded_sd[k] = v.to(device="cpu")
|
| 576 |
save_file(sharded_sd, str(filepath))
|
| 577 |
-
|
| 578 |
-
torch.cuda.empty_cache()
|
| 579 |
-
gc.collect()
|
| 580 |
print(f"Loading temporary files from disk: {path}")
|
| 581 |
sd = load_sharded_safetensors(path)
|
| 582 |
print(f"Loaded temporary files from disk: {path}")
|
|
@@ -599,9 +612,7 @@ with torch.no_grad():
|
|
| 599 |
for k, v in sd.items():
|
| 600 |
sd[k] = v.to(device="cpu")
|
| 601 |
save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size)
|
| 602 |
-
|
| 603 |
-
torch.cuda.empty_cache()
|
| 604 |
-
gc.collect()
|
| 605 |
progress(0.25, desc=f"Saved temporary files to disk: {path}")
|
| 606 |
print(f"Saved temporary files to disk: {path}")
|
| 607 |
for filepath in glob.glob(f"{path}/*.safetensors"):
|
|
@@ -612,9 +623,7 @@ with torch.no_grad():
|
|
| 612 |
for k, v in sharded_sd.items():
|
| 613 |
sharded_sd[k] = v.to(device="cpu")
|
| 614 |
save_file(sharded_sd, str(filepath))
|
| 615 |
-
|
| 616 |
-
torch.cuda.empty_cache()
|
| 617 |
-
gc.collect()
|
| 618 |
print(f"Processed temporary files: {str(filepath)}")
|
| 619 |
print(f"Loading temporary files from disk: {path}")
|
| 620 |
sd = load_sharded_safetensors(path)
|
|
@@ -678,8 +687,7 @@ with torch.no_grad():
|
|
| 678 |
quantization: bool = False, model_type: str = "dev", dequant: bool = False):
|
| 679 |
save_flux_other_diffusers(savepath, model_type)
|
| 680 |
normalize_flux_state_dict(loadpath, savepath, dtype, dequant)
|
| 681 |
-
|
| 682 |
-
gc.collect()
|
| 683 |
|
| 684 |
with torch.no_grad(): # Much lower memory consumption, but higher disk load
|
| 685 |
def flux_to_diffusers_lowmem(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16,
|
|
@@ -698,40 +706,46 @@ with torch.no_grad(): # Much lower memory consumption, but higher disk load
|
|
| 698 |
vae_sd_path = savepath.removesuffix("/") + "/vae"
|
| 699 |
vae_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors"
|
| 700 |
vae_sd_size = "10GB"
|
|
|
|
| 701 |
metadata = {"format": "pt", **read_safetensors_metadata(loadpath)}
|
|
|
|
|
|
|
| 702 |
if "vae" not in use_original:
|
| 703 |
vae_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "VAE",
|
| 704 |
keys_flux_vae)
|
| 705 |
to_safetensors_flux_module(vae_sd, vae_sd_path, vae_sd_pattern, vae_sd_size,
|
| 706 |
quantization, "VAE", None)
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
gc.collect()
|
| 710 |
if "text_encoder" not in use_original:
|
| 711 |
clip_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "Text Encoder",
|
| 712 |
keys_flux_clip)
|
| 713 |
to_safetensors_flux_module(clip_sd, clip_sd_path, clip_sd_pattern, clip_sd_size,
|
| 714 |
quantization, "Text Encoder", None)
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
gc.collect()
|
| 718 |
if "text_encoder_2" not in use_original:
|
| 719 |
te_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Text Encoder 2",
|
| 720 |
keys_flux_t5xxl)
|
| 721 |
to_safetensors_flux_module(te_sd, te_sd_path, te_sd_pattern, te_sd_size,
|
| 722 |
quantization, "Text Encoder 2", None)
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
gc.collect()
|
| 726 |
unet_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Transformer",
|
| 727 |
keys_flux_transformer)
|
| 728 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 729 |
to_safetensors_flux_module(unet_sd, unet_sd_path, unet_sd_pattern, unet_sd_size,
|
| 730 |
quantization, "Transformer", metadata)
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
gc.collect()
|
| 734 |
save_flux_other_diffusers(savepath, model_type, use_original)
|
|
|
|
| 735 |
|
| 736 |
with torch.no_grad(): # lowest memory consumption, but higheest disk load
|
| 737 |
def flux_to_diffusers_lowmem2(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16,
|
|
@@ -752,47 +766,52 @@ with torch.no_grad(): # lowest memory consumption, but higheest disk load
|
|
| 752 |
vae_sd_path = savepath.removesuffix("/") + "/vae"
|
| 753 |
vae_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors"
|
| 754 |
vae_sd_size = "10GB"
|
|
|
|
| 755 |
metadata = {"format": "pt", **read_safetensors_metadata(loadpath)}
|
|
|
|
|
|
|
| 756 |
if "vae" not in use_original:
|
| 757 |
vae_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "VAE",
|
| 758 |
keys_flux_vae)
|
| 759 |
to_safetensors_flux_module(vae_sd, vae_sd_path, vae_sd_pattern, vae_sd_size,
|
| 760 |
quantization, "VAE", None)
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
gc.collect()
|
| 764 |
if "text_encoder" not in use_original:
|
| 765 |
clip_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "Text Encoder",
|
| 766 |
keys_flux_clip)
|
| 767 |
to_safetensors_flux_module(clip_sd, clip_sd_path, clip_sd_pattern, clip_sd_size,
|
| 768 |
quantization, "Text Encoder", None)
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
gc.collect()
|
| 772 |
if "text_encoder_2" not in use_original:
|
| 773 |
te_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Text Encoder 2",
|
| 774 |
keys_flux_t5xxl)
|
| 775 |
to_safetensors_flux_module(te_sd, te_sd_path, te_sd_pattern, te_sd_size,
|
| 776 |
quantization, "Text Encoder 2", None)
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
gc.collect()
|
| 780 |
unet_sd = extract_normalized_flux_state_dict_sharded(loadpath, dtype, dequant,
|
| 781 |
unet_temp_path, unet_sd_pattern, unet_temp_size)
|
|
|
|
|
|
|
| 782 |
unet_sd = convert_flux_transformer_sd_to_diffusers_sharded(unet_sd, unet_temp_path,
|
| 783 |
unet_sd_pattern, unet_temp_size)
|
|
|
|
|
|
|
| 784 |
to_safetensors_flux_module(unet_sd, unet_sd_path, unet_sd_pattern, unet_sd_size,
|
| 785 |
quantization, "Transformer", metadata)
|
| 786 |
-
|
| 787 |
-
|
| 788 |
-
gc.collect()
|
| 789 |
save_flux_other_diffusers(savepath, model_type, use_original)
|
|
|
|
| 790 |
|
| 791 |
def convert_url_to_diffusers_flux(url, civitai_key="", is_upload_sf=False, data_type="bf16",
|
| 792 |
model_type="dev", dequant=False, use_original=["vae", "text_encoder"],
|
| 793 |
hf_user="", hf_repo="", q=None, progress=gr.Progress(track_tqdm=True)):
|
| 794 |
progress(0, desc="Start converting...")
|
| 795 |
temp_dir = "."
|
|
|
|
| 796 |
new_file = get_download_file(temp_dir, url, civitai_key)
|
| 797 |
if not new_file:
|
| 798 |
print(f"Not found: {url}")
|
|
@@ -825,6 +844,7 @@ def convert_url_to_fixed_flux_safetensors(url, civitai_key="", is_upload_sf=Fals
|
|
| 825 |
model_type="dev", dequant=False, q=None, progress=gr.Progress(track_tqdm=True)):
|
| 826 |
progress(0, desc="Start converting...")
|
| 827 |
temp_dir = "."
|
|
|
|
| 828 |
new_file = get_download_file(temp_dir, url, civitai_key)
|
| 829 |
if not new_file:
|
| 830 |
print(f"Not found: {url}")
|
|
|
|
| 10 |
import argparse
|
| 11 |
import gradio as gr
|
| 12 |
# also requires aria, gdown, peft, huggingface_hub, safetensors, transformers, accelerate, pytorch_lightning
|
| 13 |
+
|
| 14 |
+
import subprocess
|
| 15 |
+
subprocess.run('pip cache purge', shell=True)
|
| 16 |
+
|
| 17 |
import spaces
|
| 18 |
+
@spaces.GPU()
|
| 19 |
+
def spaces_dummy():
|
| 20 |
+
pass
|
| 21 |
|
| 22 |
flux_dev_repo = "ChuckMcSneed/FLUX.1-dev"
|
| 23 |
flux_schnell_repo = "black-forest-labs/FLUX.1-schnell"
|
|
|
|
| 45 |
import re
|
| 46 |
return re.fullmatch(r'^[^/,\s]+?/[^/,\s]+?$', s)
|
| 47 |
|
| 48 |
+
def clear_cache():
|
| 49 |
+
torch.cuda.empty_cache()
|
| 50 |
+
gc.collect()
|
| 51 |
+
|
| 52 |
+
def clear_sd(sd: dict):
|
| 53 |
+
for k in list(sd.keys()):
|
| 54 |
+
sd.pop(k)
|
| 55 |
+
del sd
|
| 56 |
+
torch.cuda.empty_cache()
|
| 57 |
+
gc.collect()
|
| 58 |
+
|
| 59 |
+
def clone_sd(sd: dict):
|
| 60 |
+
print("Cloning state dict.")
|
| 61 |
+
for k in list(sd.keys()):
|
| 62 |
+
sd[k] = sd.pop(k).detach().clone()
|
| 63 |
+
torch.cuda.empty_cache()
|
| 64 |
+
gc.collect()
|
| 65 |
+
|
| 66 |
def print_resource_usage():
|
| 67 |
import psutil
|
| 68 |
cpu_usage = psutil.cpu_percent()
|
|
|
|
| 388 |
|
| 389 |
def normalize_key(k: str):
|
| 390 |
return k.replace("vae.", "").replace("model.diffusion_model.", "")\
|
| 391 |
+
.replace("text_encoders.clip_l.transformer.", "")\
|
| 392 |
.replace("text_encoders.t5xxl.transformer.", "")
|
| 393 |
|
| 394 |
def load_json_list(path: str):
|
|
|
|
| 490 |
print(e)
|
| 491 |
return
|
| 492 |
finally:
|
| 493 |
+
clear_sd(state_dict)
|
|
|
|
|
|
|
| 494 |
new_path = str(Path(savepath, Path(path).stem + "_fixed" + Path(path).suffix))
|
| 495 |
metadata = read_safetensors_metadata(path)
|
| 496 |
progress(0.5, desc=f"Saving FLUX.1 safetensors: {new_path}")
|
|
|
|
| 499 |
save_file(new_sd, new_path, metadata={"format": "pt", **metadata})
|
| 500 |
progress(1, desc=f"Saved FLUX.1 safetensors: {new_path}")
|
| 501 |
print(f"Saved FLUX.1 safetensors: {new_path}")
|
| 502 |
+
clear_sd(new_sd)
|
|
|
|
|
|
|
| 503 |
|
| 504 |
with torch.no_grad():
|
| 505 |
def extract_norm_flux_module_sd(path: str, dtype: torch.dtype = torch.bfloat16,
|
|
|
|
| 527 |
finally:
|
| 528 |
progress(1, desc=f"Normalized FLUX.1 {name} safetensors: {path}")
|
| 529 |
print(f"Normalized FLUX.1 {name} safetensors: {path}")
|
| 530 |
+
clear_sd(state_dict)
|
|
|
|
|
|
|
| 531 |
return new_sd
|
| 532 |
|
| 533 |
with torch.no_grad():
|
|
|
|
| 560 |
for k, v in sharded_sd.items():
|
| 561 |
sharded_sd[k] = v.to(device="cpu")
|
| 562 |
sd = sd | sharded_sd.copy()
|
| 563 |
+
clear_sd(sharded_sd)
|
|
|
|
|
|
|
| 564 |
except Exception as e:
|
| 565 |
print(e)
|
| 566 |
return sd
|
|
|
|
| 578 |
for k, v in sd.items():
|
| 579 |
if k in set(keys_flux_transformer): sd[k] = v.to(device="cpu")
|
| 580 |
save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size)
|
| 581 |
+
clear_sd(sd)
|
|
|
|
|
|
|
| 582 |
progress(0.25, desc=f"Saved temporary files to disk: {path}")
|
| 583 |
print(f"Saved temporary files to disk: {path}")
|
| 584 |
for filepath in glob.glob(f"{path}/*.safetensors"):
|
|
|
|
| 589 |
for k, v in sharded_sd.items():
|
| 590 |
sharded_sd[k] = v.to(device="cpu")
|
| 591 |
save_file(sharded_sd, str(filepath))
|
| 592 |
+
clear_sd(sharded_sd)
|
|
|
|
|
|
|
| 593 |
print(f"Loading temporary files from disk: {path}")
|
| 594 |
sd = load_sharded_safetensors(path)
|
| 595 |
print(f"Loaded temporary files from disk: {path}")
|
|
|
|
| 612 |
for k, v in sd.items():
|
| 613 |
sd[k] = v.to(device="cpu")
|
| 614 |
save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size)
|
| 615 |
+
clear_sd(sd)
|
|
|
|
|
|
|
| 616 |
progress(0.25, desc=f"Saved temporary files to disk: {path}")
|
| 617 |
print(f"Saved temporary files to disk: {path}")
|
| 618 |
for filepath in glob.glob(f"{path}/*.safetensors"):
|
|
|
|
| 623 |
for k, v in sharded_sd.items():
|
| 624 |
sharded_sd[k] = v.to(device="cpu")
|
| 625 |
save_file(sharded_sd, str(filepath))
|
| 626 |
+
clear_sd(sharded_sd)
|
|
|
|
|
|
|
| 627 |
print(f"Processed temporary files: {str(filepath)}")
|
| 628 |
print(f"Loading temporary files from disk: {path}")
|
| 629 |
sd = load_sharded_safetensors(path)
|
|
|
|
| 687 |
quantization: bool = False, model_type: str = "dev", dequant: bool = False):
|
| 688 |
save_flux_other_diffusers(savepath, model_type)
|
| 689 |
normalize_flux_state_dict(loadpath, savepath, dtype, dequant)
|
| 690 |
+
clear_cache()
|
|
|
|
| 691 |
|
| 692 |
with torch.no_grad(): # Much lower memory consumption, but higher disk load
|
| 693 |
def flux_to_diffusers_lowmem(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16,
|
|
|
|
| 706 |
vae_sd_path = savepath.removesuffix("/") + "/vae"
|
| 707 |
vae_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors"
|
| 708 |
vae_sd_size = "10GB"
|
| 709 |
+
print_resource_usage() #
|
| 710 |
metadata = {"format": "pt", **read_safetensors_metadata(loadpath)}
|
| 711 |
+
clear_cache()
|
| 712 |
+
print_resource_usage() #
|
| 713 |
if "vae" not in use_original:
|
| 714 |
vae_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "VAE",
|
| 715 |
keys_flux_vae)
|
| 716 |
to_safetensors_flux_module(vae_sd, vae_sd_path, vae_sd_pattern, vae_sd_size,
|
| 717 |
quantization, "VAE", None)
|
| 718 |
+
clear_sd(vae_sd)
|
| 719 |
+
print_resource_usage() #
|
|
|
|
| 720 |
if "text_encoder" not in use_original:
|
| 721 |
clip_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "Text Encoder",
|
| 722 |
keys_flux_clip)
|
| 723 |
to_safetensors_flux_module(clip_sd, clip_sd_path, clip_sd_pattern, clip_sd_size,
|
| 724 |
quantization, "Text Encoder", None)
|
| 725 |
+
clear_sd(clip_sd)
|
| 726 |
+
print_resource_usage() #
|
|
|
|
| 727 |
if "text_encoder_2" not in use_original:
|
| 728 |
te_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Text Encoder 2",
|
| 729 |
keys_flux_t5xxl)
|
| 730 |
to_safetensors_flux_module(te_sd, te_sd_path, te_sd_pattern, te_sd_size,
|
| 731 |
quantization, "Text Encoder 2", None)
|
| 732 |
+
clear_sd(te_sd)
|
| 733 |
+
print_resource_usage() #
|
|
|
|
| 734 |
unet_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Transformer",
|
| 735 |
keys_flux_transformer)
|
| 736 |
+
clear_cache()
|
| 737 |
+
print_resource_usage() #
|
| 738 |
+
if not local:
|
| 739 |
+
os.remove(loadpath)
|
| 740 |
+
print("Deleted downloaded file.")
|
| 741 |
+
clear_cache()
|
| 742 |
+
print_resource_usage() #
|
| 743 |
to_safetensors_flux_module(unet_sd, unet_sd_path, unet_sd_pattern, unet_sd_size,
|
| 744 |
quantization, "Transformer", metadata)
|
| 745 |
+
clear_sd(unet_sd)
|
| 746 |
+
print_resource_usage() #
|
|
|
|
| 747 |
save_flux_other_diffusers(savepath, model_type, use_original)
|
| 748 |
+
print_resource_usage() #
|
| 749 |
|
| 750 |
with torch.no_grad(): # lowest memory consumption, but higheest disk load
|
| 751 |
def flux_to_diffusers_lowmem2(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16,
|
|
|
|
| 766 |
vae_sd_path = savepath.removesuffix("/") + "/vae"
|
| 767 |
vae_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors"
|
| 768 |
vae_sd_size = "10GB"
|
| 769 |
+
print_resource_usage() #
|
| 770 |
metadata = {"format": "pt", **read_safetensors_metadata(loadpath)}
|
| 771 |
+
clear_cache()
|
| 772 |
+
print_resource_usage() #
|
| 773 |
if "vae" not in use_original:
|
| 774 |
vae_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "VAE",
|
| 775 |
keys_flux_vae)
|
| 776 |
to_safetensors_flux_module(vae_sd, vae_sd_path, vae_sd_pattern, vae_sd_size,
|
| 777 |
quantization, "VAE", None)
|
| 778 |
+
clear_sd(vae_sd)
|
| 779 |
+
print_resource_usage() #
|
|
|
|
| 780 |
if "text_encoder" not in use_original:
|
| 781 |
clip_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "Text Encoder",
|
| 782 |
keys_flux_clip)
|
| 783 |
to_safetensors_flux_module(clip_sd, clip_sd_path, clip_sd_pattern, clip_sd_size,
|
| 784 |
quantization, "Text Encoder", None)
|
| 785 |
+
clear_sd(clip_sd)
|
| 786 |
+
print_resource_usage() #
|
|
|
|
| 787 |
if "text_encoder_2" not in use_original:
|
| 788 |
te_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Text Encoder 2",
|
| 789 |
keys_flux_t5xxl)
|
| 790 |
to_safetensors_flux_module(te_sd, te_sd_path, te_sd_pattern, te_sd_size,
|
| 791 |
quantization, "Text Encoder 2", None)
|
| 792 |
+
clear_sd(te_sd)
|
| 793 |
+
print_resource_usage() #
|
|
|
|
| 794 |
unet_sd = extract_normalized_flux_state_dict_sharded(loadpath, dtype, dequant,
|
| 795 |
unet_temp_path, unet_sd_pattern, unet_temp_size)
|
| 796 |
+
clear_cache()
|
| 797 |
+
print_resource_usage() #
|
| 798 |
unet_sd = convert_flux_transformer_sd_to_diffusers_sharded(unet_sd, unet_temp_path,
|
| 799 |
unet_sd_pattern, unet_temp_size)
|
| 800 |
+
clear_cache()
|
| 801 |
+
print_resource_usage() #
|
| 802 |
to_safetensors_flux_module(unet_sd, unet_sd_path, unet_sd_pattern, unet_sd_size,
|
| 803 |
quantization, "Transformer", metadata)
|
| 804 |
+
clear_sd(unet_sd)
|
| 805 |
+
print_resource_usage() #
|
|
|
|
| 806 |
save_flux_other_diffusers(savepath, model_type, use_original)
|
| 807 |
+
print_resource_usage() #
|
| 808 |
|
| 809 |
def convert_url_to_diffusers_flux(url, civitai_key="", is_upload_sf=False, data_type="bf16",
|
| 810 |
model_type="dev", dequant=False, use_original=["vae", "text_encoder"],
|
| 811 |
hf_user="", hf_repo="", q=None, progress=gr.Progress(track_tqdm=True)):
|
| 812 |
progress(0, desc="Start converting...")
|
| 813 |
temp_dir = "."
|
| 814 |
+
print_resource_usage() #
|
| 815 |
new_file = get_download_file(temp_dir, url, civitai_key)
|
| 816 |
if not new_file:
|
| 817 |
print(f"Not found: {url}")
|
|
|
|
| 844 |
model_type="dev", dequant=False, q=None, progress=gr.Progress(track_tqdm=True)):
|
| 845 |
progress(0, desc="Start converting...")
|
| 846 |
temp_dir = "."
|
| 847 |
+
print_resource_usage() #
|
| 848 |
new_file = get_download_file(temp_dir, url, civitai_key)
|
| 849 |
if not new_file:
|
| 850 |
print(f"Not found: {url}")
|