|
|
import random |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
from diffusers import ( |
|
|
AutoPipelineForText2Image, |
|
|
AutoPipelineForImage2Image, |
|
|
AutoPipelineForInpainting, |
|
|
DiffusionPipeline, |
|
|
AutoencoderKL, |
|
|
FluxControlNetModel, |
|
|
FluxMultiControlNetModel, |
|
|
) |
|
|
from huggingface_hub import hf_hub_download |
|
|
from diffusers.schedulers import * |
|
|
from huggingface_hub import hf_hub_download |
|
|
from sd_embed.embedding_funcs import get_weighted_text_embeddings_flux1 |
|
|
|
|
|
from .common_helpers import * |
|
|
|
|
|
|
|
|
def load_sd(): |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
models = [ |
|
|
{ |
|
|
"repo_id": "black-forest-labs/FLUX.1-dev", |
|
|
"loader": "flux", |
|
|
"compute_type": torch.bfloat16, |
|
|
} |
|
|
] |
|
|
|
|
|
for model in models: |
|
|
try: |
|
|
model["pipeline"] = AutoPipelineForText2Image.from_pretrained( |
|
|
model['repo_id'], |
|
|
vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to(device), |
|
|
torch_dtype = model['compute_type'], |
|
|
safety_checker = None, |
|
|
variant = "fp16" |
|
|
).to(device) |
|
|
except: |
|
|
model["pipeline"] = AutoPipelineForText2Image.from_pretrained( |
|
|
model['repo_id'], |
|
|
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device), |
|
|
torch_dtype = model['compute_type'], |
|
|
safety_checker = None |
|
|
).to(device) |
|
|
|
|
|
model["pipeline"].enable_model_cpu_offload() |
|
|
|
|
|
|
|
|
|
|
|
flux_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device) |
|
|
sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device) |
|
|
refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device) |
|
|
refiner.enable_model_cpu_offload() |
|
|
|
|
|
|
|
|
|
|
|
controlnet = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained( |
|
|
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", |
|
|
torch_dtype=torch.bfloat16 |
|
|
).to(device)]) |
|
|
|
|
|
return device, models, flux_vae, sdxl_vae, refiner, controlnet |
|
|
|
|
|
|
|
|
device, models, flux_vae, sdxl_vae, refiner, controlnet = load_sd() |
|
|
|
|
|
|
|
|
def get_control_mode(controlnet_config: ControlNetReq): |
|
|
control_mode = [] |
|
|
layers = ["canny", "tile", "depth", "blur", "pose", "gray", "low_quality"] |
|
|
|
|
|
for c in controlnet_config.controlnets: |
|
|
if c in layers: |
|
|
control_mode.append(layers.index(c)) |
|
|
|
|
|
return control_mode |
|
|
|
|
|
|
|
|
def get_pipe(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq): |
|
|
for m in models: |
|
|
if m['repo_id'] == request.model: |
|
|
pipe_args = { |
|
|
"pipeline": m['pipeline'], |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if request.controlnet_config: |
|
|
pipe_args["control_mode"] = get_control_mode(request.controlnet_config) |
|
|
pipe_args["controlnet"] = [controlnet] |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(request, BaseReq): |
|
|
pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args) |
|
|
elif isinstance(request, BaseImg2ImgReq): |
|
|
pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args) |
|
|
elif isinstance(request, BaseInpaintReq): |
|
|
pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args) |
|
|
|
|
|
|
|
|
|
|
|
if request.vae: |
|
|
pipe_args["pipeline"].vae = flux_vae |
|
|
elif not request.vae: |
|
|
pipe_args["pipeline"].vae = None |
|
|
|
|
|
|
|
|
|
|
|
pipe_args["pipeline"].scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe_args["pipeline"].scheduler.config) |
|
|
|
|
|
|
|
|
|
|
|
if request.loras: |
|
|
for i, lora in enumerate(request.loras): |
|
|
pipe_args["pipeline"].load_lora_weights(request.lora['repo_id'], adapter_name=f"lora_{i}") |
|
|
adapter_names = [f"lora_{i}" for i in range(len(request.loras))] |
|
|
adapter_weights = [lora['weight'] for lora in request.loras] |
|
|
|
|
|
if request.fast_generation: |
|
|
hyper_lora = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors") |
|
|
hyper_weight = 0.125 |
|
|
pipe_args["pipeline"].load_lora_weights(hyper_lora, adapter_name="hyper_lora") |
|
|
adapter_names.append("hyper_lora") |
|
|
adapter_weights.append(hyper_weight) |
|
|
|
|
|
pipe_args["pipeline"].set_adapters(adapter_names, adapter_weights) |
|
|
|
|
|
return pipe_args |
|
|
|
|
|
|
|
|
def get_prompt_attention(pipeline, prompt): |
|
|
return get_weighted_text_embeddings_flux1(pipeline, prompt) |
|
|
|
|
|
|
|
|
|
|
|
def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq): |
|
|
pipe_args = get_pipe(request) |
|
|
pipeline = pipe_args["pipeline"] |
|
|
try: |
|
|
positive_prompt_embeds, positive_prompt_pooled = get_prompt_attention(pipeline, request.prompt) |
|
|
|
|
|
|
|
|
args = { |
|
|
'prompt_embeds': positive_prompt_embeds, |
|
|
'pooled_prompt_embeds': positive_prompt_pooled, |
|
|
'height': request.height, |
|
|
'width': request.width, |
|
|
'num_images_per_prompt': request.num_images_per_prompt, |
|
|
'num_inference_steps': request.num_inference_steps, |
|
|
'guidance_scale': request.guidance_scale, |
|
|
'generator': [torch.Generator(device=device).manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator(device=device).manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)], |
|
|
} |
|
|
|
|
|
if request.controlnet_config: |
|
|
args['control_mode'] = get_control_mode(request.controlnet_config) |
|
|
args['control_images'] = get_controlnet_images(request.controlnet_config, request.height, request.width, request.resize_mode) |
|
|
args['controlnet_conditioning_scale'] = request.controlnet_config.controlnet_conditioning_scale |
|
|
|
|
|
if isinstance(request, (BaseImg2ImgReq, BaseInpaintReq)): |
|
|
args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)[0] |
|
|
args['strength'] = request.strength |
|
|
|
|
|
if isinstance(request, BaseInpaintReq): |
|
|
args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)[0] |
|
|
|
|
|
|
|
|
images = pipeline(**args).images |
|
|
|
|
|
|
|
|
if request.refiner: |
|
|
images = refiner(image=images, prompt=request.prompt, num_inference_steps=40, denoising_start=0.7).images |
|
|
|
|
|
cleanup(pipeline, request.loras) |
|
|
|
|
|
return images |
|
|
except Exception as e: |
|
|
cleanup(pipeline, request.loras) |
|
|
raise gr.Error(f"Error: {e}") |
|
|
|