# inference.py """TensorFlow helpers for the fruit classification Hugging Face Space.""" from __future__ import annotations import json import os import time from typing import Any, Dict, Iterable, List, Optional import numpy as np from PIL import Image, ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True import tensorflow as tf # ------------------- Label utilities ------------------- _LABEL_FILES = [ os.path.join("models", "class_names.json"), os.path.join("models", "class_indices.json"), os.path.join("models", "idx2class.json"), ] _DEFAULT_LABELS = [ "Bean", "Bitter_Gourd", "Bottle_Gourd", "Brinjal", "Broccoli", "Cabbage", "Capsicum", "Carrot", "Cauliflower", "Cucumber", "Papaya", "Potato", "Pumpkin", "Radish", "Tomato", ] def _normalize_labels(seq: Iterable[Any]) -> List[str]: cleaned: List[str] = [] seen = set() for label in seq: if not isinstance(label, str): continue label = label.strip() if not label or label.startswith("."): continue if label in seen: continue cleaned.append(label) seen.add(label) return cleaned def _load_labels() -> List[str]: def _is_digits(x: Any) -> bool: try: int(x) return True except (TypeError, ValueError): return False for path in _LABEL_FILES: if not os.path.exists(path): continue try: with open(path, "r", encoding="utf-8") as f: data = json.load(f) except Exception as exc: print(f"[LABEL] failed to load {path}: {exc}") continue if isinstance(data, list): labels = _normalize_labels(data) if labels: print(f"[LABEL] from {os.path.basename(path)} -> {labels}") return labels if isinstance(data, dict) and data: # case A: {label: idx} if all(_is_digits(v) for v in data.values()): sorted_pairs = sorted( ((lbl, int(idx)) for lbl, idx in data.items()), key=lambda item: item[1], ) labels = _normalize_labels(lbl for lbl, _ in sorted_pairs) if labels: print(f"[LABEL] from {os.path.basename(path)} (label->idx) -> {labels}") return labels # case B: {idx: label} if all(_is_digits(k) for k in data.keys()): size = len(data) ordered = [data.get(str(i), data.get(i)) for i in range(size)] labels = _normalize_labels(ordered) if labels: print(f"[LABEL] from {os.path.basename(path)} (idx->label) -> {labels}") return labels print("[LABEL] fallback default ->", _DEFAULT_LABELS) return list(_DEFAULT_LABELS) def _generate_config_if_missing(model: tf.keras.Model, labels: List[str], path: str = "config.json") -> None: if os.path.exists(path): return ishape = model.input_shape try: img_size = int(ishape[1]) except Exception as exc: # pragma: no cover - defensive only raise AssertionError(f"Invalid input shape for config: {ishape}") from exc cfg = { "architectures": ["FruitCNN"], "image_size": img_size, "num_labels": len(labels), "id2label": {str(i): lbl for i, lbl in enumerate(labels)}, "label2id": {lbl: i for i, lbl in enumerate(labels)}, } with open(path, "w", encoding="utf-8") as f: json.dump(cfg, f, indent=2) print(f"[CFG] wrote {path} (image_size={img_size}, num_labels={len(labels)})") # ------------------- Model wrapper ------------------- class FruitClassifier: def __init__(self, model_path: str = "models/model_cnn.keras") -> None: self.labels = _load_labels() full_path = os.path.join(os.getcwd(), model_path) print(f"[LOAD] {full_path}") self.model: tf.keras.Model = tf.keras.models.load_model(full_path, compile=False) ishape = self.model.input_shape self.img_size = int(ishape[1]) print(f"[MODEL] input size = {self.img_size}") names_lower = [layer.name.lower() for layer in self.model.layers[:12]] has_internal_pp = any("rescaling" in n or "normalization" in n for n in names_lower) self.external_rescale = not has_internal_pp print(f"[MODEL] internal_preproc={has_internal_pp} -> external_rescale={self.external_rescale}") num_outputs = int(self.model.output_shape[-1]) if num_outputs != len(self.labels): print(f"[WARN] labels({len(self.labels)}) != outputs({num_outputs}) -> syncing") if len(self.labels) >= num_outputs: self.labels = self.labels[:num_outputs] else: for idx in range(len(self.labels), num_outputs): self.labels.append(f"class_{idx}") _generate_config_if_missing(self.model, self.labels) try: _ = self.model(tf.zeros((1, self.img_size, self.img_size, 3), dtype=tf.float32)) except Exception as exc: print("[WARN] warmup failed:", exc) @staticmethod def _to_rgb(img: Image.Image) -> Image.Image: return img if img.mode == "RGB" else img.convert("RGB") def _preprocess(self, img: Image.Image) -> np.ndarray: img = self._to_rgb(img).resize((self.img_size, self.img_size)) arr = np.asarray(img, dtype=np.float32) if self.external_rescale: arr = arr / 255.0 return np.expand_dims(arr, 0) def predict_dict(self, img: Image.Image) -> Dict[str, float]: t0 = time.perf_counter() probs = self.model.predict(self._preprocess(img), verbose=0)[0] result = {label: float(prob) for label, prob in zip(self.labels, probs)} dt = (time.perf_counter() - t0) * 1000.0 print(f"[INF] {len(self.labels)} classes in {dt:.1f} ms") return result _MODEL = FruitClassifier() # ------------------- Public API ------------------- def predict(image: Optional[Image.Image]) -> Dict[str, float]: if image is None: return {"Error": 1.0} return _MODEL.predict_dict(image) def predict_batch(images: Iterable[Any]) -> List[Dict[str, float]]: from PIL import Image as PILImage def _as_pil(obj: Any) -> Optional[PILImage.Image]: if obj is None: return None if isinstance(obj, PILImage.Image): return obj try: return PILImage.open(obj).convert("RGB") except Exception: return None outputs: List[Dict[str, float]] = [] for item in images or []: pil_img = _as_pil(item) outputs.append({"Error": 1.0} if pil_img is None else _MODEL.predict_dict(pil_img)) return outputs __all__ = ["predict", "predict_batch"]