File size: 6,292 Bytes
318b10c ac7382e a722bd4 b02d758 a722bd4 318b10c b02d758 ac7382e b02d758 318b10c b02d758 d5412a0 ac7382e b02d758 d5412a0 b02d758 d5412a0 ac7382e b02d758 d5412a0 ac7382e b02d758 ac7382e 318b10c b02d758 486e475 b02d758 486e475 ac7382e 486e475 ac7382e 486e475 b02d758 ac7382e 318b10c a722bd4 318b10c b02d758 d5412a0 b02d758 d5412a0 b02d758 d5412a0 b02d758 d5412a0 b02d758 486e475 ac7382e d5412a0 b02d758 a722bd4 d5412a0 a722bd4 d5412a0 b02d758 ac7382e 486e475 318b10c 486e475 b02d758 318b10c a722bd4 b02d758 ac7382e b02d758 3794fed b02d758 ac7382e b02d758 8c7f090 318b10c ac7382e 318b10c b02d758 318b10c b02d758 318b10c b02d758 318b10c b02d758 318b10c b02d758 318b10c b02d758 318b10c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | # inference.py
import os, json, time
from typing import List, Dict, Iterable, Any, Optional
import numpy as np
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import tensorflow as tf
# ---------- Label utils ----------
_DEFAULT_LABELS = ["Heart", "Oblong", "Oval", "Round", "Square"]
def _load_labels() -> List[str]:
"""Prioritas: models/class_indices.json -> models/idx2class.json -> default."""
p_ci = os.path.join("models", "class_indices.json") # {"Label": idx}
p_i2c = os.path.join("models", "idx2class.json") # {"0": "Label"} atau ["Label", ...]
# 1) class_indices.json (label->idx)
try:
with open(p_ci, "r") as f:
ci = json.load(f)
if isinstance(ci, dict):
labels = [k for k, _ in sorted(ci.items(), key=lambda kv: kv[1])]
print("[LABEL] from class_indices.json ->", labels)
return labels
if isinstance(ci, list):
print("[LABEL] from class_indices.json (list) ->", ci)
return list(ci)
except Exception:
pass
# 2) idx2class.json (idx->label)
try:
with open(p_i2c, "r") as f:
i2c = json.load(f)
if isinstance(i2c, dict):
n = len(i2c)
labels = [i2c[str(i)] if str(i) in i2c else i2c[i] for i in range(n)]
print("[LABEL] from idx2class.json (dict) ->", labels)
return labels
if isinstance(i2c, list):
print("[LABEL] from idx2class.json (list) ->", i2c)
return list(i2c)
except Exception:
pass
print("[LABEL] fallback default ->", _DEFAULT_LABELS)
return list(_DEFAULT_LABELS)
def _generate_config_if_missing(model: tf.keras.Model, labels: List[str], path="config.json"):
"""Auto-tulis config.json jika belum ada."""
if os.path.exists(path):
return
ishape = model.input_shape
try:
h = int(ishape[1]); assert h > 0
except Exception as e:
raise AssertionError(f"Input shape tidak valid untuk config: {ishape}") from e
cfg = {
"architectures": ["EfficientNetB4"],
"image_size": h,
"num_labels": len(labels),
"id2label": {str(i): lbl for i, lbl in enumerate(labels)},
"label2id": {lbl: i for i, lbl in enumerate(labels)},
}
with open(path, "w") as f:
json.dump(cfg, f, indent=2)
print(f"[CFG] wrote {path} (image_size={h}, num_labels={len(labels)})")
# ---------- Model wrapper ----------
class FaceShapeModel:
def __init__(self, model_path="models/model.keras"):
self.labels: List[str] = _load_labels()
full_path = os.path.join(os.getcwd(), model_path)
print(f"[LOAD] {full_path}")
self.model: tf.keras.Model = tf.keras.models.load_model(full_path, compile=False)
# input size (H=W)
ishape = self.model.input_shape
self.img_size: int = int(ishape[1])
print(f"[MODEL] input img_size = {self.img_size}")
# apakah model sudah termasuk preprocessing internal?
names_lower = [l.name.lower() for l in self.model.layers[:12]]
has_internal_pp = any(("rescaling" in n) or ("normalization" in n) for n in names_lower)
self.external_rescale: bool = not has_internal_pp
print(f"[MODEL] internal_preproc={has_internal_pp} -> external_rescale={self.external_rescale}")
# sinkronkan jumlah label dengan jumlah output model
num_out = int(self.model.output_shape[-1])
if len(self.labels) != num_out:
print(f"[WARN] labels({len(self.labels)}) != model_out({num_out}) -> menyesuaikan.")
if len(self.labels) >= num_out:
self.labels = self.labels[:num_out]
else:
# pad label generik jika kurang
self.labels += [f"class_{i}" for i in range(len(self.labels), num_out)]
_generate_config_if_missing(self.model, self.labels)
# warmup (optional)
try:
_ = self.model(tf.zeros((1, self.img_size, self.img_size, 3), dtype=tf.float32))
except Exception as e:
print("[WARN] warmup failed:", e)
# ---- preprocessing ----
@staticmethod
def _to_rgb(img: Image.Image) -> Image.Image:
return img if img.mode == "RGB" else img.convert("RGB")
def _preprocess(self, img: Image.Image) -> np.ndarray:
img = self._to_rgb(img).resize((self.img_size, self.img_size))
x = np.asarray(img, dtype=np.float32)
if self.external_rescale:
x = x / 255.0
return np.expand_dims(x, 0) # (1,H,W,3)
# ---- predict ----
def predict_dict(self, img: Image.Image) -> Dict[str, float]:
"""Return dict {label: prob} untuk gr.Label."""
t0 = time.perf_counter()
probs = self.model.predict(self._preprocess(img), verbose=0)[0] # (C,)
# pastikan float murni
out = {lbl: float(p) for lbl, p in zip(self.labels, probs)}
dt = (time.perf_counter() - t0) * 1000.0
print(f"[INF] {len(self.labels)}-class in {dt:.1f} ms")
return out
# singleton
_MODEL = FaceShapeModel()
# ---------- Public API ----------
def predict(image: Image.Image) -> Dict[str, float]:
# NOTE: gr.Label butuh mapping label->float; jangan kirim string/list.
if image is None:
return {"Error": 1.0}
return _MODEL.predict_dict(image)
def predict_batch(images: Iterable[Any]) -> List[Dict[str, float]]:
"""Kembalikan list of dict label->prob; cocok untuk gr.JSON di tab Batch."""
from PIL import Image as _PILImage
import os as _os
results: List[Dict[str, float]] = []
def _as_pil(x: Any) -> Optional[_PILImage.Image]:
if x is None:
return None
if isinstance(x, _PILImage.Image):
return x
if isinstance(x, (str, bytes, _os.PathLike)):
try:
return _PILImage.open(x).convert("RGB")
except Exception:
return None
try:
return _PILImage.open(x).convert("RGB")
except Exception:
return None
for x in (images or []):
im = _as_pil(x)
results.append({"Error": 1.0} if im is None else _MODEL.predict_dict(im))
return results
__all__ = ["predict", "predict_batch"]
|