Spaces:
Running on Zero
Running on Zero
| # functions to load models from config | |
| import numpy as np | |
| import torch | |
| import re | |
| import os | |
| import math | |
| from diffusers import ( | |
| BitsAndBytesConfig | |
| ) | |
| from diffusers import AutoencoderTiny | |
| from src.flair.pipelines import sd3 | |
| import os | |
| from huggingface_hub import login | |
| # If you really want an explicit login call: | |
| login(token=os.environ["hf_token"]) | |
| def load_sd3(config, device): | |
| if isinstance(device, list): | |
| device = device[0] | |
| if config["quantize"]: | |
| nf4_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| else: | |
| nf4_config = None | |
| if config["model"] == "SD3.5-large": | |
| pipe = sd3.SD3Wrapper.from_pretrained( | |
| "stabilityai/stable-diffusion-3.5-large", torch_dtype=torch.bfloat16, quantization_config=nf4_config | |
| ) | |
| elif config["model"] == "SD3.5-large-turbo": | |
| pipe = sd3.SD3Wrapper.from_pretrained( | |
| "stabilityai/stable-diffusion-3.5-large-turbo", torch_dtype=torch.bfloat16, quantization_config=nf4_config, | |
| ) | |
| else: | |
| pipe = sd3.SD3Wrapper.from_pretrained("stabilityai/stable-diffusion-3.5-medium", torch_dtype=torch.float16, quantization_config=nf4_config) | |
| # maybe use tiny autoencoder | |
| if config["use_tiny_ae"]: | |
| pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd3", torch_dtype=torch.float16) | |
| # encode prompts | |
| inp_kwargs_list = [] | |
| prompts = config["prompt"] | |
| pipe._guidance_scale = config["guidance"] | |
| pipe._joint_attention_kwargs = {"ip_adapter_image_embeds": None} | |
| for prompt in prompts: | |
| print(f"Generating prompt embeddings for: {prompt}") | |
| pipe.text_encoder.to(device).to(torch.bfloat16) | |
| pipe.text_encoder_2.to(device).to(torch.bfloat16) | |
| pipe.text_encoder_3.to(device).to(torch.bfloat16) | |
| # encode | |
| ( | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| pooled_prompt_embeds, | |
| negative_pooled_prompt_embeds, | |
| ) = pipe.encode_prompt( | |
| prompt=prompt, | |
| prompt_2=prompt, | |
| prompt_3=prompt, | |
| negative_prompt=config["negative_prompt"], | |
| negative_prompt_2=config["negative_prompt"], | |
| negative_prompt_3=config["negative_prompt"], | |
| do_classifier_free_guidance=pipe.do_classifier_free_guidance, | |
| prompt_embeds=None, | |
| negative_prompt_embeds=None, | |
| pooled_prompt_embeds=None, | |
| negative_pooled_prompt_embeds=None, | |
| device=device, | |
| clip_skip=None, | |
| num_images_per_prompt=1, | |
| max_sequence_length=256, | |
| lora_scale=None, | |
| ) | |
| if pipe.do_classifier_free_guidance: | |
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) | |
| pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) | |
| inp_kwargs = { | |
| "prompt_embeds": prompt_embeds, | |
| "pooled_prompt_embeds": pooled_prompt_embeds, | |
| "guidance": config["guidance"], | |
| } | |
| inp_kwargs_list.append(inp_kwargs) | |
| pipe.vae.to(device).to(torch.bfloat16) | |
| pipe.transformer.to(device).to(torch.bfloat16) | |
| return pipe, inp_kwargs_list | |
| def load_model(config, device=["cuda"]): | |
| if re.match(r"SD3*", config["model"]): | |
| return load_sd3(config, device) | |
| else: | |
| raise ValueError(f"Unknown model type {config['model']}") | |