fruit-classification / inference.py
DimasMP3
Re-upload model with LFS fixed
1af914e
# inference.py
"""TensorFlow helpers for the fruit classification Hugging Face Space."""
from __future__ import annotations
import json
import os
import time
from typing import Any, Dict, Iterable, List, Optional
import numpy as np
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import tensorflow as tf
# ------------------- Label utilities -------------------
_LABEL_FILES = [
os.path.join("models", "class_names.json"),
os.path.join("models", "class_indices.json"),
os.path.join("models", "idx2class.json"),
]
_DEFAULT_LABELS = [
"Bean",
"Bitter_Gourd",
"Bottle_Gourd",
"Brinjal",
"Broccoli",
"Cabbage",
"Capsicum",
"Carrot",
"Cauliflower",
"Cucumber",
"Papaya",
"Potato",
"Pumpkin",
"Radish",
"Tomato",
]
def _normalize_labels(seq: Iterable[Any]) -> List[str]:
cleaned: List[str] = []
seen = set()
for label in seq:
if not isinstance(label, str):
continue
label = label.strip()
if not label or label.startswith("."):
continue
if label in seen:
continue
cleaned.append(label)
seen.add(label)
return cleaned
def _load_labels() -> List[str]:
def _is_digits(x: Any) -> bool:
try:
int(x)
return True
except (TypeError, ValueError):
return False
for path in _LABEL_FILES:
if not os.path.exists(path):
continue
try:
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
except Exception as exc:
print(f"[LABEL] failed to load {path}: {exc}")
continue
if isinstance(data, list):
labels = _normalize_labels(data)
if labels:
print(f"[LABEL] from {os.path.basename(path)} -> {labels}")
return labels
if isinstance(data, dict) and data:
# case A: {label: idx}
if all(_is_digits(v) for v in data.values()):
sorted_pairs = sorted(
((lbl, int(idx)) for lbl, idx in data.items()),
key=lambda item: item[1],
)
labels = _normalize_labels(lbl for lbl, _ in sorted_pairs)
if labels:
print(f"[LABEL] from {os.path.basename(path)} (label->idx) -> {labels}")
return labels
# case B: {idx: label}
if all(_is_digits(k) for k in data.keys()):
size = len(data)
ordered = [data.get(str(i), data.get(i)) for i in range(size)]
labels = _normalize_labels(ordered)
if labels:
print(f"[LABEL] from {os.path.basename(path)} (idx->label) -> {labels}")
return labels
print("[LABEL] fallback default ->", _DEFAULT_LABELS)
return list(_DEFAULT_LABELS)
def _generate_config_if_missing(model: tf.keras.Model, labels: List[str], path: str = "config.json") -> None:
if os.path.exists(path):
return
ishape = model.input_shape
try:
img_size = int(ishape[1])
except Exception as exc: # pragma: no cover - defensive only
raise AssertionError(f"Invalid input shape for config: {ishape}") from exc
cfg = {
"architectures": ["FruitCNN"],
"image_size": img_size,
"num_labels": len(labels),
"id2label": {str(i): lbl for i, lbl in enumerate(labels)},
"label2id": {lbl: i for i, lbl in enumerate(labels)},
}
with open(path, "w", encoding="utf-8") as f:
json.dump(cfg, f, indent=2)
print(f"[CFG] wrote {path} (image_size={img_size}, num_labels={len(labels)})")
# ------------------- Model wrapper -------------------
class FruitClassifier:
def __init__(self, model_path: str = "models/model_cnn.keras") -> None:
self.labels = _load_labels()
full_path = os.path.join(os.getcwd(), model_path)
print(f"[LOAD] {full_path}")
self.model: tf.keras.Model = tf.keras.models.load_model(full_path, compile=False)
ishape = self.model.input_shape
self.img_size = int(ishape[1])
print(f"[MODEL] input size = {self.img_size}")
names_lower = [layer.name.lower() for layer in self.model.layers[:12]]
has_internal_pp = any("rescaling" in n or "normalization" in n for n in names_lower)
self.external_rescale = not has_internal_pp
print(f"[MODEL] internal_preproc={has_internal_pp} -> external_rescale={self.external_rescale}")
num_outputs = int(self.model.output_shape[-1])
if num_outputs != len(self.labels):
print(f"[WARN] labels({len(self.labels)}) != outputs({num_outputs}) -> syncing")
if len(self.labels) >= num_outputs:
self.labels = self.labels[:num_outputs]
else:
for idx in range(len(self.labels), num_outputs):
self.labels.append(f"class_{idx}")
_generate_config_if_missing(self.model, self.labels)
try:
_ = self.model(tf.zeros((1, self.img_size, self.img_size, 3), dtype=tf.float32))
except Exception as exc:
print("[WARN] warmup failed:", exc)
@staticmethod
def _to_rgb(img: Image.Image) -> Image.Image:
return img if img.mode == "RGB" else img.convert("RGB")
def _preprocess(self, img: Image.Image) -> np.ndarray:
img = self._to_rgb(img).resize((self.img_size, self.img_size))
arr = np.asarray(img, dtype=np.float32)
if self.external_rescale:
arr = arr / 255.0
return np.expand_dims(arr, 0)
def predict_dict(self, img: Image.Image) -> Dict[str, float]:
t0 = time.perf_counter()
probs = self.model.predict(self._preprocess(img), verbose=0)[0]
result = {label: float(prob) for label, prob in zip(self.labels, probs)}
dt = (time.perf_counter() - t0) * 1000.0
print(f"[INF] {len(self.labels)} classes in {dt:.1f} ms")
return result
_MODEL = FruitClassifier()
# ------------------- Public API -------------------
def predict(image: Optional[Image.Image]) -> Dict[str, float]:
if image is None:
return {"Error": 1.0}
return _MODEL.predict_dict(image)
def predict_batch(images: Iterable[Any]) -> List[Dict[str, float]]:
from PIL import Image as PILImage
def _as_pil(obj: Any) -> Optional[PILImage.Image]:
if obj is None:
return None
if isinstance(obj, PILImage.Image):
return obj
try:
return PILImage.open(obj).convert("RGB")
except Exception:
return None
outputs: List[Dict[str, float]] = []
for item in images or []:
pil_img = _as_pil(item)
outputs.append({"Error": 1.0} if pil_img is None else _MODEL.predict_dict(pil_img))
return outputs
__all__ = ["predict", "predict_batch"]