| import numpy as np |
| import torch |
| import cv2 as cv |
| import random |
| import os |
| import spaces |
| import gradio as gr |
|
|
| from rembg import remove |
| from PIL import Image |
| from transformers import pipeline |
| from controlnet_aux import MLSDdetector, HEDdetector, NormalBaeDetector, LineartDetector |
| from peft import PeftModel, LoraConfig |
| from diffusers import ( |
| DiffusionPipeline, |
| StableDiffusionPipeline, |
| StableDiffusionControlNetPipeline, |
| StableDiffusionControlNetImg2ImgPipeline, |
| DPMSolverMultistepScheduler, |
| PNDMScheduler, |
| ControlNetModel |
| ) |
| from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback |
| from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg, retrieve_timesteps |
| from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput |
| from diffusers.utils.torch_utils import randn_tensor |
| from diffusers.utils import load_image, make_image_grid |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| if torch.cuda.is_available(): |
| torch_dtype = torch.float16 |
| else: |
| torch_dtype = torch.float32 |
|
|
| MAX_SEED = np.iinfo(np.int32).max |
|
|
| default_model = 'CompVis/stable-diffusion-v1-4' |
| LoRA_path = 'new_model' |
|
|
| CONTROLNET_MODE = { |
| "Canny Edge Detection" : "lllyasviel/control_v11p_sd15_canny", |
| "Pixel to Pixel": "lllyasviel/control_v11e_sd15_ip2p", |
| "HED edge detection (soft edge)" : "lllyasviel/control_v11p_sd15_softedge", |
| "Midas depth estimation" : "lllyasviel/control_v11f1p_sd15_depth", |
| "Surface Normal Estimation" : "lllyasviel/control_v11p_sd15_normalbae", |
| "Scribble-Based Generation" : "lllyasviel/control_v11p_sd15_scribble", |
| "Line Art Generation": "lllyasviel/control_v11p_sd15_lineart", |
| } |
|
|
| def get_pipe( |
| model_id, |
| use_controlnet, |
| controlnet_mode, |
| use_ip_adapter |
| ): |
|
|
| if use_controlnet and use_ip_adapter: |
| |
| print('Pipe with ControlNet and IPAdapter') |
|
|
| controlnet = ControlNetModel.from_pretrained( |
| CONTROLNET_MODE[controlnet_mode], |
| cache_dir="./models_cache" |
| ) |
|
|
| pipe = StableDiffusionControlNetPipeline.from_pretrained( |
| model_id if model_id!='Maria_Lashina_LoRA' else default_model, |
| torch_dtype=torch_dtype, |
| controlnet=controlnet, |
| safety_checker=None, |
| ).to(device) |
|
|
| pipe.load_ip_adapter( |
| "h94/IP-Adapter", |
| subfolder="models", |
| weight_name="ip-adapter-plus_sd15.bin", |
| ) |
|
|
| elif use_controlnet and not use_ip_adapter: |
|
|
| print('Pipe with ControlNet') |
|
|
| controlnet = ControlNetModel.from_pretrained( |
| CONTROLNET_MODE[controlnet_mode], |
| cache_dir="./models_cache" |
| ) |
| |
| pipe = StableDiffusionControlNetPipeline.from_pretrained( |
| model_id if model_id!='Maria_Lashina_LoRA' else default_model, |
| torch_dtype=torch_dtype, |
| controlnet=controlnet, |
| safety_checker=None, |
| ).to(device) |
|
|
| elif use_ip_adapter and not use_controlnet: |
|
|
| print('Pipe with IpAdapter') |
|
|
| pipe = StableDiffusionPipeline.from_pretrained( |
| model_id if model_id!='Maria_Lashina_LoRA' else default_model, |
| torch_dtype=torch_dtype, |
| safety_checker=None, |
| ).to(device) |
| |
| pipe.load_ip_adapter( |
| "h94/IP-Adapter", |
| subfolder="models", |
| weight_name="ip-adapter-plus_sd15.bin") |
|
|
| elif not use_controlnet and not use_ip_adapter: |
|
|
| print('Pipe with only SD') |
|
|
| pipe = StableDiffusionPipeline.from_pretrained( |
| model_id if model_id!='Maria_Lashina_LoRA' else default_model, |
| torch_dtype=torch_dtype, |
| safety_checker=None, |
| ).to(device) |
|
|
|
|
| if model_id == 'Maria_Lashina_LoRA': |
| adapter_name = 'cartoonish mouse' |
| unet_sub_dir = os.path.join(LoRA_path, "unet") |
| text_encoder_sub_dir = os.path.join(LoRA_path, "text_encoder") |
| |
| pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name) |
|
|
| pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name) |
| |
| return pipe |
|
|
| def prepare_controlnet_image(controlnet_image, mode): |
| if mode == "Canny Edge Detection": |
| image = cv.Canny(controlnet_image, 80, 160) |
| image = np.repeat(image[:, :, None], 3, axis=2) |
| image = Image.fromarray(image) |
|
|
| elif mode == "Pixel to Pixel": |
| image = Image.fromarray(controlnet_image).convert('RGB') |
|
|
| elif mode == "HED edge detection (soft edge)": |
| processor = HEDdetector.from_pretrained('lllyasviel/Annotators') |
| image = processor(controlnet_image) |
|
|
| elif mode == "Midas depth estimation": |
| depth_estimator = pipeline('depth-estimation') |
| image = depth_estimator(Image.fromarray(controlnet_image))['depth'] |
| image = np.array(image) |
| image = image[:, :, None] |
| image = np.concatenate([image, image, image], axis=2) |
| image = Image.fromarray(image) |
|
|
| elif mode == "Surface Normal Estimation": |
| processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") |
| image = processor(controlnet_image) |
|
|
| elif mode == "Scribble-Based Generation": |
| processor = HEDdetector.from_pretrained('lllyasviel/Annotators') |
| image = processor(controlnet_image, scribble=True) |
|
|
| elif mode == "Line Art Generation": |
| processor = LineartDetector.from_pretrained("lllyasviel/Annotators") |
| image = processor(controlnet_image) |
|
|
| else: |
| image = controlnet_image |
|
|
| return image |
|
|
| |
| def infer( |
| model_id, |
| prompt, |
| negative_prompt, |
| seed, |
| randomize_seed, |
| width, |
| height, |
| guidance_scale, |
| lora_scale, |
| num_inference_steps, |
| use_controlnet, |
| control_strength, |
| controlnet_mode, |
| controlnet_image, |
| use_ip_adapter, |
| ip_adapter_scale, |
| ip_adapter_image, |
| delete_background, |
| progress=gr.Progress(track_tqdm=True), |
| ): |
| if randomize_seed: |
| seed = random.randint(0, MAX_SEED) |
|
|
| generator = torch.Generator().manual_seed(seed) |
|
|
| if not use_controlnet and not use_ip_adapter: |
|
|
| pipe = get_pipe(model_id, use_controlnet, controlnet_mode, use_ip_adapter) |
|
|
| image = pipe( |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| guidance_scale=guidance_scale, |
| cross_attention_kwargs={"scale": lora_scale}, |
| num_inference_steps=num_inference_steps, |
| width=width, |
| height=height, |
| generator=generator |
| ).images[0] |
|
|
| elif use_controlnet and not use_ip_adapter: |
| |
| cn_image = prepare_controlnet_image(controlnet_image, controlnet_mode) |
|
|
| pipe = get_pipe(model_id, use_controlnet, controlnet_mode, use_ip_adapter) |
| |
| image = pipe( |
| prompt, |
| cn_image, |
| controlnet_conditioning_scale=control_strength, |
| negative_prompt=negative_prompt, |
| guidance_scale=guidance_scale, |
| cross_attention_kwargs={"scale": lora_scale}, |
| num_inference_steps=num_inference_steps, |
| width=width, |
| height=height, |
| generator=generator |
| ).images[0] |
|
|
| elif not use_controlnet and use_ip_adapter: |
|
|
| pipe = get_pipe(model_id, use_controlnet, controlnet_mode, use_ip_adapter) |
|
|
| pipe.set_ip_adapter_scale(ip_adapter_scale) |
|
|
| image = pipe( |
| prompt, |
| ip_adapter_image=ip_adapter_image, |
| negative_prompt=negative_prompt, |
| guidance_scale=guidance_scale, |
| cross_attention_kwargs={"scale": lora_scale}, |
| num_inference_steps=num_inference_steps, |
| width=width, |
| height=height, |
| generator=generator |
| ).images[0] |
|
|
| elif use_controlnet and use_ip_adapter: |
|
|
| cn_image = prepare_controlnet_image(controlnet_image, controlnet_mode) |
|
|
| pipe = get_pipe(model_id, use_controlnet, controlnet_mode, use_ip_adapter) |
| |
| pipe.set_ip_adapter_scale(ip_adapter_scale) |
|
|
| image = pipe( |
| prompt, |
| cn_image, |
| controlnet_conditioning_scale=control_strength, |
| ip_adapter_image=ip_adapter_image, |
| negative_prompt=negative_prompt, |
| guidance_scale=guidance_scale, |
| cross_attention_kwargs={"scale": lora_scale}, |
| num_inference_steps=num_inference_steps, |
| width=width, |
| height=height, |
| generator=generator |
| ).images[0] |
|
|
| if delete_background: |
| image = remove(image) |
|
|
| return image, seed |