File size: 3,577 Bytes
0c32ef6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from diffusers import (
  ControlNetModel,
  StableDiffusionImg2ImgPipeline,
  StableDiffusionControlNetImg2ImgPipeline,
)
from compel import Compel
from PIL import Image
import cv2
import gc
import gradio
import numpy
import torch

base_model = "SimianLuo/LCM_Dreamshaper_v7"
controlnet_model = "lllyasviel/control_v11p_sd15_canny"
device = "cuda"
dtype = torch.float16
width = 512
height = 512

controlnet = ControlNetModel.from_pretrained(
  controlnet_model, tourch_dtype=dtype
)

pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
  base_model, controlnet=controlnet, safety_checker=None
).to(dtype=dtype)
pipe.enable_model_cpu_offload(device=device)
pipe.unet.to(memory_format=torch.channels_last)

compel_proc = Compel(
  tokenizer=pipe.tokenizer,
  text_encoder=pipe.text_encoder,
  truncate_long_prompts=False,
)

pipe_no_controlnet = StableDiffusionImg2ImgPipeline.from_pretrained(
  base_model, safety_checker=None
).to(dtype=dtype)
pipe.enable_model_cpu_offload(device=device)
pipe_no_controlnet.enable_model_cpu_offload()

compel_proc_no_controlnet = Compel(
  tokenizer=pipe_no_controlnet.tokenizer,
  text_encoder=pipe_no_controlnet.text_encoder,
  truncate_long_prompts=False,
)

def predict(
  prompt: str,
  image: Image,
  use_controlnet: bool,
  generator: int,
  num_inference_steps: int,
  strength: float,
  guidance_scale: float,
  controlnet_conditioning_scale: float,
  canny_lower_threshold: int,
  canny_higher_threshold: int,
):
  if image is None:
    return None

  generator = torch.manual_seed(generator)
  # TODO: Keep the original ratio?
  image = image.resize((width, height))

  if use_controlnet:
    prompt_embeds = compel_proc(prompt)
    image_array = numpy.array(image)
    image_array = cv2.Canny(
      image_array,
      canny_lower_threshold,
      canny_higher_threshold
    )
    image_array = image_array[:, :, None]
    image_array = numpy.concatenate([image_array, image_array, image_array], axis=2)
    control_image = Image.fromarray(image_array)
    results = pipe(
      control_image=control_image,
      control_guidance_end=1.0,
      control_guidance_start=0.0,
      controlnet_conditioning_scale=controlnet_conditioning_scale,
      generator=generator,
      guidance_scale=guidance_scale,
      image=image,
      num_inference_steps=num_inference_steps,
      output_type="pil",
      prompt_embeds=prompt_embeds,
      strength=strength,
    )
    control_image.close()
  else:
    prompt_embeds = compel_proc_no_controlnet(prompt)
    results = pipe_no_controlnet(
      generator=generator,
      guidance_scale=guidance_scale,
      image=image,
      num_inference_steps=num_inference_steps,
      output_type="pil",
      prompt_embeds=prompt_embeds,
      strength=strength,
    )

  gc.collect()

  if len(results.images) > 0:
    return results.images[0]
  return None

app = gradio.Interface(
  fn=predict,
  inputs=[
    gradio.Textbox("Kirisame Marisa, Cute, Smiling, High quality, Realistic"), # prompt
    gradio.Image(type="pil"), # image
    gradio.Checkbox(True), # use_controlnet
    gradio.Slider(0, 2147483647, 2159232, step=1), # generator
    gradio.Slider(2, 15, 4, step=1), # num_inference_steps
    gradio.Slider(0.0, 1.0, 0.5, step=0.01), # strength
    gradio.Slider(0.0, 5.0, 0.2, step=0.01), # guidance_scale
    gradio.Slider(0.0, 1.0, 0.8, step=0.01), # controlnet_conditioning_scale
    gradio.Slider(0, 255, 100, step=1), # canny_lower_threshold
    gradio.Slider(0, 255, 200, step=1), # canny_higher_threshold
  ],
  outputs=gradio.Image(type="pil")
)
app.launch()