File size: 6,292 Bytes
318b10c
 
ac7382e
a722bd4
b02d758
 
a722bd4
 
318b10c
b02d758
ac7382e
 
b02d758
318b10c
b02d758
 
d5412a0
 
ac7382e
b02d758
 
 
 
 
 
 
d5412a0
 
b02d758
d5412a0
 
ac7382e
b02d758
 
 
 
 
 
 
 
d5412a0
ac7382e
 
b02d758
ac7382e
318b10c
b02d758
 
 
486e475
b02d758
 
 
 
486e475
 
ac7382e
486e475
 
ac7382e
486e475
b02d758
 
ac7382e
 
318b10c
a722bd4
318b10c
b02d758
 
d5412a0
 
b02d758
d5412a0
b02d758
d5412a0
b02d758
d5412a0
 
b02d758
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486e475
ac7382e
d5412a0
b02d758
a722bd4
d5412a0
a722bd4
d5412a0
 
b02d758
 
 
ac7382e
 
486e475
318b10c
486e475
b02d758
 
318b10c
a722bd4
b02d758
ac7382e
b02d758
3794fed
b02d758
 
 
ac7382e
 
b02d758
8c7f090
318b10c
ac7382e
 
318b10c
 
b02d758
318b10c
b02d758
318b10c
 
 
b02d758
318b10c
 
 
 
 
b02d758
 
 
 
318b10c
b02d758
 
 
 
 
 
 
 
318b10c
 
 
b02d758
318b10c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# inference.py
import os, json, time
from typing import List, Dict, Iterable, Any, Optional
import numpy as np
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import tensorflow as tf

# ---------- Label utils ----------
_DEFAULT_LABELS = ["Heart", "Oblong", "Oval", "Round", "Square"]

def _load_labels() -> List[str]:
    """Prioritas: models/class_indices.json -> models/idx2class.json -> default."""
    p_ci  = os.path.join("models", "class_indices.json")   # {"Label": idx}
    p_i2c = os.path.join("models", "idx2class.json")       # {"0": "Label"} atau ["Label", ...]
    # 1) class_indices.json (label->idx)
    try:
        with open(p_ci, "r") as f:
            ci = json.load(f)
        if isinstance(ci, dict):
            labels = [k for k, _ in sorted(ci.items(), key=lambda kv: kv[1])]
            print("[LABEL] from class_indices.json ->", labels)
            return labels
        if isinstance(ci, list):
            print("[LABEL] from class_indices.json (list) ->", ci)
            return list(ci)
    except Exception:
        pass
    # 2) idx2class.json (idx->label)
    try:
        with open(p_i2c, "r") as f:
            i2c = json.load(f)
        if isinstance(i2c, dict):
            n = len(i2c)
            labels = [i2c[str(i)] if str(i) in i2c else i2c[i] for i in range(n)]
            print("[LABEL] from idx2class.json (dict) ->", labels)
            return labels
        if isinstance(i2c, list):
            print("[LABEL] from idx2class.json (list) ->", i2c)
            return list(i2c)
    except Exception:
        pass
    print("[LABEL] fallback default ->", _DEFAULT_LABELS)
    return list(_DEFAULT_LABELS)

def _generate_config_if_missing(model: tf.keras.Model, labels: List[str], path="config.json"):
    """Auto-tulis config.json jika belum ada."""
    if os.path.exists(path):
        return
    ishape = model.input_shape
    try:
        h = int(ishape[1]); assert h > 0
    except Exception as e:
        raise AssertionError(f"Input shape tidak valid untuk config: {ishape}") from e
    cfg = {
        "architectures": ["EfficientNetB4"],
        "image_size": h,
        "num_labels": len(labels),
        "id2label": {str(i): lbl for i, lbl in enumerate(labels)},
        "label2id": {lbl: i for i, lbl in enumerate(labels)},
    }
    with open(path, "w") as f:
        json.dump(cfg, f, indent=2)
    print(f"[CFG] wrote {path} (image_size={h}, num_labels={len(labels)})")

# ---------- Model wrapper ----------
class FaceShapeModel:
    def __init__(self, model_path="models/model.keras"):
        self.labels: List[str] = _load_labels()

        full_path = os.path.join(os.getcwd(), model_path)
        print(f"[LOAD] {full_path}")
        self.model: tf.keras.Model = tf.keras.models.load_model(full_path, compile=False)

        # input size (H=W)
        ishape = self.model.input_shape
        self.img_size: int = int(ishape[1])
        print(f"[MODEL] input img_size = {self.img_size}")

        # apakah model sudah termasuk preprocessing internal?
        names_lower = [l.name.lower() for l in self.model.layers[:12]]
        has_internal_pp = any(("rescaling" in n) or ("normalization" in n) for n in names_lower)
        self.external_rescale: bool = not has_internal_pp
        print(f"[MODEL] internal_preproc={has_internal_pp} -> external_rescale={self.external_rescale}")

        # sinkronkan jumlah label dengan jumlah output model
        num_out = int(self.model.output_shape[-1])
        if len(self.labels) != num_out:
            print(f"[WARN] labels({len(self.labels)}) != model_out({num_out}) -> menyesuaikan.")
            if len(self.labels) >= num_out:
                self.labels = self.labels[:num_out]
            else:
                # pad label generik jika kurang
                self.labels += [f"class_{i}" for i in range(len(self.labels), num_out)]

        _generate_config_if_missing(self.model, self.labels)

        # warmup (optional)
        try:
            _ = self.model(tf.zeros((1, self.img_size, self.img_size, 3), dtype=tf.float32))
        except Exception as e:
            print("[WARN] warmup failed:", e)

    # ---- preprocessing ----
    @staticmethod
    def _to_rgb(img: Image.Image) -> Image.Image:
        return img if img.mode == "RGB" else img.convert("RGB")

    def _preprocess(self, img: Image.Image) -> np.ndarray:
        img = self._to_rgb(img).resize((self.img_size, self.img_size))
        x = np.asarray(img, dtype=np.float32)
        if self.external_rescale:
            x = x / 255.0
        return np.expand_dims(x, 0)  # (1,H,W,3)

    # ---- predict ----
    def predict_dict(self, img: Image.Image) -> Dict[str, float]:
        """Return dict {label: prob} untuk gr.Label."""
        t0 = time.perf_counter()
        probs = self.model.predict(self._preprocess(img), verbose=0)[0]  # (C,)
        # pastikan float murni
        out = {lbl: float(p) for lbl, p in zip(self.labels, probs)}
        dt = (time.perf_counter() - t0) * 1000.0
        print(f"[INF] {len(self.labels)}-class in {dt:.1f} ms")
        return out

# singleton
_MODEL = FaceShapeModel()

# ---------- Public API ----------
def predict(image: Image.Image) -> Dict[str, float]:
    # NOTE: gr.Label butuh mapping label->float; jangan kirim string/list.
    if image is None:
        return {"Error": 1.0}
    return _MODEL.predict_dict(image)

def predict_batch(images: Iterable[Any]) -> List[Dict[str, float]]:
    """Kembalikan list of dict label->prob; cocok untuk gr.JSON di tab Batch."""
    from PIL import Image as _PILImage
    import os as _os
    results: List[Dict[str, float]] = []

    def _as_pil(x: Any) -> Optional[_PILImage.Image]:
        if x is None:
            return None
        if isinstance(x, _PILImage.Image):
            return x
        if isinstance(x, (str, bytes, _os.PathLike)):
            try:
                return _PILImage.open(x).convert("RGB")
            except Exception:
                return None
        try:
            return _PILImage.open(x).convert("RGB")
        except Exception:
            return None

    for x in (images or []):
        im = _as_pil(x)
        results.append({"Error": 1.0} if im is None else _MODEL.predict_dict(im))
    return results

__all__ = ["predict", "predict_batch"]