DimasMP3
add guard
b02d758
# inference.py
import os, json, time
from typing import List, Dict, Iterable, Any, Optional
import numpy as np
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import tensorflow as tf
# ---------- Label utils ----------
_DEFAULT_LABELS = ["Heart", "Oblong", "Oval", "Round", "Square"]
def _load_labels() -> List[str]:
"""Prioritas: models/class_indices.json -> models/idx2class.json -> default."""
p_ci = os.path.join("models", "class_indices.json") # {"Label": idx}
p_i2c = os.path.join("models", "idx2class.json") # {"0": "Label"} atau ["Label", ...]
# 1) class_indices.json (label->idx)
try:
with open(p_ci, "r") as f:
ci = json.load(f)
if isinstance(ci, dict):
labels = [k for k, _ in sorted(ci.items(), key=lambda kv: kv[1])]
print("[LABEL] from class_indices.json ->", labels)
return labels
if isinstance(ci, list):
print("[LABEL] from class_indices.json (list) ->", ci)
return list(ci)
except Exception:
pass
# 2) idx2class.json (idx->label)
try:
with open(p_i2c, "r") as f:
i2c = json.load(f)
if isinstance(i2c, dict):
n = len(i2c)
labels = [i2c[str(i)] if str(i) in i2c else i2c[i] for i in range(n)]
print("[LABEL] from idx2class.json (dict) ->", labels)
return labels
if isinstance(i2c, list):
print("[LABEL] from idx2class.json (list) ->", i2c)
return list(i2c)
except Exception:
pass
print("[LABEL] fallback default ->", _DEFAULT_LABELS)
return list(_DEFAULT_LABELS)
def _generate_config_if_missing(model: tf.keras.Model, labels: List[str], path="config.json"):
"""Auto-tulis config.json jika belum ada."""
if os.path.exists(path):
return
ishape = model.input_shape
try:
h = int(ishape[1]); assert h > 0
except Exception as e:
raise AssertionError(f"Input shape tidak valid untuk config: {ishape}") from e
cfg = {
"architectures": ["EfficientNetB4"],
"image_size": h,
"num_labels": len(labels),
"id2label": {str(i): lbl for i, lbl in enumerate(labels)},
"label2id": {lbl: i for i, lbl in enumerate(labels)},
}
with open(path, "w") as f:
json.dump(cfg, f, indent=2)
print(f"[CFG] wrote {path} (image_size={h}, num_labels={len(labels)})")
# ---------- Model wrapper ----------
class FaceShapeModel:
def __init__(self, model_path="models/model.keras"):
self.labels: List[str] = _load_labels()
full_path = os.path.join(os.getcwd(), model_path)
print(f"[LOAD] {full_path}")
self.model: tf.keras.Model = tf.keras.models.load_model(full_path, compile=False)
# input size (H=W)
ishape = self.model.input_shape
self.img_size: int = int(ishape[1])
print(f"[MODEL] input img_size = {self.img_size}")
# apakah model sudah termasuk preprocessing internal?
names_lower = [l.name.lower() for l in self.model.layers[:12]]
has_internal_pp = any(("rescaling" in n) or ("normalization" in n) for n in names_lower)
self.external_rescale: bool = not has_internal_pp
print(f"[MODEL] internal_preproc={has_internal_pp} -> external_rescale={self.external_rescale}")
# sinkronkan jumlah label dengan jumlah output model
num_out = int(self.model.output_shape[-1])
if len(self.labels) != num_out:
print(f"[WARN] labels({len(self.labels)}) != model_out({num_out}) -> menyesuaikan.")
if len(self.labels) >= num_out:
self.labels = self.labels[:num_out]
else:
# pad label generik jika kurang
self.labels += [f"class_{i}" for i in range(len(self.labels), num_out)]
_generate_config_if_missing(self.model, self.labels)
# warmup (optional)
try:
_ = self.model(tf.zeros((1, self.img_size, self.img_size, 3), dtype=tf.float32))
except Exception as e:
print("[WARN] warmup failed:", e)
# ---- preprocessing ----
@staticmethod
def _to_rgb(img: Image.Image) -> Image.Image:
return img if img.mode == "RGB" else img.convert("RGB")
def _preprocess(self, img: Image.Image) -> np.ndarray:
img = self._to_rgb(img).resize((self.img_size, self.img_size))
x = np.asarray(img, dtype=np.float32)
if self.external_rescale:
x = x / 255.0
return np.expand_dims(x, 0) # (1,H,W,3)
# ---- predict ----
def predict_dict(self, img: Image.Image) -> Dict[str, float]:
"""Return dict {label: prob} untuk gr.Label."""
t0 = time.perf_counter()
probs = self.model.predict(self._preprocess(img), verbose=0)[0] # (C,)
# pastikan float murni
out = {lbl: float(p) for lbl, p in zip(self.labels, probs)}
dt = (time.perf_counter() - t0) * 1000.0
print(f"[INF] {len(self.labels)}-class in {dt:.1f} ms")
return out
# singleton
_MODEL = FaceShapeModel()
# ---------- Public API ----------
def predict(image: Image.Image) -> Dict[str, float]:
# NOTE: gr.Label butuh mapping label->float; jangan kirim string/list.
if image is None:
return {"Error": 1.0}
return _MODEL.predict_dict(image)
def predict_batch(images: Iterable[Any]) -> List[Dict[str, float]]:
"""Kembalikan list of dict label->prob; cocok untuk gr.JSON di tab Batch."""
from PIL import Image as _PILImage
import os as _os
results: List[Dict[str, float]] = []
def _as_pil(x: Any) -> Optional[_PILImage.Image]:
if x is None:
return None
if isinstance(x, _PILImage.Image):
return x
if isinstance(x, (str, bytes, _os.PathLike)):
try:
return _PILImage.open(x).convert("RGB")
except Exception:
return None
try:
return _PILImage.open(x).convert("RGB")
except Exception:
return None
for x in (images or []):
im = _as_pil(x)
results.append({"Error": 1.0} if im is None else _MODEL.predict_dict(im))
return results
__all__ = ["predict", "predict_batch"]