davelop commited on
Commit
7443712
·
verified ·
1 Parent(s): a1c1ebb

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +103 -0
inference.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import cv2
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from sklearn.cluster import KMeans
6
+ from collections import Counter
7
+ from scipy.spatial import KDTree
8
+ from webcolors import hex_to_rgb, rgb_to_hex
9
+ from PIL import Image
10
+
11
+
12
+ model = tf.keras.models.load_model("model.h5")
13
+
14
+ classes = [
15
+ "background", "skin", "left eyebrow", "right eyebrow",
16
+ "left eye", "right eye", "nose", "upper lip", "inner mouth",
17
+ "lower lip", "hair"
18
+ ]
19
+
20
+
21
+ def face_skin_extract(pred, image_x):
22
+
23
+ output = np.zeros_like(image_x, dtype=np.uint8)
24
+ mask = (pred == 1)
25
+ output[mask] = image_x[mask]
26
+ return output
27
+
28
+ def extract_dom_color_kmeans(img):
29
+
30
+ mask = ~np.all(img == [0, 0, 0], axis=-1)
31
+ non_black_pixels = img[mask]
32
+
33
+ k_cluster = KMeans(n_clusters=3, n_init="auto")
34
+ k_cluster.fit(non_black_pixels)
35
+
36
+ n_pixels = len(k_cluster.labels_)
37
+ counter = Counter(k_cluster.labels_)
38
+ perc = {i: np.round(counter[i] / n_pixels, 2) for i in counter}
39
+
40
+ val = list(perc.values())
41
+ val.sort()
42
+ res = val[-1]
43
+
44
+ dominant_cluster_index = list(perc.keys())[list(perc.values()).index(res)]
45
+ rgb_list = k_cluster.cluster_centers_[dominant_cluster_index]
46
+
47
+ return rgb_list
48
+
49
+ def closest_tone_match(rgb_tuple):
50
+ skin_tones = {
51
+ 'Monk 10': '#292420',
52
+ 'Monk 9': '#3a312a',
53
+ 'Monk 8': '#604134',
54
+ 'Monk 7': '#825c43',
55
+ 'Monk 6': '#a07e56',
56
+ 'Monk 5': '#d7bd96',
57
+ 'Monk 4': '#eadaba',
58
+ 'Monk 3': '#f7ead0',
59
+ 'Monk 2': '#f3e7db',
60
+ 'Monk 1': '#f6ede4'
61
+ }
62
+
63
+ rgb_values = []
64
+ names = []
65
+ for monk in skin_tones:
66
+ names.append(monk)
67
+ rgb_values.append(hex_to_rgb(skin_tones[monk]))
68
+
69
+ kdt_db = KDTree(rgb_values)
70
+ distance, index = kdt_db.query(rgb_tuple)
71
+ monk_hex = skin_tones[names[index]]
72
+ derived_hex = rgb_to_hex((int(rgb_tuple[0]), int(rgb_tuple[1]), int(rgb_tuple[2])))
73
+ return names[index], monk_hex, derived_hex
74
+
75
+
76
+ def inference(inputs: bytes) -> dict:
77
+
78
+ nparr = np.frombuffer(inputs, np.uint8)
79
+ image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
80
+
81
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
82
+
83
+ image_x = cv2.resize(image, (512, 512))
84
+ image_norm = image_x / 255.0
85
+ image_norm = np.expand_dims(image_norm, axis=0).astype(np.float32)
86
+
87
+ pred = model.predict(image_norm)[0]
88
+ pred = np.argmax(pred, axis=-1).astype(np.int32)
89
+
90
+ face_skin = face_skin_extract(pred, image_x)
91
+
92
+ dominant_color_rgb = extract_dom_color_kmeans(face_skin) # This is an RGB tuple (floats)
93
+
94
+ monk_tone, monk_hex, derived_hex = closest_tone_match(
95
+ (dominant_color_rgb[0], dominant_color_rgb[1], dominant_color_rgb[2])
96
+ )
97
+
98
+ return {
99
+ "derived_hex_code": derived_hex,
100
+ "monk_hex": monk_hex,
101
+ "monk_skin_tone": monk_tone,
102
+ "dominant_rgb": dominant_color_rgb.tolist()
103
+ }