File size: 2,772 Bytes
ed73811
 
 
 
47d1a17
ed73811
fdf1e29
ed73811
 
 
 
 
 
47d1a17
 
ed73811
47d1a17
 
ed73811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47d1a17
ed73811
 
 
 
 
 
 
 
 
47d1a17
ed73811
47d1a17
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from huggingface_hub import from_pretrained_keras
import gradio as gr 
import tensorflow as tf 
import numpy as np 
import os

model = tf.keras.models.load.model(os.path.join 

inputs = gr.inputs.Image()
output = gr.output.Image()


def predict(image_input):
    img = np.array(inputs)
    
  pass 


class PreTrainedPipeline():
    def __init__(self, path: str):
        # load the model
        self.model = keras.models.load_model(os.path.join(path, "tf_model.h5"))

    def __call__(self, inputs: "Image.Image")-> List[Dict[str, Any]]:

        # convert img to numpy array, resize and normalize to make the prediction
        img = np.array(inputs)

        im = tf.image.resize(img, (128, 128))
        im = tf.cast(im, tf.float32) / 255.0
        pred_mask = self.model.predict(im[tf.newaxis, ...])
        
        # take the best performing class for each pixel
        # the output of argmax looks like this [[1, 2, 0], ...]
        pred_mask_arg = tf.argmax(pred_mask, axis=-1)

        labels = []
        
        # convert the prediction mask into binary masks for each class
        binary_masks = {}
        mask_codes = {}
        
        # when we take tf.argmax() over pred_mask, it becomes a tensor object
        # the shape becomes TensorShape object, looking like this TensorShape([128]) 
        # we need to take get shape, convert to list and take the best one
        
        rows = pred_mask_arg[0][1].get_shape().as_list()[0]
        cols = pred_mask_arg[0][2].get_shape().as_list()[0]
        
        for cls in range(pred_mask.shape[-1]):

            binary_masks[f"mask_{cls}"] = np.zeros(shape = (pred_mask.shape[1], pred_mask.shape[2])) #create masks for each class
            
            for row in range(rows):

                for col in range(cols):

                    if pred_mask_arg[0][row][col] == cls:
                        
                        binary_masks[f"mask_{cls}"][row][col] = 1
                    else:
                        binary_masks[f"mask_{cls}"][row][col] = 0

            mask = binary_masks[f"mask_{cls}"]
            mask *= 255 
            img = Image.fromarray(mask.astype(np.int8), mode="L")
               
            # we need to make it readable for the widget
            with io.BytesIO() as out:
                img.save(out, format="PNG")
                png_string = out.getvalue()
                mask = base64.b64encode(png_string).decode("utf-8")

            mask_codes[f"mask_{cls}"] = mask
    

            # widget needs the below format, for each class we return label and mask string
            labels.append({
                "label": f"LABEL_{cls}",
                "mask": mask_codes[f"mask_{cls}"],
                "score": 1.0,
            })
        return labels