vivekgr92 commited on
Commit
e80140c
·
1 Parent(s): af27abe

Add utils

Browse files
Files changed (1) hide show
  1. utils.py +65 -0
utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image, ImageOps
5
+ from torchvision.transforms import ToPILImage, ToTensor
6
+ totensor = ToTensor()
7
+ topil = ToPILImage()
8
+
9
+
10
+
11
+ def resize_and_crop(img, size, crop_type="center"):
12
+ '''Resize and crop the image to the given size.'''
13
+ if crop_type == "top":
14
+ center = (0, 0)
15
+ elif crop_type == "center":
16
+ center = (0.5, 0.5)
17
+ else:
18
+ raise ValueError
19
+
20
+ resize = list(size)
21
+ if size[0] is None:
22
+ resize[0] = img.size[0]
23
+ if size[1] is None:
24
+ resize[1] = img.size[1]
25
+ return ImageOps.fit(img, resize, centering=center)
26
+
27
+
28
+ def recover_image(image, init_image, mask, background=False):
29
+ image = totensor(image)
30
+ mask = totensor(mask)[0]
31
+ init_image = totensor(init_image)
32
+
33
+ if background:
34
+ result = mask * init_image + (1 - mask) * image
35
+ else:
36
+ result = mask * image + (1 - mask) * init_image
37
+ return topil(result)
38
+
39
+
40
+ def preprocess(image):
41
+ w, h = image.size
42
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
43
+ image = image.resize((w, h), resample=Image.LANCZOS)
44
+ image = np.array(image).astype(np.float32) / 255.0
45
+ image = image[None].transpose(0, 3, 1, 2)
46
+ image = torch.from_numpy(image)
47
+ return 2.0 * image - 1.0
48
+
49
+
50
+ def prepare_mask_and_masked_image(image, mask):
51
+
52
+ image = np.array(image.convert("RGB"))
53
+ image = image[None].transpose(0, 3, 1, 2)
54
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
55
+
56
+ mask = np.array(mask.convert("L"))
57
+ mask = mask.astype(np.float32) / 255.0
58
+ mask = mask[None, None]
59
+ mask[mask < 0.5] = 0
60
+ mask[mask >= 0.5] = 1
61
+ mask = torch.from_numpy(mask)
62
+
63
+ masked_image = image * (mask < 0.5)
64
+
65
+ return mask, masked_image