pilotj commited on
Commit
93bd096
·
verified ·
1 Parent(s): 4da67c1

Create image_utils.py

Browse files
Files changed (1) hide show
  1. inference/utils/image_utils.py +16 -0
inference/utils/image_utils.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ import torch
4
+
5
+ def preprocess_controlnet_image(image: Image.Image, device, dtype=torch.float32):
6
+ image = image.resize((512, 512))
7
+ img_array = np.array(image).astype(np.float32) / 255.0
8
+
9
+ if img_array.ndim == 2:
10
+ img_array = img_array[None, None, :, :]
11
+ elif img_array.shape[2] == 3:
12
+ img_array = img_array.transpose(2, 0, 1)[None, :, :, :]
13
+ else:
14
+ raise ValueError("Unexpected image shape.")
15
+
16
+ return torch.tensor(img_array, device=device, dtype=dtype)