Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -786,66 +786,38 @@ def load_illustrious_xl(
|
|
| 786 |
filename: str = "illustriousXL_v01.safetensors",
|
| 787 |
device: str = "cuda"
|
| 788 |
) -> Tuple[UNet2DConditionModel, AutoencoderKL, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPTokenizer]:
|
| 789 |
-
"""Load Illustrious XL from single safetensors file."""
|
|
|
|
| 790 |
|
| 791 |
-
print(f"π₯
|
| 792 |
|
|
|
|
| 793 |
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
|
| 794 |
print(f"β Downloaded: {checkpoint_path}")
|
| 795 |
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
# Load UNet from SDXL base config, then load weights
|
| 803 |
-
print("ποΈ Initializing SDXL UNet...")
|
| 804 |
-
unet = UNet2DConditionModel.from_pretrained(
|
| 805 |
-
"stabilityai/stable-diffusion-xl-base-1.0",
|
| 806 |
-
subfolder="unet",
|
| 807 |
-
torch_dtype=torch.float16
|
| 808 |
)
|
| 809 |
|
| 810 |
-
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
"stabilityai/stable-diffusion-xl-base-1.0",
|
| 818 |
-
subfolder="vae",
|
| 819 |
-
torch_dtype=torch.float16
|
| 820 |
-
)
|
| 821 |
-
|
| 822 |
-
if components["vae"]:
|
| 823 |
-
missing, unexpected = vae.load_state_dict(components["vae"], strict=False)
|
| 824 |
-
print(f" VAE: {len(missing)} missing, {len(unexpected)} unexpected keys")
|
| 825 |
-
|
| 826 |
-
# Load CLIP-L
|
| 827 |
-
print("ποΈ Loading CLIP-L...")
|
| 828 |
-
text_encoder = CLIPTextModel.from_pretrained(
|
| 829 |
-
"openai/clip-vit-large-patch14",
|
| 830 |
-
torch_dtype=torch.float16
|
| 831 |
-
)
|
| 832 |
-
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
| 833 |
-
|
| 834 |
-
# Load CLIP-G
|
| 835 |
-
print("ποΈ Loading CLIP-G...")
|
| 836 |
-
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
|
| 837 |
-
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
| 838 |
-
torch_dtype=torch.float16
|
| 839 |
-
)
|
| 840 |
-
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
| 841 |
|
| 842 |
-
#
|
| 843 |
-
|
| 844 |
-
|
| 845 |
-
text_encoder = text_encoder.to(device)
|
| 846 |
-
text_encoder_2 = text_encoder_2.to(device)
|
| 847 |
|
| 848 |
print("β
Illustrious XL loaded!")
|
|
|
|
|
|
|
| 849 |
|
| 850 |
return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2
|
| 851 |
|
|
|
|
| 786 |
filename: str = "illustriousXL_v01.safetensors",
|
| 787 |
device: str = "cuda"
|
| 788 |
) -> Tuple[UNet2DConditionModel, AutoencoderKL, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPTokenizer]:
|
| 789 |
+
"""Load Illustrious XL from single safetensors file using diffusers' single-file loader."""
|
| 790 |
+
from diffusers import StableDiffusionXLPipeline
|
| 791 |
|
| 792 |
+
print(f"π₯ Loading Illustrious XL: {repo_id}/{filename}")
|
| 793 |
|
| 794 |
+
# Download the checkpoint
|
| 795 |
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="model")
|
| 796 |
print(f"β Downloaded: {checkpoint_path}")
|
| 797 |
|
| 798 |
+
# Use diffusers' built-in single-file loader which handles key remapping
|
| 799 |
+
print("π¦ Loading with StableDiffusionXLPipeline.from_single_file()...")
|
| 800 |
+
pipe = StableDiffusionXLPipeline.from_single_file(
|
| 801 |
+
checkpoint_path,
|
| 802 |
+
torch_dtype=torch.float16,
|
| 803 |
+
use_safetensors=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 804 |
)
|
| 805 |
|
| 806 |
+
# Extract components
|
| 807 |
+
unet = pipe.unet.to(device)
|
| 808 |
+
vae = pipe.vae.to(device)
|
| 809 |
+
text_encoder = pipe.text_encoder.to(device)
|
| 810 |
+
text_encoder_2 = pipe.text_encoder_2.to(device)
|
| 811 |
+
tokenizer = pipe.tokenizer
|
| 812 |
+
tokenizer_2 = pipe.tokenizer_2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 813 |
|
| 814 |
+
# Clean up the pipeline to free memory
|
| 815 |
+
del pipe
|
| 816 |
+
torch.cuda.empty_cache()
|
|
|
|
|
|
|
| 817 |
|
| 818 |
print("β
Illustrious XL loaded!")
|
| 819 |
+
print(f" UNet params: {sum(p.numel() for p in unet.parameters()):,}")
|
| 820 |
+
print(f" VAE params: {sum(p.numel() for p in vae.parameters()):,}")
|
| 821 |
|
| 822 |
return unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2
|
| 823 |
|