sand74 commited on
Commit
bd03ad6
·
verified ·
1 Parent(s): 03fcb80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -1
app.py CHANGED
@@ -5,6 +5,8 @@ import cv2
5
  import PIL
6
  from controlnet_aux import OpenposeDetector
7
  from transformers import pipeline
 
 
8
 
9
 
10
  #import spaces #[uncomment to use ZeroGPU]
@@ -79,6 +81,12 @@ def is_lora(model_name):
79
  return model_name == "sand74/changpu_lora"
80
 
81
 
 
 
 
 
 
 
82
  #@spaces.GPU #[uncomment to use ZeroGPU]
83
  def infer(
84
  model_id,
@@ -98,6 +106,7 @@ def infer(
98
  use_ip_adapter=False,
99
  ip_adapter_image=None,
100
  ip_adapter_scale=None,
 
101
  progress=gr.Progress(track_tqdm=True),
102
  ):
103
  if randomize_seed:
@@ -119,8 +128,10 @@ def infer(
119
  else:
120
  base_model_id = model_id
121
 
 
 
122
  if not use_controlnet:
123
- pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype)
124
  else:
125
  controlnet_image = cv2.resize(controlnet_image, (width, height), interpolation=cv2.INTER_AREA)
126
  controlnet = ControlNetModel.from_pretrained(
@@ -158,6 +169,9 @@ def infer(
158
 
159
  image = pipe(**pipe_params).images[0]
160
 
 
 
 
161
  return image, seed
162
 
163
 
@@ -195,6 +209,7 @@ with gr.Blocks(css=css) as demo:
195
  run_button = gr.Button("Run", scale=0, variant="primary")
196
 
197
  result = gr.Image(label="Result", show_label=False)
 
198
 
199
  with gr.Group(visible=True) as lora_section:
200
  title = gr.Markdown(" ### LoRA section")
@@ -360,6 +375,7 @@ with gr.Blocks(css=css) as demo:
360
  use_ip_adapter,
361
  ip_adapter_image,
362
  ip_adapter_scale,
 
363
  ],
364
  outputs=[result, seed],
365
  )
 
5
  import PIL
6
  from controlnet_aux import OpenposeDetector
7
  from transformers import pipeline
8
+ from rembg import remove
9
+ from diffusers.models import AutoencoderKL
10
 
11
 
12
  #import spaces #[uncomment to use ZeroGPU]
 
81
  return model_name == "sand74/changpu_lora"
82
 
83
 
84
+ def remove_background(image):
85
+ image = remove(image)
86
+ return image
87
+
88
+
89
+
90
  #@spaces.GPU #[uncomment to use ZeroGPU]
91
  def infer(
92
  model_id,
 
106
  use_ip_adapter=False,
107
  ip_adapter_image=None,
108
  ip_adapter_scale=None,
109
+ rm_background=True,
110
  progress=gr.Progress(track_tqdm=True),
111
  ):
112
  if randomize_seed:
 
128
  else:
129
  base_model_id = model_id
130
 
131
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch_dtype)
132
+
133
  if not use_controlnet:
134
+ pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch_dtype, vae=vae)
135
  else:
136
  controlnet_image = cv2.resize(controlnet_image, (width, height), interpolation=cv2.INTER_AREA)
137
  controlnet = ControlNetModel.from_pretrained(
 
169
 
170
  image = pipe(**pipe_params).images[0]
171
 
172
+ if rm_background:
173
+ image = remove(image)
174
+
175
  return image, seed
176
 
177
 
 
209
  run_button = gr.Button("Run", scale=0, variant="primary")
210
 
211
  result = gr.Image(label="Result", show_label=False)
212
+ rm_background = gr.Checkbox(label="Remove background?", scale=1, value=True)
213
 
214
  with gr.Group(visible=True) as lora_section:
215
  title = gr.Markdown(" ### LoRA section")
 
375
  use_ip_adapter,
376
  ip_adapter_image,
377
  ip_adapter_scale,
378
+ rm_background,
379
  ],
380
  outputs=[result, seed],
381
  )