DimasMP3 commited on
Commit
b02d758
·
1 Parent(s): d2f7145

add guard

Browse files
Files changed (1) hide show
  1. inference.py +80 -33
inference.py CHANGED
@@ -2,39 +2,56 @@
2
  import os, json, time
3
  from typing import List, Dict, Iterable, Any, Optional
4
  import numpy as np
5
- from PIL import Image
 
6
  import tensorflow as tf
7
 
8
  # ---------- Label utils ----------
9
- _DEFAULT_LABELS = ["Heart","Oblong","Oval","Round","Square"]
10
 
11
  def _load_labels() -> List[str]:
 
12
  p_ci = os.path.join("models", "class_indices.json") # {"Label": idx}
13
- p_i2c = os.path.join("models", "idx2class.json") # {"0":"Label"}
 
14
  try:
15
  with open(p_ci, "r") as f:
16
  ci = json.load(f)
17
- labels = [k for k,_ in sorted(ci.items(), key=lambda kv: kv[1])]
18
- print("[LABEL] from class_indices.json ->", labels)
19
- return labels
 
 
 
 
20
  except Exception:
21
  pass
 
22
  try:
23
  with open(p_i2c, "r") as f:
24
  i2c = json.load(f)
25
- labels = [i2c[str(i)] for i in range(len(i2c))]
26
- print("[LABEL] from idx2class.json ->", labels)
27
- return labels
 
 
 
 
 
28
  except Exception:
29
  pass
30
  print("[LABEL] fallback default ->", _DEFAULT_LABELS)
31
- return _DEFAULT_LABELS
32
 
33
  def _generate_config_if_missing(model: tf.keras.Model, labels: List[str], path="config.json"):
34
- if os.path.exists(path): return
 
 
35
  ishape = model.input_shape
36
- h = int(ishape[1])
37
- assert h > 0, f"Input shape aneh: {ishape}"
 
 
38
  cfg = {
39
  "architectures": ["EfficientNetB4"],
40
  "image_size": h,
@@ -42,75 +59,105 @@ def _generate_config_if_missing(model: tf.keras.Model, labels: List[str], path="
42
  "id2label": {str(i): lbl for i, lbl in enumerate(labels)},
43
  "label2id": {lbl: i for i, lbl in enumerate(labels)},
44
  }
45
- with open(path, "w") as f: json.dump(cfg, f, indent=2)
 
46
  print(f"[CFG] wrote {path} (image_size={h}, num_labels={len(labels)})")
47
 
48
  # ---------- Model wrapper ----------
49
  class FaceShapeModel:
50
  def __init__(self, model_path="models/model.keras"):
51
- self.labels = _load_labels()
 
52
  full_path = os.path.join(os.getcwd(), model_path)
53
  print(f"[LOAD] {full_path}")
54
- self.model = tf.keras.models.load_model(full_path, compile=False)
55
 
 
56
  ishape = self.model.input_shape
57
- self.img_size = int(ishape[1])
58
  print(f"[MODEL] input img_size = {self.img_size}")
59
 
60
- names_lower = [l.name.lower() for l in self.model.layers[:10]]
61
- has_pp = any(("rescaling" in n) or ("normalization" in n) for n in names_lower)
62
- self.external_rescale = not has_pp
63
- print(f"[MODEL] internal_preproc={has_pp} -> external_rescale={self.external_rescale}")
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  _generate_config_if_missing(self.model, self.labels)
66
 
 
67
  try:
68
  _ = self.model(tf.zeros((1, self.img_size, self.img_size, 3), dtype=tf.float32))
69
  except Exception as e:
70
  print("[WARN] warmup failed:", e)
71
 
72
- def _to_rgb(self, img: Image.Image) -> Image.Image:
 
 
73
  return img if img.mode == "RGB" else img.convert("RGB")
74
 
75
  def _preprocess(self, img: Image.Image) -> np.ndarray:
76
  img = self._to_rgb(img).resize((self.img_size, self.img_size))
77
  x = np.asarray(img, dtype=np.float32)
78
- if self.external_rescale: x = x / 255.0
 
79
  return np.expand_dims(x, 0) # (1,H,W,3)
80
 
 
81
  def predict_dict(self, img: Image.Image) -> Dict[str, float]:
 
82
  t0 = time.perf_counter()
83
- p = self.model.predict(self._preprocess(img), verbose=0)[0]
 
 
84
  dt = (time.perf_counter() - t0) * 1000.0
85
  print(f"[INF] {len(self.labels)}-class in {dt:.1f} ms")
86
- return {lbl: float(prob) for lbl, prob in zip(self.labels, p)}
87
 
88
  # singleton
89
  _MODEL = FaceShapeModel()
90
 
91
  # ---------- Public API ----------
92
  def predict(image: Image.Image) -> Dict[str, float]:
 
93
  if image is None:
94
- return {"Error": "No image"}
95
  return _MODEL.predict_dict(image)
96
 
97
  def predict_batch(images: Iterable[Any]) -> List[Dict[str, float]]:
 
98
  from PIL import Image as _PILImage
99
  import os as _os
100
  results: List[Dict[str, float]] = []
101
 
102
  def _as_pil(x: Any) -> Optional[_PILImage.Image]:
103
- if x is None: return None
104
- if isinstance(x, _PILImage.Image): return x
 
 
105
  if isinstance(x, (str, bytes, _os.PathLike)):
106
- try: return _PILImage.open(x).convert("RGB")
107
- except Exception: return None
108
- try: return _PILImage.open(x).convert("RGB")
109
- except Exception: return None
 
 
 
 
110
 
111
  for x in (images or []):
112
  im = _as_pil(x)
113
- results.append({"Error": "Invalid image"} if im is None else _MODEL.predict_dict(im))
114
  return results
115
 
116
  __all__ = ["predict", "predict_batch"]
 
2
  import os, json, time
3
  from typing import List, Dict, Iterable, Any, Optional
4
  import numpy as np
5
+ from PIL import Image, ImageFile
6
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
7
  import tensorflow as tf
8
 
9
  # ---------- Label utils ----------
10
+ _DEFAULT_LABELS = ["Heart", "Oblong", "Oval", "Round", "Square"]
11
 
12
  def _load_labels() -> List[str]:
13
+ """Prioritas: models/class_indices.json -> models/idx2class.json -> default."""
14
  p_ci = os.path.join("models", "class_indices.json") # {"Label": idx}
15
+ p_i2c = os.path.join("models", "idx2class.json") # {"0": "Label"} atau ["Label", ...]
16
+ # 1) class_indices.json (label->idx)
17
  try:
18
  with open(p_ci, "r") as f:
19
  ci = json.load(f)
20
+ if isinstance(ci, dict):
21
+ labels = [k for k, _ in sorted(ci.items(), key=lambda kv: kv[1])]
22
+ print("[LABEL] from class_indices.json ->", labels)
23
+ return labels
24
+ if isinstance(ci, list):
25
+ print("[LABEL] from class_indices.json (list) ->", ci)
26
+ return list(ci)
27
  except Exception:
28
  pass
29
+ # 2) idx2class.json (idx->label)
30
  try:
31
  with open(p_i2c, "r") as f:
32
  i2c = json.load(f)
33
+ if isinstance(i2c, dict):
34
+ n = len(i2c)
35
+ labels = [i2c[str(i)] if str(i) in i2c else i2c[i] for i in range(n)]
36
+ print("[LABEL] from idx2class.json (dict) ->", labels)
37
+ return labels
38
+ if isinstance(i2c, list):
39
+ print("[LABEL] from idx2class.json (list) ->", i2c)
40
+ return list(i2c)
41
  except Exception:
42
  pass
43
  print("[LABEL] fallback default ->", _DEFAULT_LABELS)
44
+ return list(_DEFAULT_LABELS)
45
 
46
  def _generate_config_if_missing(model: tf.keras.Model, labels: List[str], path="config.json"):
47
+ """Auto-tulis config.json jika belum ada."""
48
+ if os.path.exists(path):
49
+ return
50
  ishape = model.input_shape
51
+ try:
52
+ h = int(ishape[1]); assert h > 0
53
+ except Exception as e:
54
+ raise AssertionError(f"Input shape tidak valid untuk config: {ishape}") from e
55
  cfg = {
56
  "architectures": ["EfficientNetB4"],
57
  "image_size": h,
 
59
  "id2label": {str(i): lbl for i, lbl in enumerate(labels)},
60
  "label2id": {lbl: i for i, lbl in enumerate(labels)},
61
  }
62
+ with open(path, "w") as f:
63
+ json.dump(cfg, f, indent=2)
64
  print(f"[CFG] wrote {path} (image_size={h}, num_labels={len(labels)})")
65
 
66
  # ---------- Model wrapper ----------
67
  class FaceShapeModel:
68
  def __init__(self, model_path="models/model.keras"):
69
+ self.labels: List[str] = _load_labels()
70
+
71
  full_path = os.path.join(os.getcwd(), model_path)
72
  print(f"[LOAD] {full_path}")
73
+ self.model: tf.keras.Model = tf.keras.models.load_model(full_path, compile=False)
74
 
75
+ # input size (H=W)
76
  ishape = self.model.input_shape
77
+ self.img_size: int = int(ishape[1])
78
  print(f"[MODEL] input img_size = {self.img_size}")
79
 
80
+ # apakah model sudah termasuk preprocessing internal?
81
+ names_lower = [l.name.lower() for l in self.model.layers[:12]]
82
+ has_internal_pp = any(("rescaling" in n) or ("normalization" in n) for n in names_lower)
83
+ self.external_rescale: bool = not has_internal_pp
84
+ print(f"[MODEL] internal_preproc={has_internal_pp} -> external_rescale={self.external_rescale}")
85
+
86
+ # sinkronkan jumlah label dengan jumlah output model
87
+ num_out = int(self.model.output_shape[-1])
88
+ if len(self.labels) != num_out:
89
+ print(f"[WARN] labels({len(self.labels)}) != model_out({num_out}) -> menyesuaikan.")
90
+ if len(self.labels) >= num_out:
91
+ self.labels = self.labels[:num_out]
92
+ else:
93
+ # pad label generik jika kurang
94
+ self.labels += [f"class_{i}" for i in range(len(self.labels), num_out)]
95
 
96
  _generate_config_if_missing(self.model, self.labels)
97
 
98
+ # warmup (optional)
99
  try:
100
  _ = self.model(tf.zeros((1, self.img_size, self.img_size, 3), dtype=tf.float32))
101
  except Exception as e:
102
  print("[WARN] warmup failed:", e)
103
 
104
+ # ---- preprocessing ----
105
+ @staticmethod
106
+ def _to_rgb(img: Image.Image) -> Image.Image:
107
  return img if img.mode == "RGB" else img.convert("RGB")
108
 
109
  def _preprocess(self, img: Image.Image) -> np.ndarray:
110
  img = self._to_rgb(img).resize((self.img_size, self.img_size))
111
  x = np.asarray(img, dtype=np.float32)
112
+ if self.external_rescale:
113
+ x = x / 255.0
114
  return np.expand_dims(x, 0) # (1,H,W,3)
115
 
116
+ # ---- predict ----
117
  def predict_dict(self, img: Image.Image) -> Dict[str, float]:
118
+ """Return dict {label: prob} untuk gr.Label."""
119
  t0 = time.perf_counter()
120
+ probs = self.model.predict(self._preprocess(img), verbose=0)[0] # (C,)
121
+ # pastikan float murni
122
+ out = {lbl: float(p) for lbl, p in zip(self.labels, probs)}
123
  dt = (time.perf_counter() - t0) * 1000.0
124
  print(f"[INF] {len(self.labels)}-class in {dt:.1f} ms")
125
+ return out
126
 
127
  # singleton
128
  _MODEL = FaceShapeModel()
129
 
130
  # ---------- Public API ----------
131
  def predict(image: Image.Image) -> Dict[str, float]:
132
+ # NOTE: gr.Label butuh mapping label->float; jangan kirim string/list.
133
  if image is None:
134
+ return {"Error": 1.0}
135
  return _MODEL.predict_dict(image)
136
 
137
  def predict_batch(images: Iterable[Any]) -> List[Dict[str, float]]:
138
+ """Kembalikan list of dict label->prob; cocok untuk gr.JSON di tab Batch."""
139
  from PIL import Image as _PILImage
140
  import os as _os
141
  results: List[Dict[str, float]] = []
142
 
143
  def _as_pil(x: Any) -> Optional[_PILImage.Image]:
144
+ if x is None:
145
+ return None
146
+ if isinstance(x, _PILImage.Image):
147
+ return x
148
  if isinstance(x, (str, bytes, _os.PathLike)):
149
+ try:
150
+ return _PILImage.open(x).convert("RGB")
151
+ except Exception:
152
+ return None
153
+ try:
154
+ return _PILImage.open(x).convert("RGB")
155
+ except Exception:
156
+ return None
157
 
158
  for x in (images or []):
159
  im = _as_pil(x)
160
+ results.append({"Error": 1.0} if im is None else _MODEL.predict_dict(im))
161
  return results
162
 
163
  __all__ = ["predict", "predict_batch"]