Mayo commited on
Commit
0cc86a0
·
unverified ·
1 Parent(s): cde9fe7

chore: inference python scripts

Browse files
.gitignore CHANGED
@@ -37,3 +37,7 @@ runs/
37
  # model checkpoints
38
  models/
39
  .env
 
 
 
 
 
37
  # model checkpoints
38
  models/
39
  .env
40
+
41
+ # experiments
42
+ BallonsTranslator/
43
+ carve-lama/
scripts/ctd_inference.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+ img = cv2.imread('data/1746025823_segment.png')
5
+
6
+ kernel = np.ones((3,3),np.uint8)
7
+ h, w = img.shape[0], img.shape[1]
8
+ seedpnt = (int(w/2), int(h/2))
9
+ difres = 10
10
+
11
+ # convert to grayscale
12
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
13
+
14
+ # ballon_mask = img - 127
15
+ ballon_mask = 127 - img
16
+ ballon_mask = img
17
+ ballon_mask = cv2.dilate(ballon_mask, kernel,iterations = 1)
18
+ # ballon_area, _, _, rect = cv2.floodFill(ballon_mask, mask=None, seedPoint=seedpnt, flags=4, newVal=(30), loDiff=(difres, difres, difres), upDiff=(difres, difres, difres))
19
+ ballon_mask = 30 - ballon_mask
20
+ retval, ballon_mask = cv2.threshold(ballon_mask, 1, 255, cv2.THRESH_BINARY)
21
+ ballon_mask = cv2.bitwise_not(ballon_mask, ballon_mask)
22
+
23
+ # box_kernel = int(np.sqrt(ballon_area) / 30)
24
+ # if box_kernel > 1:
25
+ # box_kernel = np.ones((box_kernel,box_kernel),np.uint8)
26
+ # ballon_mask = cv2.dilate(ballon_mask, box_kernel, iterations = 1)
27
+ # ballon_mask = cv2.erode(ballon_mask, box_kernel, iterations = 1)
28
+
29
+ cv2.imshow('ballon_mask', ballon_mask)
30
+ #cv2.imshow('img', img)
31
+ cv2.waitKey(0)
scripts/inference_inpaint_onnx.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import onnxruntime
4
+ import torch
5
+ import io
6
+ import requests
7
+ from PIL import Image
8
+
9
+ def get_image(image):
10
+ if isinstance(image, Image.Image):
11
+ img = np.array(image)
12
+ elif isinstance(image, np.ndarray):
13
+ img = image.copy()
14
+ else:
15
+ raise Exception("Input image should be either PIL Image or numpy array!")
16
+
17
+ if img.ndim == 3:
18
+ img = np.transpose(img, (2, 0, 1)) # chw
19
+ elif img.ndim == 2:
20
+ img = img[np.newaxis, ...]
21
+
22
+ assert img.ndim == 3
23
+
24
+ img = img.astype(np.float32) / 255
25
+ return img
26
+
27
+
28
+ def ceil_modulo(x, mod):
29
+ if x % mod == 0:
30
+ return x
31
+ return (x // mod + 1) * mod
32
+
33
+
34
+ def scale_image(img, factor, interpolation=cv2.INTER_AREA):
35
+ if img.shape[0] == 1:
36
+ img = img[0]
37
+ else:
38
+ img = np.transpose(img, (1, 2, 0))
39
+
40
+ img = cv2.resize(img, dsize=None, fx=factor, fy=factor, interpolation=interpolation)
41
+
42
+ if img.ndim == 2:
43
+ img = img[None, ...]
44
+ else:
45
+ img = np.transpose(img, (2, 0, 1))
46
+ return img
47
+
48
+
49
+ def pad_img_to_modulo(img, mod):
50
+ channels, height, width = img.shape
51
+ out_height = ceil_modulo(height, mod)
52
+ out_width = ceil_modulo(width, mod)
53
+ return np.pad(
54
+ img,
55
+ ((0, 0), (0, out_height - height), (0, out_width - width)),
56
+ mode="symmetric",
57
+ )
58
+
59
+
60
+ def prepare_img_and_mask(image, mask, device, pad_out_to_modulo=8, scale_factor=None):
61
+ out_image = get_image(image)
62
+ out_mask = get_image(mask)
63
+
64
+ if scale_factor is not None:
65
+ out_image = scale_image(out_image, scale_factor)
66
+ out_mask = scale_image(out_mask, scale_factor, interpolation=cv2.INTER_NEAREST)
67
+
68
+ if pad_out_to_modulo is not None and pad_out_to_modulo > 1:
69
+ out_image = pad_img_to_modulo(out_image, pad_out_to_modulo)
70
+ out_mask = pad_img_to_modulo(out_mask, pad_out_to_modulo)
71
+
72
+ out_image = torch.from_numpy(out_image).unsqueeze(0).to(device)
73
+ out_mask = torch.from_numpy(out_mask).unsqueeze(0).to(device)
74
+
75
+ out_mask = (out_mask > 0) * 1
76
+
77
+ return out_image, out_mask
78
+
79
+
80
+ def open_image(image):
81
+ if isinstance(image, str):
82
+ if image.startswith("http://") or image.startswith("https://"):
83
+ image = Image.open(io.BytesIO(requests.get(image).content))
84
+ else:
85
+ image = Image.open(image)
86
+ return image
87
+
88
+
89
+
90
+ sess_options = onnxruntime.SessionOptions()
91
+ model = onnxruntime.InferenceSession('models/lama_manga.onnx', sess_options=sess_options)
92
+
93
+ image_url = "https://huggingface.co/Carve/LaMa-ONNX/resolve/main/image.jpg" # @param {type:"string"}
94
+ mask_url = "https://huggingface.co/Carve/LaMa-ONNX/resolve/main/mask.png" # @param {type:"string"}
95
+
96
+ image = open_image(image_url).resize((512, 512))
97
+ mask = open_image(mask_url).convert("L").resize((512, 512))
98
+
99
+ image, mask = prepare_img_and_mask(image, mask, 'cpu')
100
+ # Run the model
101
+ outputs = model.run(None,
102
+ {'image': image.numpy().astype(np.float32),
103
+ 'mask': mask.numpy().astype(np.float32)})
104
+
105
+
106
+ output = outputs[0][0] * 256
107
+ # Postprocess the outputs
108
+ output = output.transpose(1, 2, 0)
109
+ output = output.astype(np.uint8)
110
+ output = Image.fromarray(output)
111
+ output.show()