# inference.py import os, json, time from typing import List, Dict, Iterable, Any, Optional import numpy as np from PIL import Image, ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True import tensorflow as tf # ---------- Label utils ---------- _DEFAULT_LABELS = ["Heart", "Oblong", "Oval", "Round", "Square"] def _load_labels() -> List[str]: """Prioritas: models/class_indices.json -> models/idx2class.json -> default.""" p_ci = os.path.join("models", "class_indices.json") # {"Label": idx} p_i2c = os.path.join("models", "idx2class.json") # {"0": "Label"} atau ["Label", ...] # 1) class_indices.json (label->idx) try: with open(p_ci, "r") as f: ci = json.load(f) if isinstance(ci, dict): labels = [k for k, _ in sorted(ci.items(), key=lambda kv: kv[1])] print("[LABEL] from class_indices.json ->", labels) return labels if isinstance(ci, list): print("[LABEL] from class_indices.json (list) ->", ci) return list(ci) except Exception: pass # 2) idx2class.json (idx->label) try: with open(p_i2c, "r") as f: i2c = json.load(f) if isinstance(i2c, dict): n = len(i2c) labels = [i2c[str(i)] if str(i) in i2c else i2c[i] for i in range(n)] print("[LABEL] from idx2class.json (dict) ->", labels) return labels if isinstance(i2c, list): print("[LABEL] from idx2class.json (list) ->", i2c) return list(i2c) except Exception: pass print("[LABEL] fallback default ->", _DEFAULT_LABELS) return list(_DEFAULT_LABELS) def _generate_config_if_missing(model: tf.keras.Model, labels: List[str], path="config.json"): """Auto-tulis config.json jika belum ada.""" if os.path.exists(path): return ishape = model.input_shape try: h = int(ishape[1]); assert h > 0 except Exception as e: raise AssertionError(f"Input shape tidak valid untuk config: {ishape}") from e cfg = { "architectures": ["EfficientNetB4"], "image_size": h, "num_labels": len(labels), "id2label": {str(i): lbl for i, lbl in enumerate(labels)}, "label2id": {lbl: i for i, lbl in enumerate(labels)}, } with open(path, "w") as f: json.dump(cfg, f, indent=2) print(f"[CFG] wrote {path} (image_size={h}, num_labels={len(labels)})") # ---------- Model wrapper ---------- class FaceShapeModel: def __init__(self, model_path="models/model.keras"): self.labels: List[str] = _load_labels() full_path = os.path.join(os.getcwd(), model_path) print(f"[LOAD] {full_path}") self.model: tf.keras.Model = tf.keras.models.load_model(full_path, compile=False) # input size (H=W) ishape = self.model.input_shape self.img_size: int = int(ishape[1]) print(f"[MODEL] input img_size = {self.img_size}") # apakah model sudah termasuk preprocessing internal? names_lower = [l.name.lower() for l in self.model.layers[:12]] has_internal_pp = any(("rescaling" in n) or ("normalization" in n) for n in names_lower) self.external_rescale: bool = not has_internal_pp print(f"[MODEL] internal_preproc={has_internal_pp} -> external_rescale={self.external_rescale}") # sinkronkan jumlah label dengan jumlah output model num_out = int(self.model.output_shape[-1]) if len(self.labels) != num_out: print(f"[WARN] labels({len(self.labels)}) != model_out({num_out}) -> menyesuaikan.") if len(self.labels) >= num_out: self.labels = self.labels[:num_out] else: # pad label generik jika kurang self.labels += [f"class_{i}" for i in range(len(self.labels), num_out)] _generate_config_if_missing(self.model, self.labels) # warmup (optional) try: _ = self.model(tf.zeros((1, self.img_size, self.img_size, 3), dtype=tf.float32)) except Exception as e: print("[WARN] warmup failed:", e) # ---- preprocessing ---- @staticmethod def _to_rgb(img: Image.Image) -> Image.Image: return img if img.mode == "RGB" else img.convert("RGB") def _preprocess(self, img: Image.Image) -> np.ndarray: img = self._to_rgb(img).resize((self.img_size, self.img_size)) x = np.asarray(img, dtype=np.float32) if self.external_rescale: x = x / 255.0 return np.expand_dims(x, 0) # (1,H,W,3) # ---- predict ---- def predict_dict(self, img: Image.Image) -> Dict[str, float]: """Return dict {label: prob} untuk gr.Label.""" t0 = time.perf_counter() probs = self.model.predict(self._preprocess(img), verbose=0)[0] # (C,) # pastikan float murni out = {lbl: float(p) for lbl, p in zip(self.labels, probs)} dt = (time.perf_counter() - t0) * 1000.0 print(f"[INF] {len(self.labels)}-class in {dt:.1f} ms") return out # singleton _MODEL = FaceShapeModel() # ---------- Public API ---------- def predict(image: Image.Image) -> Dict[str, float]: # NOTE: gr.Label butuh mapping label->float; jangan kirim string/list. if image is None: return {"Error": 1.0} return _MODEL.predict_dict(image) def predict_batch(images: Iterable[Any]) -> List[Dict[str, float]]: """Kembalikan list of dict label->prob; cocok untuk gr.JSON di tab Batch.""" from PIL import Image as _PILImage import os as _os results: List[Dict[str, float]] = [] def _as_pil(x: Any) -> Optional[_PILImage.Image]: if x is None: return None if isinstance(x, _PILImage.Image): return x if isinstance(x, (str, bytes, _os.PathLike)): try: return _PILImage.open(x).convert("RGB") except Exception: return None try: return _PILImage.open(x).convert("RGB") except Exception: return None for x in (images or []): im = _as_pil(x) results.append({"Error": 1.0} if im is None else _MODEL.predict_dict(im)) return results __all__ = ["predict", "predict_batch"]