Create inference.py
Browse files- inference.py +103 -0
inference.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import tensorflow as tf
|
| 5 |
+
from sklearn.cluster import KMeans
|
| 6 |
+
from collections import Counter
|
| 7 |
+
from scipy.spatial import KDTree
|
| 8 |
+
from webcolors import hex_to_rgb, rgb_to_hex
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
model = tf.keras.models.load_model("model.h5")
|
| 13 |
+
|
| 14 |
+
classes = [
|
| 15 |
+
"background", "skin", "left eyebrow", "right eyebrow",
|
| 16 |
+
"left eye", "right eye", "nose", "upper lip", "inner mouth",
|
| 17 |
+
"lower lip", "hair"
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def face_skin_extract(pred, image_x):
|
| 22 |
+
|
| 23 |
+
output = np.zeros_like(image_x, dtype=np.uint8)
|
| 24 |
+
mask = (pred == 1)
|
| 25 |
+
output[mask] = image_x[mask]
|
| 26 |
+
return output
|
| 27 |
+
|
| 28 |
+
def extract_dom_color_kmeans(img):
|
| 29 |
+
|
| 30 |
+
mask = ~np.all(img == [0, 0, 0], axis=-1)
|
| 31 |
+
non_black_pixels = img[mask]
|
| 32 |
+
|
| 33 |
+
k_cluster = KMeans(n_clusters=3, n_init="auto")
|
| 34 |
+
k_cluster.fit(non_black_pixels)
|
| 35 |
+
|
| 36 |
+
n_pixels = len(k_cluster.labels_)
|
| 37 |
+
counter = Counter(k_cluster.labels_)
|
| 38 |
+
perc = {i: np.round(counter[i] / n_pixels, 2) for i in counter}
|
| 39 |
+
|
| 40 |
+
val = list(perc.values())
|
| 41 |
+
val.sort()
|
| 42 |
+
res = val[-1]
|
| 43 |
+
|
| 44 |
+
dominant_cluster_index = list(perc.keys())[list(perc.values()).index(res)]
|
| 45 |
+
rgb_list = k_cluster.cluster_centers_[dominant_cluster_index]
|
| 46 |
+
|
| 47 |
+
return rgb_list
|
| 48 |
+
|
| 49 |
+
def closest_tone_match(rgb_tuple):
|
| 50 |
+
skin_tones = {
|
| 51 |
+
'Monk 10': '#292420',
|
| 52 |
+
'Monk 9': '#3a312a',
|
| 53 |
+
'Monk 8': '#604134',
|
| 54 |
+
'Monk 7': '#825c43',
|
| 55 |
+
'Monk 6': '#a07e56',
|
| 56 |
+
'Monk 5': '#d7bd96',
|
| 57 |
+
'Monk 4': '#eadaba',
|
| 58 |
+
'Monk 3': '#f7ead0',
|
| 59 |
+
'Monk 2': '#f3e7db',
|
| 60 |
+
'Monk 1': '#f6ede4'
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
rgb_values = []
|
| 64 |
+
names = []
|
| 65 |
+
for monk in skin_tones:
|
| 66 |
+
names.append(monk)
|
| 67 |
+
rgb_values.append(hex_to_rgb(skin_tones[monk]))
|
| 68 |
+
|
| 69 |
+
kdt_db = KDTree(rgb_values)
|
| 70 |
+
distance, index = kdt_db.query(rgb_tuple)
|
| 71 |
+
monk_hex = skin_tones[names[index]]
|
| 72 |
+
derived_hex = rgb_to_hex((int(rgb_tuple[0]), int(rgb_tuple[1]), int(rgb_tuple[2])))
|
| 73 |
+
return names[index], monk_hex, derived_hex
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def inference(inputs: bytes) -> dict:
|
| 77 |
+
|
| 78 |
+
nparr = np.frombuffer(inputs, np.uint8)
|
| 79 |
+
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 80 |
+
|
| 81 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 82 |
+
|
| 83 |
+
image_x = cv2.resize(image, (512, 512))
|
| 84 |
+
image_norm = image_x / 255.0
|
| 85 |
+
image_norm = np.expand_dims(image_norm, axis=0).astype(np.float32)
|
| 86 |
+
|
| 87 |
+
pred = model.predict(image_norm)[0]
|
| 88 |
+
pred = np.argmax(pred, axis=-1).astype(np.int32)
|
| 89 |
+
|
| 90 |
+
face_skin = face_skin_extract(pred, image_x)
|
| 91 |
+
|
| 92 |
+
dominant_color_rgb = extract_dom_color_kmeans(face_skin) # This is an RGB tuple (floats)
|
| 93 |
+
|
| 94 |
+
monk_tone, monk_hex, derived_hex = closest_tone_match(
|
| 95 |
+
(dominant_color_rgb[0], dominant_color_rgb[1], dominant_color_rgb[2])
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
return {
|
| 99 |
+
"derived_hex_code": derived_hex,
|
| 100 |
+
"monk_hex": monk_hex,
|
| 101 |
+
"monk_skin_tone": monk_tone,
|
| 102 |
+
"dominant_rgb": dominant_color_rgb.tolist()
|
| 103 |
+
}
|