Spaces:
Runtime error
Runtime error
Update sd/utils/utils.py
Browse files- sd/utils/utils.py +67 -77
sd/utils/utils.py
CHANGED
|
@@ -1,78 +1,68 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from diffusers import (ControlNetModel,
|
| 3 |
-
StableDiffusionXLControlNetImg2ImgPipeline,
|
| 4 |
-
AutoencoderKL,
|
| 5 |
-
T2IAdapter,
|
| 6 |
-
StableDiffusionXLAdapterPipeline,
|
| 7 |
-
EulerAncestralDiscreteScheduler)
|
| 8 |
-
|
| 9 |
-
from controlnet_aux.pidi import PidiNetDetector
|
| 10 |
-
|
| 11 |
-
from PIL import Image
|
| 12 |
-
import os
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def
|
| 40 |
-
if
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
load_lora(pipe, lora_path)
|
| 68 |
-
return pipe
|
| 69 |
-
|
| 70 |
-
elif adapter != None:
|
| 71 |
-
pipe=StableDiffusionXLAdapterPipeline.from_pretrained(model_name,
|
| 72 |
-
adapter=adapter,
|
| 73 |
-
vae=vae,
|
| 74 |
-
scheduler=scheduler,
|
| 75 |
-
torch_dtype=torch.float16,
|
| 76 |
-
variant="fp16")
|
| 77 |
-
load_lora(pipe, lora_path)
|
| 78 |
return pipe
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from diffusers import (ControlNetModel,
|
| 3 |
+
StableDiffusionXLControlNetImg2ImgPipeline,
|
| 4 |
+
AutoencoderKL,
|
| 5 |
+
T2IAdapter,
|
| 6 |
+
StableDiffusionXLAdapterPipeline,
|
| 7 |
+
EulerAncestralDiscreteScheduler)
|
| 8 |
+
|
| 9 |
+
from controlnet_aux.pidi import PidiNetDetector
|
| 10 |
+
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_vae(model_name="madebyollin/sdxl-vae-fp16-fix"):
|
| 16 |
+
return AutoencoderKL.from_pretrained(model_name, torch_dtype=torch.float16)
|
| 17 |
+
|
| 18 |
+
def get_controlnet(model_name="diffusers/controlnet-canny-sdxl-1.0"):
|
| 19 |
+
return ControlNetModel.from_pretrained(model_name, torch_dtype=torch.float16)
|
| 20 |
+
|
| 21 |
+
def get_adapter(model_name="Adapter/t2iadapter", subfolder="sketch_sdxl_1.0",
|
| 22 |
+
adapter_type="full_adapter_xl"):
|
| 23 |
+
if adapter_type == "full_adapter_xl":
|
| 24 |
+
return T2IAdapter.from_pretrained(model_name,
|
| 25 |
+
subfolder=subfolder,
|
| 26 |
+
torch_dtype=torch.float16,
|
| 27 |
+
adapter_type=adapter_type)
|
| 28 |
+
|
| 29 |
+
def get_scheduler(model_name, scheduler_type="discrete"):
|
| 30 |
+
if scheduler_type == "discrete":
|
| 31 |
+
return EulerAncestralDiscreteScheduler.from_pretrained(model_name, subfolder="scheduler")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_detector(model_name="lllyasviel/Annotators", model_type='pidi'):
|
| 35 |
+
if model_type == 'pidi':
|
| 36 |
+
return PidiNetDetector.from_pretrained(model_name)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def load_lora(pipe, lora_path=None):
|
| 40 |
+
if lora_path != None:
|
| 41 |
+
try:
|
| 42 |
+
lora_dir='./'+'/'.join(lora_path.split("/")[:-1])
|
| 43 |
+
lora_name=lora_path.split("/")[-1]
|
| 44 |
+
pipe.load_lora_weights(lora_dir, weight_name=lora_name)
|
| 45 |
+
except Exception as ex:
|
| 46 |
+
print(ex)
|
| 47 |
+
#return pipe
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_pipe(vae, model_name, controlnet=None, adapter=None, scheduler=None, lora_path=None):
|
| 51 |
+
if controlnet!=None:
|
| 52 |
+
pipe=StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(model_name,
|
| 53 |
+
controlnet=controlnet,
|
| 54 |
+
vae=vae,
|
| 55 |
+
torch_dtype=torch.float16)
|
| 56 |
+
|
| 57 |
+
load_lora(pipe, lora_path)
|
| 58 |
+
return pipe
|
| 59 |
+
|
| 60 |
+
elif adapter != None:
|
| 61 |
+
pipe=StableDiffusionXLAdapterPipeline.from_pretrained(model_name,
|
| 62 |
+
adapter=adapter,
|
| 63 |
+
vae=vae,
|
| 64 |
+
scheduler=scheduler,
|
| 65 |
+
torch_dtype=torch.float16,
|
| 66 |
+
variant="fp16")
|
| 67 |
+
load_lora(pipe, lora_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
return pipe
|