Bobby commited on
Commit
d5d3e44
·
1 Parent(s): 883eb2e

cleaned up preprocess

Browse files
Files changed (1) hide show
  1. preprocess.py +37 -40
preprocess.py CHANGED
@@ -1,30 +1,28 @@
1
- import gc
2
-
3
- import numpy as np
4
  import PIL.Image
5
- import torch
6
  from controlnet_aux import NormalBaeDetector#, CannyDetector
7
 
8
- from controlnet_aux.util import HWC3
9
- import cv2
10
  # from cv_utils import resize_image
11
 
12
  class Preprocessor:
13
  MODEL_ID = "lllyasviel/Annotators"
14
 
15
- def resize_image(input_image, resolution, interpolation=None):
16
- H, W, C = input_image.shape
17
- H = float(H)
18
- W = float(W)
19
- k = float(resolution) / max(H, W)
20
- H *= k
21
- W *= k
22
- H = int(np.round(H / 64.0)) * 64
23
- W = int(np.round(W / 64.0)) * 64
24
- if interpolation is None:
25
- interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA
26
- img = cv2.resize(input_image, (W, H), interpolation=interpolation)
27
- return img
28
 
29
 
30
  def __init__(self):
@@ -33,7 +31,6 @@ class Preprocessor:
33
 
34
  def load(self, name: str) -> None:
35
  if name == self.name:
36
- print("NormalBae already loaded")
37
  return
38
  elif name == "NormalBae":
39
  print("Loading NormalBae")
@@ -48,23 +45,23 @@ class Preprocessor:
48
  self.name = name
49
 
50
  def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
51
- if self.name == "Canny":
52
- if "detect_resolution" in kwargs:
53
- detect_resolution = kwargs.pop("detect_resolution")
54
- image = np.array(image)
55
- image = HWC3(image)
56
- image = resize_image(image, resolution=detect_resolution)
57
- image = self.model(image, **kwargs)
58
- return PIL.Image.fromarray(image)
59
- elif self.name == "Midas":
60
- detect_resolution = kwargs.pop("detect_resolution", 512)
61
- image_resolution = kwargs.pop("image_resolution", 512)
62
- image = np.array(image)
63
- image = HWC3(image)
64
- image = resize_image(image, resolution=detect_resolution)
65
- image = self.model(image, **kwargs)
66
- image = HWC3(image)
67
- image = resize_image(image, resolution=image_resolution)
68
- return PIL.Image.fromarray(image)
69
- else:
70
- return self.model(image, **kwargs)
 
1
+ # import numpy as np
 
 
2
  import PIL.Image
3
+ # import torch
4
  from controlnet_aux import NormalBaeDetector#, CannyDetector
5
 
6
+ # from controlnet_aux.util import HWC3
7
+ # import cv2
8
  # from cv_utils import resize_image
9
 
10
  class Preprocessor:
11
  MODEL_ID = "lllyasviel/Annotators"
12
 
13
+ # def resize_image(input_image, resolution, interpolation=None):
14
+ # H, W, C = input_image.shape
15
+ # H = float(H)
16
+ # W = float(W)
17
+ # k = float(resolution) / max(H, W)
18
+ # H *= k
19
+ # W *= k
20
+ # H = int(np.round(H / 64.0)) * 64
21
+ # W = int(np.round(W / 64.0)) * 64
22
+ # if interpolation is None:
23
+ # interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA
24
+ # img = cv2.resize(input_image, (W, H), interpolation=interpolation)
25
+ # return img
26
 
27
 
28
  def __init__(self):
 
31
 
32
  def load(self, name: str) -> None:
33
  if name == self.name:
 
34
  return
35
  elif name == "NormalBae":
36
  print("Loading NormalBae")
 
45
  self.name = name
46
 
47
  def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
48
+ # if self.name == "Canny":
49
+ # if "detect_resolution" in kwargs:
50
+ # detect_resolution = kwargs.pop("detect_resolution")
51
+ # image = np.array(image)
52
+ # image = HWC3(image)
53
+ # image = resize_image(image, resolution=detect_resolution)
54
+ # image = self.model(image, **kwargs)
55
+ # return PIL.Image.fromarray(image)
56
+ # elif self.name == "Midas":
57
+ # detect_resolution = kwargs.pop("detect_resolution", 512)
58
+ # image_resolution = kwargs.pop("image_resolution", 512)
59
+ # image = np.array(image)
60
+ # image = HWC3(image)
61
+ # image = resize_image(image, resolution=detect_resolution)
62
+ # image = self.model(image, **kwargs)
63
+ # image = HWC3(image)
64
+ # image = resize_image(image, resolution=image_resolution)
65
+ # return PIL.Image.fromarray(image)
66
+ # else:
67
+ return self.model(image, **kwargs)