Purple11 commited on
Commit
0d2623f
·
1 Parent(s): 3d41704

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +414 -58
app.py CHANGED
@@ -4,75 +4,431 @@ import torch
4
  import numpy as np
5
  from torchvision import transforms
6
 
7
- title = "Remove Bg"
8
- description = "Automatically remove the image background from a profile photo."
9
- article = "<p style='text-align: center'><a href='https://news.machinelearning.sg/posts/beautiful_profile_pics_remove_background_image_with_deeplabv3/'>Blog</a> | <a href='https://github.com/eugenesiow/practical-ml'>Github Repo</a></p>"
10
-
11
-
12
- def make_transparent_foreground(pic, mask):
13
- # split the image into channels
14
- b, g, r = cv2.split(np.array(pic).astype('uint8'))
15
- # add an alpha channel with and fill all with transparent pixels (max 255)
16
- a = np.ones(mask.shape, dtype='uint8') * 255
17
- # merge the alpha channel back
18
- alpha_im = cv2.merge([b, g, r, a], 4)
19
- # create a transparent background
20
- bg = np.zeros(alpha_im.shape)
21
- # setup the new mask
22
- new_mask = np.stack([mask, mask, mask, mask], axis=2)
23
- # copy only the foreground color pixels from the original image where mask is set
24
- foreground = np.where(new_mask, alpha_im, bg).astype(np.uint8)
25
-
26
- return foreground
27
-
28
-
29
- def remove_background(input_image):
30
- preprocess = transforms.Compose([
31
- transforms.ToTensor(),
32
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
33
- ])
34
-
35
- input_tensor = preprocess(input_image)
36
- input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
37
-
38
- # move the input and model to GPU for speed if available
39
- if torch.cuda.is_available():
40
- input_batch = input_batch.to('cuda')
41
- model.to('cuda')
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  with torch.no_grad():
44
- output = model(input_batch)['out'][0]
45
- output_predictions = output.argmax(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # create a binary (black and white) mask of the profile foreground
48
- mask = output_predictions.byte().cpu().numpy()
49
- background = np.zeros(mask.shape)
50
- bin_mask = np.where(mask, 255, background).astype(np.uint8)
51
 
52
- foreground = make_transparent_foreground(input_image, bin_mask)
53
 
54
- return foreground, bin_mask
55
 
56
 
57
- def inference(img):
58
- foreground, _ = remove_background(img)
59
- return foreground
60
 
61
 
62
- torch.hub.download_url_to_file('https://pbs.twimg.com/profile_images/691700243809718272/z7XZUARB_400x400.jpg',
63
- 'demis.jpg')
64
- torch.hub.download_url_to_file('https://hai.stanford.edu/sites/default/files/styles/person_medium/public/2020-03/hai_1512feifei.png?itok=INFuLABp',
65
- 'lifeifei.png')
66
- model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True)
67
- model.eval()
68
 
69
  gr.Interface(
70
  inference,
71
- gr.inputs.Image(type="pil", label="Input"),
 
72
  gr.outputs.Image(type="pil", label="Output"),
73
- title=title,
74
- description=description,
75
- article=article,
76
- examples=[['demis.jpg'], ['lifeifei.png']],
77
- enable_queue=True
78
  ).launch(debug=False)
 
4
  import numpy as np
5
  from torchvision import transforms
6
 
7
+ # title = "Remove Bg"
8
+ # description = "Automatically remove the image background from a profile photo."
9
+ # article = "<p style='text-align: center'><a href='https://news.machinelearning.sg/posts/beautiful_profile_pics_remove_background_image_with_deeplabv3/'>Blog</a> | <a href='https://github.com/eugenesiow/practical-ml'>Github Repo</a></p>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ import argparse, os
12
+ import cv2
13
+ import torch
14
+ import numpy as np
15
+ import torchvision
16
+ from omegaconf import OmegaConf
17
+ from PIL import Image
18
+ from tqdm import tqdm, trange
19
+ from itertools import islice
20
+ from einops import rearrange
21
+ from torchvision.utils import make_grid
22
+ import time
23
+ from pytorch_lightning import seed_everything
24
+ from torch import autocast
25
+ from contextlib import nullcontext
26
+
27
+ from ldm.util import instantiate_from_config
28
+ from ldm.models.diffusion.ddim import DDIMSampler
29
+ from ldm.modules.diffusionmodules.openaimodel import clear_feature_dic,get_feature_dic
30
+ from ldm.models.seg_module import Segmodule
31
+
32
+ import numpy as np
33
+
34
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1"
35
+
36
+ def chunk(it, size):
37
+ it = iter(it)
38
+ return iter(lambda: tuple(islice(it, size)), ())
39
+
40
+
41
+ def numpy_to_pil(images):
42
+ """
43
+ Convert a numpy image or a batch of images to a PIL image.
44
+ """
45
+ if images.ndim == 3:
46
+ images = images[None, ...]
47
+ images = (images * 255).round().astype("uint8")
48
+ pil_images = [Image.fromarray(image) for image in images]
49
+
50
+ return pil_images
51
+
52
+
53
+ def load_model_from_config(config, ckpt, verbose=False):
54
+ # print(f"Loading model from {ckpt}")
55
+ pl_sd = torch.load(ckpt, map_location="cpu")
56
+ if "global_step" in pl_sd:
57
+ # print(f"Global Step: {pl_sd['global_step']}")
58
+ sd = pl_sd["state_dict"]
59
+ model = instantiate_from_config(config.model)
60
+ # m, u = model.load_state_dict(sd, strict=False)
61
+ # if len(m) > 0 and verbose:
62
+ # print("missing keys:")
63
+ # print(m)
64
+ # if len(u) > 0 and verbose:
65
+ # print("unexpected keys:")
66
+ # print(u)
67
+
68
+ model.cuda()
69
+ model.eval()
70
+ return model
71
+
72
+
73
+ def put_watermark(img, wm_encoder=None):
74
+ if wm_encoder is not None:
75
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
76
+ img = wm_encoder.encode(img, 'dwtDct')
77
+ img = Image.fromarray(img[:, :, ::-1])
78
+ return img
79
+
80
+
81
+ def load_replacement(x):
82
+ try:
83
+ hwc = x.shape
84
+ y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
85
+ y = (np.array(y)/255.0).astype(x.dtype)
86
+ assert y.shape == x.shape
87
+ return y
88
+ except Exception:
89
+ return x
90
+
91
+ def plot_mask(img, masks, colors=None, alpha=0.8,indexlist=[0,1]) -> np.ndarray:
92
+ H,W= masks.shape[0],masks.shape[1]
93
+ color_list=[[255,97,0],[128,42,42],[220,220,220],[255,153,18],[56,94,15],[127,255,212],[210,180,140],[221,160,221],[255,0,0],[255,128,0],[255,255,0],[128,255,0],[0,255,0],[0,255,128],[0,255,255],[0,128,255],[0,0,255],[128,0,255],[255,0,255],[255,0,128]]*6
94
+ final_color_list=[np.array([[i]*512]*512) for i in color_list]
95
+
96
+ background=np.ones(img.shape)*255
97
+ count=0
98
+ colors=final_color_list[indexlist[count]]
99
+ for mask, color in zip(masks, colors):
100
+ color=final_color_list[indexlist[count]]
101
+ mask = np.stack([mask, mask, mask], -1)
102
+ img = np.where(mask, img * (1 - alpha) + color * alpha,background*0.4+img*0.6 )
103
+ count+=1
104
+ return img.astype(np.uint8)
105
+
106
+ def create_parser():
107
+
108
+ parser = argparse.ArgumentParser()
109
+
110
+ parser.add_argument(
111
+ "--prompt",
112
+ type=str,
113
+ nargs="?",
114
+ default="a photo of a lion on a mountain top at sunset",
115
+ help="the prompt to render"
116
+ )
117
+ parser.add_argument(
118
+ "--category",
119
+ type=str,
120
+ nargs="?",
121
+ default="lion",
122
+ help="the category to ground"
123
+ )
124
+ parser.add_argument(
125
+ "--outdir",
126
+ type=str,
127
+ nargs="?",
128
+ help="dir to write results to",
129
+ default="outputs/txt2img-samples"
130
+ )
131
+ parser.add_argument(
132
+ "--skip_grid",
133
+ action='store_true',
134
+ help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
135
+ )
136
+ parser.add_argument(
137
+ "--skip_save",
138
+ action='store_true',
139
+ help="do not save individual samples. For speed measurements.",
140
+ )
141
+ parser.add_argument(
142
+ "--ddim_steps",
143
+ type=int,
144
+ default=50,
145
+ help="number of ddim sampling steps",
146
+ )
147
+ parser.add_argument(
148
+ "--plms",
149
+ action='store_true',
150
+ help="use plms sampling",
151
+ )
152
+ parser.add_argument(
153
+ "--laion400m",
154
+ action='store_true',
155
+ help="uses the LAION400M model",
156
+ )
157
+ parser.add_argument(
158
+ "--fixed_code",
159
+ action='store_true',
160
+ help="if enabled, uses the same starting code across samples ",
161
+ )
162
+ parser.add_argument(
163
+ "--ddim_eta",
164
+ type=float,
165
+ default=0.0,
166
+ help="ddim eta (eta=0.0 corresponds to deterministic sampling",
167
+ )
168
+ parser.add_argument(
169
+ "--n_iter",
170
+ type=int,
171
+ default=1,
172
+ help="sample this often",
173
+ )
174
+ parser.add_argument(
175
+ "--H",
176
+ type=int,
177
+ default=512,
178
+ help="image height, in pixel space",
179
+ )
180
+ parser.add_argument(
181
+ "--W",
182
+ type=int,
183
+ default=512,
184
+ help="image width, in pixel space",
185
+ )
186
+ parser.add_argument(
187
+ "--C",
188
+ type=int,
189
+ default=4,
190
+ help="latent channels",
191
+ )
192
+ parser.add_argument(
193
+ "--f",
194
+ type=int,
195
+ default=8,
196
+ help="downsampling factor",
197
+ )
198
+ parser.add_argument(
199
+ "--n_samples",
200
+ type=int,
201
+ default=1,
202
+ help="how many samples to produce for each given prompt. A.k.a. batch size",
203
+ )
204
+ parser.add_argument(
205
+ "--n_rows",
206
+ type=int,
207
+ default=0,
208
+ help="rows in the grid (default: n_samples)",
209
+ )
210
+ parser.add_argument(
211
+ "--scale",
212
+ type=float,
213
+ default=7.5,
214
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
215
+ )
216
+ parser.add_argument(
217
+ "--from-file",
218
+ type=str,
219
+ help="if specified, load prompts from this file",
220
+ )
221
+ parser.add_argument(
222
+ "--config",
223
+ type=str,
224
+ default="configs/stable-diffusion/v1-inference.yaml",
225
+ help="path to config which constructs model",
226
+ )
227
+ parser.add_argument(
228
+ "--sd_ckpt",
229
+ type=str,
230
+ default="stable_diffusion.ckpt",
231
+ help="path to checkpoint of stable diffusion model",
232
+ )
233
+ parser.add_argument(
234
+ "--grounding_ckpt",
235
+ type=str,
236
+ default="grounding_module.pth",
237
+ help="path to checkpoint of grounding module",
238
+ )
239
+ parser.add_argument(
240
+ "--seed",
241
+ type=int,
242
+ default=42,
243
+ help="the seed (for reproducible sampling)",
244
+ )
245
+ parser.add_argument(
246
+ "--precision",
247
+ type=str,
248
+ help="evaluate at this precision",
249
+ choices=["full", "autocast"],
250
+ default="autocast"
251
+ )
252
+ opt = parser.parse_args()
253
+
254
+ return opt
255
+
256
+
257
+ def inference(input_prompt, input_category):
258
+
259
+ opt = create_parser()
260
+
261
+ seed_everything(opt.seed)
262
+
263
+ tic = time.time()
264
+ config = OmegaConf.load(f"{opt.config}")
265
+ model = load_model_from_config(config, f"{opt.sd_ckpt}")
266
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
267
+ model = model.to(device)
268
+ toc = time.time()
269
+ seg_module=Segmodule().to(device)
270
+
271
+ seg_module.load_state_dict(torch.load(opt.grounding_ckpt, map_location="cpu"), strict=True)
272
+ # print('load time:',toc-tic)
273
+ sampler = DDIMSampler(model)
274
+
275
+ os.makedirs(opt.outdir, exist_ok=True)
276
+ outpath = opt.outdir
277
+ batch_size = opt.n_samples
278
+ precision_scope = autocast if opt.precision=="autocast" else nullcontext
279
  with torch.no_grad():
280
+ with precision_scope("cuda"):
281
+ with model.ema_scope():
282
+ prompt = input_prompt
283
+ text = input_category
284
+ trainclass = text
285
+ if not opt.from_file:
286
+ assert prompt is not None
287
+ data = [batch_size * [prompt]]
288
+
289
+ else:
290
+ # print(f"reading prompts from {opt.from_file}")
291
+ with open(opt.from_file, "r") as f:
292
+ data = f.read().splitlines()
293
+ data = list(chunk(data, batch_size))
294
+
295
+ sample_path = os.path.join(outpath, "samples")
296
+ os.makedirs(sample_path, exist_ok=True)
297
+
298
+ start_code = None
299
+ if opt.fixed_code:
300
+ # print('start_code')
301
+ start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
302
+ for n in trange(opt.n_iter, desc="Sampling"):
303
+ for prompts in tqdm(data, desc="data"):
304
+ clear_feature_dic()
305
+ uc = None
306
+ if opt.scale != 1.0:
307
+ uc = model.get_learned_conditioning(batch_size * [""])
308
+ if isinstance(prompts, tuple):
309
+ prompts = list(prompts)
310
+
311
+ c = model.get_learned_conditioning(prompts)
312
+ shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
313
+ samples_ddim, _, _ = sampler.sample(S=opt.ddim_steps,
314
+ conditioning=c,
315
+ batch_size=opt.n_samples,
316
+ shape=shape,
317
+ verbose=False,
318
+ unconditional_guidance_scale=opt.scale,
319
+ unconditional_conditioning=uc,
320
+ eta=opt.ddim_eta,
321
+ x_T=start_code)
322
+
323
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
324
+ diffusion_features = get_feature_dic()
325
+
326
+
327
+ x_sample = torch.clamp((x_samples_ddim[0] + 1.0) / 2.0, min=0.0, max=1.0)
328
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
329
+
330
+ Image.fromarray(x_sample.astype(np.uint8)).save("demo/demo.png")
331
+ img = x_sample.astype(np.uint8)
332
+
333
+ class_name = trainclass
334
+
335
+ query_text ="a photograph of a " + class_name
336
+ c_split = model.cond_stage_model.tokenizer.tokenize(query_text)
337
+
338
+ sen_text_embedding = model.get_learned_conditioning(query_text)
339
+ class_embedding = sen_text_embedding[:, 5:len(c_split)+1, :]
340
+
341
+ if class_embedding.size()[1] > 1:
342
+ class_embedding = torch.unsqueeze(class_embedding.mean(1), 1)
343
+ text_embedding = class_embedding
344
+
345
+ text_embedding = text_embedding.repeat(batch_size, 1, 1)
346
+
347
+
348
+ pred_seg_total = seg_module(diffusion_features, text_embedding)
349
+
350
+
351
+ pred_seg = torch.unsqueeze(pred_seg_total[0,0,:,:], 0).unsqueeze(0)
352
+
353
+ label_pred_prob = torch.sigmoid(pred_seg)
354
+ label_pred_mask = torch.zeros_like(label_pred_prob, dtype=torch.float32)
355
+ label_pred_mask[label_pred_prob > 0.5] = 1
356
+ annotation_pred = label_pred_mask[0][0].cpu()
357
+
358
+ mask = annotation_pred.numpy()
359
+ mask = np.expand_dims(mask, 0)
360
+ done_image_mask = plot_mask(img, mask, alpha=0.9, indexlist=[0])
361
+ # cv2.imwrite(os.path.join("demo/demo_mask.png"), done_image_mask)
362
+
363
+ # torchvision.utils.save_image(annotation_pred, os.path.join("demo/demo_segresult.png"), normalize=True, scale_each=True)
364
+ return x_sample, done_image_mask
365
+
366
+
367
+ # def make_transparent_foreground(pic, mask):
368
+ # # split the image into channels
369
+ # b, g, r = cv2.split(np.array(pic).astype('uint8'))
370
+ # # add an alpha channel with and fill all with transparent pixels (max 255)
371
+ # a = np.ones(mask.shape, dtype='uint8') * 255
372
+ # # merge the alpha channel back
373
+ # alpha_im = cv2.merge([b, g, r, a], 4)
374
+ # # create a transparent background
375
+ # bg = np.zeros(alpha_im.shape)
376
+ # # setup the new mask
377
+ # new_mask = np.stack([mask, mask, mask, mask], axis=2)
378
+ # # copy only the foreground color pixels from the original image where mask is set
379
+ # foreground = np.where(new_mask, alpha_im, bg).astype(np.uint8)
380
+
381
+ # return foreground
382
+
383
+
384
+ # def remove_background(input_image):
385
+ # preprocess = transforms.Compose([
386
+ # transforms.ToTensor(),
387
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
388
+ # ])
389
+
390
+ # input_tensor = preprocess(input_image)
391
+ # input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
392
+
393
+ # # move the input and model to GPU for speed if available
394
+ # if torch.cuda.is_available():
395
+ # input_batch = input_batch.to('cuda')
396
+ # model.to('cuda')
397
+
398
+ # with torch.no_grad():
399
+ # output = model(input_batch)['out'][0]
400
+ # output_predictions = output.argmax(0)
401
 
402
+ # # create a binary (black and white) mask of the profile foreground
403
+ # mask = output_predictions.byte().cpu().numpy()
404
+ # background = np.zeros(mask.shape)
405
+ # bin_mask = np.where(mask, 255, background).astype(np.uint8)
406
 
407
+ # foreground = make_transparent_foreground(input_image, bin_mask)
408
 
409
+ # return foreground, bin_mask
410
 
411
 
412
+ # def inference(img):
413
+ # foreground, _ = remove_background(img)
414
+ # return foreground
415
 
416
 
417
+ # torch.hub.download_url_to_file('https://pbs.twimg.com/profile_images/691700243809718272/z7XZUARB_400x400.jpg',
418
+ # 'demis.jpg')
419
+ # torch.hub.download_url_to_file('https://hai.stanford.edu/sites/default/files/styles/person_medium/public/2020-03/hai_1512feifei.png?itok=INFuLABp',
420
+ # 'lifeifei.png')
421
+ # model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=True)
422
+ # model.eval()
423
 
424
  gr.Interface(
425
  inference,
426
+ gr.inputs.Textbox(label='Prompt', default='a photo of a lion on a mountain top at sunset'),
427
+ gr.inputs.Textbox(label='category', default='lion'),
428
  gr.outputs.Image(type="pil", label="Output"),
429
+ # title=title,
430
+ # description=description,
431
+ # article=article,
432
+ # examples=[['demis.jpg'], ['lifeifei.png']],
433
+ # enable_queue=True
434
  ).launch(debug=False)