DimasMP3 commited on
Commit ·
5093d6d
1
Parent(s): e6e8622
add
Browse files- tools/gen_config.py +28 -0
tools/gen_config.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, json, tensorflow as tf
|
| 2 |
+
|
| 3 |
+
# 1) baca label dari class_indices.json (prioritas)
|
| 4 |
+
labels = ["Heart","Oblong","Oval","Round","Square"]
|
| 5 |
+
ci_path = os.path.join("models", "class_indices.json")
|
| 6 |
+
if os.path.exists(ci_path):
|
| 7 |
+
with open(ci_path, "r") as f:
|
| 8 |
+
ci = json.load(f) # {"Heart":0,...}
|
| 9 |
+
labels = [k for k,_ in sorted(ci.items(), key=lambda kv: kv[1])]
|
| 10 |
+
|
| 11 |
+
# 2) ambil ukuran input langsung dari model.keras
|
| 12 |
+
m = tf.keras.models.load_model(os.path.join("models","model.keras"), compile=False)
|
| 13 |
+
h, w = m.input_shape[1], m.input_shape[2]
|
| 14 |
+
assert h == w and h is not None, f"Input shape aneh: {m.input_shape}"
|
| 15 |
+
img_size = int(h)
|
| 16 |
+
|
| 17 |
+
# 3) tulis config.json
|
| 18 |
+
cfg = {
|
| 19 |
+
"architectures": ["EfficientNetB4"],
|
| 20 |
+
"image_size": img_size,
|
| 21 |
+
"num_labels": len(labels),
|
| 22 |
+
"id2label": {str(i): lbl for i, lbl in enumerate(labels)},
|
| 23 |
+
"label2id": {lbl: i for i, lbl in enumerate(labels)},
|
| 24 |
+
}
|
| 25 |
+
with open("config.json", "w") as f:
|
| 26 |
+
json.dump(cfg, f, indent=2)
|
| 27 |
+
|
| 28 |
+
print("Wrote config.json with image_size =", img_size, "and labels =", labels)
|