Commit
·
daf9c75
1
Parent(s):
07dc8e6
Refactor ControlNetReq class to remove unused import and add controlnets, control_images, and controlnet_conditioning_scale attributes
Browse files- modules/events/flux_events.py +9 -8
- modules/helpers/flux_helpers.py +2 -54
- modules/pipelines/common_pipelines.py +19 -0
- modules/pipelines/flux_pipelines.py +51 -12
- tabs/image_tab.py +2 -2
modules/events/flux_events.py
CHANGED
|
@@ -5,14 +5,15 @@ import spaces
|
|
| 5 |
import gradio as gr
|
| 6 |
from huggingface_hub import ModelCard
|
| 7 |
|
| 8 |
-
from modules.helpers.
|
| 9 |
-
from
|
|
|
|
| 10 |
|
| 11 |
loras = flux_loras
|
| 12 |
|
| 13 |
|
| 14 |
# Event functions
|
| 15 |
-
def update_fast_generation(
|
| 16 |
if fast_generation:
|
| 17 |
return (
|
| 18 |
gr.update(
|
|
@@ -125,7 +126,7 @@ def update_selected_lora(custom_lora):
|
|
| 125 |
)
|
| 126 |
|
| 127 |
|
| 128 |
-
def add_to_enabled_loras(
|
| 129 |
lora_data = loras
|
| 130 |
try:
|
| 131 |
selected_lora = int(selected_lora)
|
|
@@ -233,7 +234,7 @@ def generate_image(
|
|
| 233 |
"vae": vae,
|
| 234 |
"controlnet_config": None,
|
| 235 |
}
|
| 236 |
-
base_args =
|
| 237 |
|
| 238 |
if len(enabled_loras) > 0:
|
| 239 |
base_args.loras = []
|
|
@@ -252,7 +253,7 @@ def generate_image(
|
|
| 252 |
image = img2img_image
|
| 253 |
strength = float(img2img_strength)
|
| 254 |
|
| 255 |
-
base_args =
|
| 256 |
**base_args.__dict__,
|
| 257 |
image=image,
|
| 258 |
strength=strength
|
|
@@ -263,7 +264,7 @@ def generate_image(
|
|
| 263 |
strength = float(inpaint_strength)
|
| 264 |
|
| 265 |
if image and mask_image:
|
| 266 |
-
base_args =
|
| 267 |
**base_args.__dict__,
|
| 268 |
image=image,
|
| 269 |
mask_image=mask_image,
|
|
@@ -289,7 +290,7 @@ def generate_image(
|
|
| 289 |
base_args.controlnet_config.control_images.append(depth_image)
|
| 290 |
base_args.controlnet_config.controlnet_conditioning_scale.append(float(depth_strength))
|
| 291 |
else:
|
| 292 |
-
base_args =
|
| 293 |
|
| 294 |
return gr.update(
|
| 295 |
value=gen_img(base_args),
|
|
|
|
| 5 |
import gradio as gr
|
| 6 |
from huggingface_hub import ModelCard
|
| 7 |
|
| 8 |
+
from modules.helpers.common_helpers import ControlNetReq, BaseReq, BaseImg2ImgReq, BaseInpaintReq
|
| 9 |
+
from modules.helpers.flux_helpers import gen_img
|
| 10 |
+
from config import flux_loras
|
| 11 |
|
| 12 |
loras = flux_loras
|
| 13 |
|
| 14 |
|
| 15 |
# Event functions
|
| 16 |
+
def update_fast_generation(fast_generation):
|
| 17 |
if fast_generation:
|
| 18 |
return (
|
| 19 |
gr.update(
|
|
|
|
| 126 |
)
|
| 127 |
|
| 128 |
|
| 129 |
+
def add_to_enabled_loras(selected_lora, enabled_loras):
|
| 130 |
lora_data = loras
|
| 131 |
try:
|
| 132 |
selected_lora = int(selected_lora)
|
|
|
|
| 234 |
"vae": vae,
|
| 235 |
"controlnet_config": None,
|
| 236 |
}
|
| 237 |
+
base_args = BaseReq(**base_args)
|
| 238 |
|
| 239 |
if len(enabled_loras) > 0:
|
| 240 |
base_args.loras = []
|
|
|
|
| 253 |
image = img2img_image
|
| 254 |
strength = float(img2img_strength)
|
| 255 |
|
| 256 |
+
base_args = BaseImg2ImgReq(
|
| 257 |
**base_args.__dict__,
|
| 258 |
image=image,
|
| 259 |
strength=strength
|
|
|
|
| 264 |
strength = float(inpaint_strength)
|
| 265 |
|
| 266 |
if image and mask_image:
|
| 267 |
+
base_args = BaseInpaintReq(
|
| 268 |
**base_args.__dict__,
|
| 269 |
image=image,
|
| 270 |
mask_image=mask_image,
|
|
|
|
| 290 |
base_args.controlnet_config.control_images.append(depth_image)
|
| 291 |
base_args.controlnet_config.controlnet_conditioning_scale.append(float(depth_strength))
|
| 292 |
else:
|
| 293 |
+
base_args = BaseReq(**base_args.__dict__)
|
| 294 |
|
| 295 |
return gr.update(
|
| 296 |
value=gen_img(base_args),
|
modules/helpers/flux_helpers.py
CHANGED
|
@@ -6,10 +6,6 @@ from diffusers import (
|
|
| 6 |
AutoPipelineForText2Image,
|
| 7 |
AutoPipelineForImage2Image,
|
| 8 |
AutoPipelineForInpainting,
|
| 9 |
-
DiffusionPipeline,
|
| 10 |
-
AutoencoderKL,
|
| 11 |
-
FluxControlNetModel,
|
| 12 |
-
FluxMultiControlNetModel,
|
| 13 |
)
|
| 14 |
from huggingface_hub import hf_hub_download
|
| 15 |
from diffusers.schedulers import *
|
|
@@ -17,56 +13,8 @@ from huggingface_hub import hf_hub_download
|
|
| 17 |
from sd_embed.embedding_funcs import get_weighted_text_embeddings_flux1
|
| 18 |
|
| 19 |
from .common_helpers import ControlNetReq, BaseReq, BaseImg2ImgReq, BaseInpaintReq, cleanup, get_controlnet_images, resize_images
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def load_sd():
|
| 23 |
-
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 24 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 25 |
-
|
| 26 |
-
# Models
|
| 27 |
-
models = [
|
| 28 |
-
{
|
| 29 |
-
"repo_id": "black-forest-labs/FLUX.1-dev",
|
| 30 |
-
"loader": "flux",
|
| 31 |
-
"compute_type": torch.bfloat16,
|
| 32 |
-
}
|
| 33 |
-
]
|
| 34 |
-
|
| 35 |
-
for model in models:
|
| 36 |
-
try:
|
| 37 |
-
model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
|
| 38 |
-
model['repo_id'],
|
| 39 |
-
vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
|
| 40 |
-
torch_dtype=model['compute_type'],
|
| 41 |
-
safety_checker=None,
|
| 42 |
-
variant="fp16"
|
| 43 |
-
).to(device)
|
| 44 |
-
except:
|
| 45 |
-
model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
|
| 46 |
-
model['repo_id'],
|
| 47 |
-
vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
|
| 48 |
-
torch_dtype=model['compute_type'],
|
| 49 |
-
safety_checker=None
|
| 50 |
-
).to(device)
|
| 51 |
-
|
| 52 |
-
model["pipeline"].enable_model_cpu_offload()
|
| 53 |
-
|
| 54 |
-
# VAE n Refiner
|
| 55 |
-
flux_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
|
| 56 |
-
sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
|
| 57 |
-
refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
|
| 58 |
-
refiner.enable_model_cpu_offload()
|
| 59 |
-
|
| 60 |
-
# ControlNet
|
| 61 |
-
controlnet = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained(
|
| 62 |
-
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
| 63 |
-
torch_dtype=torch.bfloat16
|
| 64 |
-
).to(device)])
|
| 65 |
-
|
| 66 |
-
return device, models, flux_vae, sdxl_vae, refiner, controlnet
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
device, models, flux_vae, sdxl_vae, refiner, controlnet = load_sd()
|
| 70 |
|
| 71 |
|
| 72 |
def get_control_mode(controlnet_config: ControlNetReq):
|
|
|
|
| 6 |
AutoPipelineForText2Image,
|
| 7 |
AutoPipelineForImage2Image,
|
| 8 |
AutoPipelineForInpainting,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
)
|
| 10 |
from huggingface_hub import hf_hub_download
|
| 11 |
from diffusers.schedulers import *
|
|
|
|
| 13 |
from sd_embed.embedding_funcs import get_weighted_text_embeddings_flux1
|
| 14 |
|
| 15 |
from .common_helpers import ControlNetReq, BaseReq, BaseImg2ImgReq, BaseInpaintReq, cleanup, get_controlnet_images, resize_images
|
| 16 |
+
from modules.pipelines.flux_pipelines import device, models, flux_vae, controlnet
|
| 17 |
+
from modules.pipelines.common_pipelines import refiner
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
def get_control_mode(controlnet_config: ControlNetReq):
|
modules/pipelines/common_pipelines.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from diffusers import (
|
| 3 |
+
DiffusionPipeline,
|
| 4 |
+
AutoencoderKL,
|
| 5 |
+
)
|
| 6 |
+
from diffusers.schedulers import *
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_common():
|
| 10 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
+
|
| 12 |
+
# VAE n Refiner
|
| 13 |
+
sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
|
| 14 |
+
refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
|
| 15 |
+
refiner.enable_model_cpu_offload()
|
| 16 |
+
|
| 17 |
+
return refiner, sdxl_vae
|
| 18 |
+
|
| 19 |
+
refiner, sdxl_vae = load_common()
|
modules/pipelines/flux_pipelines.py
CHANGED
|
@@ -1,19 +1,58 @@
|
|
| 1 |
-
# modules/pipelines/flux_pipelines.py
|
| 2 |
|
| 3 |
import torch
|
| 4 |
-
from diffusers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
def load_flux():
|
| 7 |
-
#
|
| 8 |
-
|
| 9 |
-
return device, models, flux_vae, controlnet
|
| 10 |
|
| 11 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
import torch
|
| 14 |
-
from diffusers import AutoPipelineForText2Image, AutoencoderKL
|
| 15 |
|
| 16 |
-
|
| 17 |
-
# Load SDXL models and pipelines
|
| 18 |
-
# ...
|
| 19 |
-
return device, models, sdxl_vae, controlnet
|
|
|
|
|
|
|
| 1 |
|
| 2 |
import torch
|
| 3 |
+
from diffusers import (
|
| 4 |
+
AutoPipelineForText2Image,
|
| 5 |
+
DiffusionPipeline,
|
| 6 |
+
AutoencoderKL,
|
| 7 |
+
FluxControlNetModel,
|
| 8 |
+
FluxMultiControlNetModel,
|
| 9 |
+
)
|
| 10 |
+
from diffusers.schedulers import *
|
| 11 |
+
|
| 12 |
+
|
| 13 |
|
| 14 |
def load_flux():
|
| 15 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 17 |
|
| 18 |
+
# Models
|
| 19 |
+
models = [
|
| 20 |
+
{
|
| 21 |
+
"repo_id": "black-forest-labs/FLUX.1-dev",
|
| 22 |
+
"loader": "flux",
|
| 23 |
+
"compute_type": torch.bfloat16,
|
| 24 |
+
}
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
for model in models:
|
| 28 |
+
try:
|
| 29 |
+
model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
|
| 30 |
+
model['repo_id'],
|
| 31 |
+
vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
|
| 32 |
+
torch_dtype=model['compute_type'],
|
| 33 |
+
safety_checker=None,
|
| 34 |
+
variant="fp16"
|
| 35 |
+
).to(device)
|
| 36 |
+
except:
|
| 37 |
+
model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
|
| 38 |
+
model['repo_id'],
|
| 39 |
+
vae=AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device),
|
| 40 |
+
torch_dtype=model['compute_type'],
|
| 41 |
+
safety_checker=None
|
| 42 |
+
).to(device)
|
| 43 |
+
|
| 44 |
+
model["pipeline"].enable_model_cpu_offload()
|
| 45 |
+
|
| 46 |
+
# VAE n Refiner
|
| 47 |
+
flux_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
|
| 48 |
+
|
| 49 |
+
# ControlNet
|
| 50 |
+
controlnet = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained(
|
| 51 |
+
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
| 52 |
+
torch_dtype=torch.bfloat16
|
| 53 |
+
).to(device)])
|
| 54 |
+
|
| 55 |
+
return device, models, flux_vae, controlnet
|
| 56 |
|
|
|
|
|
|
|
| 57 |
|
| 58 |
+
device, models, flux_vae, controlnet = load_flux()
|
|
|
|
|
|
|
|
|
tabs/image_tab.py
CHANGED
|
@@ -144,13 +144,13 @@ def flux_tab():
|
|
| 144 |
|
| 145 |
# Events
|
| 146 |
# Base Options
|
| 147 |
-
fast_generation.change(update_fast_generation, [
|
| 148 |
|
| 149 |
|
| 150 |
# Lora Gallery
|
| 151 |
lora_gallery.select(selected_lora_from_gallery, None, selected_lora)
|
| 152 |
custom_lora.change(update_selected_lora, custom_lora, [custom_lora, selected_lora])
|
| 153 |
-
add_lora.click(add_to_enabled_loras, [
|
| 154 |
enabled_loras.change(update_lora_sliders, enabled_loras, [lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, lora_remove_0, lora_remove_1, lora_remove_2, lora_remove_3, lora_remove_4, lora_remove_5]) # type: ignore
|
| 155 |
|
| 156 |
for i in range(6):
|
|
|
|
| 144 |
|
| 145 |
# Events
|
| 146 |
# Base Options
|
| 147 |
+
fast_generation.change(update_fast_generation, [fast_generation], [image_guidance_scale, image_num_inference_steps]) # Fast Generation # type: ignore
|
| 148 |
|
| 149 |
|
| 150 |
# Lora Gallery
|
| 151 |
lora_gallery.select(selected_lora_from_gallery, None, selected_lora)
|
| 152 |
custom_lora.change(update_selected_lora, custom_lora, [custom_lora, selected_lora])
|
| 153 |
+
add_lora.click(add_to_enabled_loras, [selected_lora, enabled_loras], [selected_lora, custom_lora_info, enabled_loras])
|
| 154 |
enabled_loras.change(update_lora_sliders, enabled_loras, [lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, lora_remove_0, lora_remove_1, lora_remove_2, lora_remove_3, lora_remove_4, lora_remove_5]) # type: ignore
|
| 155 |
|
| 156 |
for i in range(6):
|