AIOmarRehan commited on
Commit
3929462
·
verified ·
1 Parent(s): 88452e6

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +26 -0
pipeline.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ from tensorflow.keras.models import load_model
4
+
5
+ class Pipeline:
6
+ def __init__(self, model_path="unet_model.h5"):
7
+ self.model = load_model(model_path, compile=False)
8
+
9
+ def preprocess(self, image):
10
+ image = image.convert("L")
11
+ image = image.resize((176, 192)) # width, height
12
+ arr = np.array(image) / 255.0
13
+ if arr.ndim == 2:
14
+ arr = np.expand_dims(arr, axis=-1)
15
+ return np.expand_dims(arr, axis=0)
16
+
17
+ def postprocess(self, pred):
18
+ pred = pred[0]
19
+ if pred.ndim == 3 and pred.shape[-1] == 1:
20
+ pred = np.squeeze(pred, axis=-1)
21
+ return Image.fromarray((pred * 255).astype(np.uint8))
22
+
23
+ def __call__(self, image):
24
+ x = self.preprocess(image)
25
+ pred = self.model.predict(x)
26
+ return self.postprocess(pred)