Spaces:
Build error
Build error
DimasMP3 commited on
Commit ·
b02d758
1
Parent(s): d2f7145
add guard
Browse files- inference.py +80 -33
inference.py
CHANGED
|
@@ -2,39 +2,56 @@
|
|
| 2 |
import os, json, time
|
| 3 |
from typing import List, Dict, Iterable, Any, Optional
|
| 4 |
import numpy as np
|
| 5 |
-
from PIL import Image
|
|
|
|
| 6 |
import tensorflow as tf
|
| 7 |
|
| 8 |
# ---------- Label utils ----------
|
| 9 |
-
_DEFAULT_LABELS = ["Heart","Oblong","Oval","Round","Square"]
|
| 10 |
|
| 11 |
def _load_labels() -> List[str]:
|
|
|
|
| 12 |
p_ci = os.path.join("models", "class_indices.json") # {"Label": idx}
|
| 13 |
-
p_i2c = os.path.join("models", "idx2class.json") # {"0":"Label"}
|
|
|
|
| 14 |
try:
|
| 15 |
with open(p_ci, "r") as f:
|
| 16 |
ci = json.load(f)
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
except Exception:
|
| 21 |
pass
|
|
|
|
| 22 |
try:
|
| 23 |
with open(p_i2c, "r") as f:
|
| 24 |
i2c = json.load(f)
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
except Exception:
|
| 29 |
pass
|
| 30 |
print("[LABEL] fallback default ->", _DEFAULT_LABELS)
|
| 31 |
-
return _DEFAULT_LABELS
|
| 32 |
|
| 33 |
def _generate_config_if_missing(model: tf.keras.Model, labels: List[str], path="config.json"):
|
| 34 |
-
|
|
|
|
|
|
|
| 35 |
ishape = model.input_shape
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
| 38 |
cfg = {
|
| 39 |
"architectures": ["EfficientNetB4"],
|
| 40 |
"image_size": h,
|
|
@@ -42,75 +59,105 @@ def _generate_config_if_missing(model: tf.keras.Model, labels: List[str], path="
|
|
| 42 |
"id2label": {str(i): lbl for i, lbl in enumerate(labels)},
|
| 43 |
"label2id": {lbl: i for i, lbl in enumerate(labels)},
|
| 44 |
}
|
| 45 |
-
with open(path, "w") as f:
|
|
|
|
| 46 |
print(f"[CFG] wrote {path} (image_size={h}, num_labels={len(labels)})")
|
| 47 |
|
| 48 |
# ---------- Model wrapper ----------
|
| 49 |
class FaceShapeModel:
|
| 50 |
def __init__(self, model_path="models/model.keras"):
|
| 51 |
-
self.labels = _load_labels()
|
|
|
|
| 52 |
full_path = os.path.join(os.getcwd(), model_path)
|
| 53 |
print(f"[LOAD] {full_path}")
|
| 54 |
-
self.model = tf.keras.models.load_model(full_path, compile=False)
|
| 55 |
|
|
|
|
| 56 |
ishape = self.model.input_shape
|
| 57 |
-
self.img_size = int(ishape[1])
|
| 58 |
print(f"[MODEL] input img_size = {self.img_size}")
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
_generate_config_if_missing(self.model, self.labels)
|
| 66 |
|
|
|
|
| 67 |
try:
|
| 68 |
_ = self.model(tf.zeros((1, self.img_size, self.img_size, 3), dtype=tf.float32))
|
| 69 |
except Exception as e:
|
| 70 |
print("[WARN] warmup failed:", e)
|
| 71 |
|
| 72 |
-
|
|
|
|
|
|
|
| 73 |
return img if img.mode == "RGB" else img.convert("RGB")
|
| 74 |
|
| 75 |
def _preprocess(self, img: Image.Image) -> np.ndarray:
|
| 76 |
img = self._to_rgb(img).resize((self.img_size, self.img_size))
|
| 77 |
x = np.asarray(img, dtype=np.float32)
|
| 78 |
-
if self.external_rescale:
|
|
|
|
| 79 |
return np.expand_dims(x, 0) # (1,H,W,3)
|
| 80 |
|
|
|
|
| 81 |
def predict_dict(self, img: Image.Image) -> Dict[str, float]:
|
|
|
|
| 82 |
t0 = time.perf_counter()
|
| 83 |
-
|
|
|
|
|
|
|
| 84 |
dt = (time.perf_counter() - t0) * 1000.0
|
| 85 |
print(f"[INF] {len(self.labels)}-class in {dt:.1f} ms")
|
| 86 |
-
return
|
| 87 |
|
| 88 |
# singleton
|
| 89 |
_MODEL = FaceShapeModel()
|
| 90 |
|
| 91 |
# ---------- Public API ----------
|
| 92 |
def predict(image: Image.Image) -> Dict[str, float]:
|
|
|
|
| 93 |
if image is None:
|
| 94 |
-
return {"Error":
|
| 95 |
return _MODEL.predict_dict(image)
|
| 96 |
|
| 97 |
def predict_batch(images: Iterable[Any]) -> List[Dict[str, float]]:
|
|
|
|
| 98 |
from PIL import Image as _PILImage
|
| 99 |
import os as _os
|
| 100 |
results: List[Dict[str, float]] = []
|
| 101 |
|
| 102 |
def _as_pil(x: Any) -> Optional[_PILImage.Image]:
|
| 103 |
-
if x is None:
|
| 104 |
-
|
|
|
|
|
|
|
| 105 |
if isinstance(x, (str, bytes, _os.PathLike)):
|
| 106 |
-
try:
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
for x in (images or []):
|
| 112 |
im = _as_pil(x)
|
| 113 |
-
results.append({"Error":
|
| 114 |
return results
|
| 115 |
|
| 116 |
__all__ = ["predict", "predict_batch"]
|
|
|
|
| 2 |
import os, json, time
|
| 3 |
from typing import List, Dict, Iterable, Any, Optional
|
| 4 |
import numpy as np
|
| 5 |
+
from PIL import Image, ImageFile
|
| 6 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 7 |
import tensorflow as tf
|
| 8 |
|
| 9 |
# ---------- Label utils ----------
|
| 10 |
+
_DEFAULT_LABELS = ["Heart", "Oblong", "Oval", "Round", "Square"]
|
| 11 |
|
| 12 |
def _load_labels() -> List[str]:
|
| 13 |
+
"""Prioritas: models/class_indices.json -> models/idx2class.json -> default."""
|
| 14 |
p_ci = os.path.join("models", "class_indices.json") # {"Label": idx}
|
| 15 |
+
p_i2c = os.path.join("models", "idx2class.json") # {"0": "Label"} atau ["Label", ...]
|
| 16 |
+
# 1) class_indices.json (label->idx)
|
| 17 |
try:
|
| 18 |
with open(p_ci, "r") as f:
|
| 19 |
ci = json.load(f)
|
| 20 |
+
if isinstance(ci, dict):
|
| 21 |
+
labels = [k for k, _ in sorted(ci.items(), key=lambda kv: kv[1])]
|
| 22 |
+
print("[LABEL] from class_indices.json ->", labels)
|
| 23 |
+
return labels
|
| 24 |
+
if isinstance(ci, list):
|
| 25 |
+
print("[LABEL] from class_indices.json (list) ->", ci)
|
| 26 |
+
return list(ci)
|
| 27 |
except Exception:
|
| 28 |
pass
|
| 29 |
+
# 2) idx2class.json (idx->label)
|
| 30 |
try:
|
| 31 |
with open(p_i2c, "r") as f:
|
| 32 |
i2c = json.load(f)
|
| 33 |
+
if isinstance(i2c, dict):
|
| 34 |
+
n = len(i2c)
|
| 35 |
+
labels = [i2c[str(i)] if str(i) in i2c else i2c[i] for i in range(n)]
|
| 36 |
+
print("[LABEL] from idx2class.json (dict) ->", labels)
|
| 37 |
+
return labels
|
| 38 |
+
if isinstance(i2c, list):
|
| 39 |
+
print("[LABEL] from idx2class.json (list) ->", i2c)
|
| 40 |
+
return list(i2c)
|
| 41 |
except Exception:
|
| 42 |
pass
|
| 43 |
print("[LABEL] fallback default ->", _DEFAULT_LABELS)
|
| 44 |
+
return list(_DEFAULT_LABELS)
|
| 45 |
|
| 46 |
def _generate_config_if_missing(model: tf.keras.Model, labels: List[str], path="config.json"):
|
| 47 |
+
"""Auto-tulis config.json jika belum ada."""
|
| 48 |
+
if os.path.exists(path):
|
| 49 |
+
return
|
| 50 |
ishape = model.input_shape
|
| 51 |
+
try:
|
| 52 |
+
h = int(ishape[1]); assert h > 0
|
| 53 |
+
except Exception as e:
|
| 54 |
+
raise AssertionError(f"Input shape tidak valid untuk config: {ishape}") from e
|
| 55 |
cfg = {
|
| 56 |
"architectures": ["EfficientNetB4"],
|
| 57 |
"image_size": h,
|
|
|
|
| 59 |
"id2label": {str(i): lbl for i, lbl in enumerate(labels)},
|
| 60 |
"label2id": {lbl: i for i, lbl in enumerate(labels)},
|
| 61 |
}
|
| 62 |
+
with open(path, "w") as f:
|
| 63 |
+
json.dump(cfg, f, indent=2)
|
| 64 |
print(f"[CFG] wrote {path} (image_size={h}, num_labels={len(labels)})")
|
| 65 |
|
| 66 |
# ---------- Model wrapper ----------
|
| 67 |
class FaceShapeModel:
|
| 68 |
def __init__(self, model_path="models/model.keras"):
|
| 69 |
+
self.labels: List[str] = _load_labels()
|
| 70 |
+
|
| 71 |
full_path = os.path.join(os.getcwd(), model_path)
|
| 72 |
print(f"[LOAD] {full_path}")
|
| 73 |
+
self.model: tf.keras.Model = tf.keras.models.load_model(full_path, compile=False)
|
| 74 |
|
| 75 |
+
# input size (H=W)
|
| 76 |
ishape = self.model.input_shape
|
| 77 |
+
self.img_size: int = int(ishape[1])
|
| 78 |
print(f"[MODEL] input img_size = {self.img_size}")
|
| 79 |
|
| 80 |
+
# apakah model sudah termasuk preprocessing internal?
|
| 81 |
+
names_lower = [l.name.lower() for l in self.model.layers[:12]]
|
| 82 |
+
has_internal_pp = any(("rescaling" in n) or ("normalization" in n) for n in names_lower)
|
| 83 |
+
self.external_rescale: bool = not has_internal_pp
|
| 84 |
+
print(f"[MODEL] internal_preproc={has_internal_pp} -> external_rescale={self.external_rescale}")
|
| 85 |
+
|
| 86 |
+
# sinkronkan jumlah label dengan jumlah output model
|
| 87 |
+
num_out = int(self.model.output_shape[-1])
|
| 88 |
+
if len(self.labels) != num_out:
|
| 89 |
+
print(f"[WARN] labels({len(self.labels)}) != model_out({num_out}) -> menyesuaikan.")
|
| 90 |
+
if len(self.labels) >= num_out:
|
| 91 |
+
self.labels = self.labels[:num_out]
|
| 92 |
+
else:
|
| 93 |
+
# pad label generik jika kurang
|
| 94 |
+
self.labels += [f"class_{i}" for i in range(len(self.labels), num_out)]
|
| 95 |
|
| 96 |
_generate_config_if_missing(self.model, self.labels)
|
| 97 |
|
| 98 |
+
# warmup (optional)
|
| 99 |
try:
|
| 100 |
_ = self.model(tf.zeros((1, self.img_size, self.img_size, 3), dtype=tf.float32))
|
| 101 |
except Exception as e:
|
| 102 |
print("[WARN] warmup failed:", e)
|
| 103 |
|
| 104 |
+
# ---- preprocessing ----
|
| 105 |
+
@staticmethod
|
| 106 |
+
def _to_rgb(img: Image.Image) -> Image.Image:
|
| 107 |
return img if img.mode == "RGB" else img.convert("RGB")
|
| 108 |
|
| 109 |
def _preprocess(self, img: Image.Image) -> np.ndarray:
|
| 110 |
img = self._to_rgb(img).resize((self.img_size, self.img_size))
|
| 111 |
x = np.asarray(img, dtype=np.float32)
|
| 112 |
+
if self.external_rescale:
|
| 113 |
+
x = x / 255.0
|
| 114 |
return np.expand_dims(x, 0) # (1,H,W,3)
|
| 115 |
|
| 116 |
+
# ---- predict ----
|
| 117 |
def predict_dict(self, img: Image.Image) -> Dict[str, float]:
|
| 118 |
+
"""Return dict {label: prob} untuk gr.Label."""
|
| 119 |
t0 = time.perf_counter()
|
| 120 |
+
probs = self.model.predict(self._preprocess(img), verbose=0)[0] # (C,)
|
| 121 |
+
# pastikan float murni
|
| 122 |
+
out = {lbl: float(p) for lbl, p in zip(self.labels, probs)}
|
| 123 |
dt = (time.perf_counter() - t0) * 1000.0
|
| 124 |
print(f"[INF] {len(self.labels)}-class in {dt:.1f} ms")
|
| 125 |
+
return out
|
| 126 |
|
| 127 |
# singleton
|
| 128 |
_MODEL = FaceShapeModel()
|
| 129 |
|
| 130 |
# ---------- Public API ----------
|
| 131 |
def predict(image: Image.Image) -> Dict[str, float]:
|
| 132 |
+
# NOTE: gr.Label butuh mapping label->float; jangan kirim string/list.
|
| 133 |
if image is None:
|
| 134 |
+
return {"Error": 1.0}
|
| 135 |
return _MODEL.predict_dict(image)
|
| 136 |
|
| 137 |
def predict_batch(images: Iterable[Any]) -> List[Dict[str, float]]:
|
| 138 |
+
"""Kembalikan list of dict label->prob; cocok untuk gr.JSON di tab Batch."""
|
| 139 |
from PIL import Image as _PILImage
|
| 140 |
import os as _os
|
| 141 |
results: List[Dict[str, float]] = []
|
| 142 |
|
| 143 |
def _as_pil(x: Any) -> Optional[_PILImage.Image]:
|
| 144 |
+
if x is None:
|
| 145 |
+
return None
|
| 146 |
+
if isinstance(x, _PILImage.Image):
|
| 147 |
+
return x
|
| 148 |
if isinstance(x, (str, bytes, _os.PathLike)):
|
| 149 |
+
try:
|
| 150 |
+
return _PILImage.open(x).convert("RGB")
|
| 151 |
+
except Exception:
|
| 152 |
+
return None
|
| 153 |
+
try:
|
| 154 |
+
return _PILImage.open(x).convert("RGB")
|
| 155 |
+
except Exception:
|
| 156 |
+
return None
|
| 157 |
|
| 158 |
for x in (images or []):
|
| 159 |
im = _as_pil(x)
|
| 160 |
+
results.append({"Error": 1.0} if im is None else _MODEL.predict_dict(im))
|
| 161 |
return results
|
| 162 |
|
| 163 |
__all__ = ["predict", "predict_batch"]
|