Spaces:
Sleeping
Sleeping
| # inference.py | |
| """TensorFlow helpers for the fruit classification Hugging Face Space.""" | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import time | |
| from typing import Any, Dict, Iterable, List, Optional | |
| import numpy as np | |
| from PIL import Image, ImageFile | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| import tensorflow as tf | |
| # ------------------- Label utilities ------------------- | |
| _LABEL_FILES = [ | |
| os.path.join("models", "class_names.json"), | |
| os.path.join("models", "class_indices.json"), | |
| os.path.join("models", "idx2class.json"), | |
| ] | |
| _DEFAULT_LABELS = [ | |
| "Bean", | |
| "Bitter_Gourd", | |
| "Bottle_Gourd", | |
| "Brinjal", | |
| "Broccoli", | |
| "Cabbage", | |
| "Capsicum", | |
| "Carrot", | |
| "Cauliflower", | |
| "Cucumber", | |
| "Papaya", | |
| "Potato", | |
| "Pumpkin", | |
| "Radish", | |
| "Tomato", | |
| ] | |
| def _normalize_labels(seq: Iterable[Any]) -> List[str]: | |
| cleaned: List[str] = [] | |
| seen = set() | |
| for label in seq: | |
| if not isinstance(label, str): | |
| continue | |
| label = label.strip() | |
| if not label or label.startswith("."): | |
| continue | |
| if label in seen: | |
| continue | |
| cleaned.append(label) | |
| seen.add(label) | |
| return cleaned | |
| def _load_labels() -> List[str]: | |
| def _is_digits(x: Any) -> bool: | |
| try: | |
| int(x) | |
| return True | |
| except (TypeError, ValueError): | |
| return False | |
| for path in _LABEL_FILES: | |
| if not os.path.exists(path): | |
| continue | |
| try: | |
| with open(path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| except Exception as exc: | |
| print(f"[LABEL] failed to load {path}: {exc}") | |
| continue | |
| if isinstance(data, list): | |
| labels = _normalize_labels(data) | |
| if labels: | |
| print(f"[LABEL] from {os.path.basename(path)} -> {labels}") | |
| return labels | |
| if isinstance(data, dict) and data: | |
| # case A: {label: idx} | |
| if all(_is_digits(v) for v in data.values()): | |
| sorted_pairs = sorted( | |
| ((lbl, int(idx)) for lbl, idx in data.items()), | |
| key=lambda item: item[1], | |
| ) | |
| labels = _normalize_labels(lbl for lbl, _ in sorted_pairs) | |
| if labels: | |
| print(f"[LABEL] from {os.path.basename(path)} (label->idx) -> {labels}") | |
| return labels | |
| # case B: {idx: label} | |
| if all(_is_digits(k) for k in data.keys()): | |
| size = len(data) | |
| ordered = [data.get(str(i), data.get(i)) for i in range(size)] | |
| labels = _normalize_labels(ordered) | |
| if labels: | |
| print(f"[LABEL] from {os.path.basename(path)} (idx->label) -> {labels}") | |
| return labels | |
| print("[LABEL] fallback default ->", _DEFAULT_LABELS) | |
| return list(_DEFAULT_LABELS) | |
| def _generate_config_if_missing(model: tf.keras.Model, labels: List[str], path: str = "config.json") -> None: | |
| if os.path.exists(path): | |
| return | |
| ishape = model.input_shape | |
| try: | |
| img_size = int(ishape[1]) | |
| except Exception as exc: # pragma: no cover - defensive only | |
| raise AssertionError(f"Invalid input shape for config: {ishape}") from exc | |
| cfg = { | |
| "architectures": ["FruitCNN"], | |
| "image_size": img_size, | |
| "num_labels": len(labels), | |
| "id2label": {str(i): lbl for i, lbl in enumerate(labels)}, | |
| "label2id": {lbl: i for i, lbl in enumerate(labels)}, | |
| } | |
| with open(path, "w", encoding="utf-8") as f: | |
| json.dump(cfg, f, indent=2) | |
| print(f"[CFG] wrote {path} (image_size={img_size}, num_labels={len(labels)})") | |
| # ------------------- Model wrapper ------------------- | |
| class FruitClassifier: | |
| def __init__(self, model_path: str = "models/model_cnn.keras") -> None: | |
| self.labels = _load_labels() | |
| full_path = os.path.join(os.getcwd(), model_path) | |
| print(f"[LOAD] {full_path}") | |
| self.model: tf.keras.Model = tf.keras.models.load_model(full_path, compile=False) | |
| ishape = self.model.input_shape | |
| self.img_size = int(ishape[1]) | |
| print(f"[MODEL] input size = {self.img_size}") | |
| names_lower = [layer.name.lower() for layer in self.model.layers[:12]] | |
| has_internal_pp = any("rescaling" in n or "normalization" in n for n in names_lower) | |
| self.external_rescale = not has_internal_pp | |
| print(f"[MODEL] internal_preproc={has_internal_pp} -> external_rescale={self.external_rescale}") | |
| num_outputs = int(self.model.output_shape[-1]) | |
| if num_outputs != len(self.labels): | |
| print(f"[WARN] labels({len(self.labels)}) != outputs({num_outputs}) -> syncing") | |
| if len(self.labels) >= num_outputs: | |
| self.labels = self.labels[:num_outputs] | |
| else: | |
| for idx in range(len(self.labels), num_outputs): | |
| self.labels.append(f"class_{idx}") | |
| _generate_config_if_missing(self.model, self.labels) | |
| try: | |
| _ = self.model(tf.zeros((1, self.img_size, self.img_size, 3), dtype=tf.float32)) | |
| except Exception as exc: | |
| print("[WARN] warmup failed:", exc) | |
| def _to_rgb(img: Image.Image) -> Image.Image: | |
| return img if img.mode == "RGB" else img.convert("RGB") | |
| def _preprocess(self, img: Image.Image) -> np.ndarray: | |
| img = self._to_rgb(img).resize((self.img_size, self.img_size)) | |
| arr = np.asarray(img, dtype=np.float32) | |
| if self.external_rescale: | |
| arr = arr / 255.0 | |
| return np.expand_dims(arr, 0) | |
| def predict_dict(self, img: Image.Image) -> Dict[str, float]: | |
| t0 = time.perf_counter() | |
| probs = self.model.predict(self._preprocess(img), verbose=0)[0] | |
| result = {label: float(prob) for label, prob in zip(self.labels, probs)} | |
| dt = (time.perf_counter() - t0) * 1000.0 | |
| print(f"[INF] {len(self.labels)} classes in {dt:.1f} ms") | |
| return result | |
| _MODEL = FruitClassifier() | |
| # ------------------- Public API ------------------- | |
| def predict(image: Optional[Image.Image]) -> Dict[str, float]: | |
| if image is None: | |
| return {"Error": 1.0} | |
| return _MODEL.predict_dict(image) | |
| def predict_batch(images: Iterable[Any]) -> List[Dict[str, float]]: | |
| from PIL import Image as PILImage | |
| def _as_pil(obj: Any) -> Optional[PILImage.Image]: | |
| if obj is None: | |
| return None | |
| if isinstance(obj, PILImage.Image): | |
| return obj | |
| try: | |
| return PILImage.open(obj).convert("RGB") | |
| except Exception: | |
| return None | |
| outputs: List[Dict[str, float]] = [] | |
| for item in images or []: | |
| pil_img = _as_pil(item) | |
| outputs.append({"Error": 1.0} if pil_img is None else _MODEL.predict_dict(pil_img)) | |
| return outputs | |
| __all__ = ["predict", "predict_batch"] | |