# inference.py from pathlib import Path import cv2 import numpy as np # Ưu tiên tflite-runtime; fallback sang tensorflow.lite nếu cần try: from tflite_runtime.interpreter import Interpreter except Exception: import tensorflow as tf Interpreter = lambda model_path: tf.lite.Interpreter(model_path=model_path) MODEL_PATH = Path("models/best.tflite") # đổi tên nếu khác CLASS_NAMES = ["FRESH", "HALF-FRESH", "SPOILED"] # sửa theo mô hình bạn if not MODEL_PATH.exists(): raise FileNotFoundError(f"Model not found: {MODEL_PATH}") _interpreter = Interpreter(model_path=str(MODEL_PATH)) _interpreter.allocate_tensors() _IN = _interpreter.get_input_details()[0] _OUT = _interpreter.get_output_details()[0] _, H, W, _ = _IN["shape"] # ví dụ (1,224,224,3) def _preprocess(path: str) -> np.ndarray: img = cv2.imread(path) if img is None: raise ValueError("Không đọc được ảnh") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) # Chuẩn hóa tùy theo dtype của tensor vào if _IN["dtype"] == np.uint8 and "quantization" in _IN: s, z = _IN["quantization"] # scale, zero-point x = (img.astype(np.float32) / 255.0) / (s or 1.0) + (z or 0.0) x = x.astype(np.uint8) else: x = img.astype(np.float32) / 255.0 return x[None] # (1,H,W,3) def predict(image_path: str): x = _preprocess(image_path) _interpreter.set_tensor(_IN["index"], x) _interpreter.invoke() y = _interpreter.get_tensor(_OUT["index"])[0] # Hậu xử lý cho đầu ra lượng tử hóa if _OUT["dtype"] == np.uint8 and "quantization" in _OUT: s, z = _OUT["quantization"] y = (y.astype(np.float32) - (z or 0.0)) * (s or 1.0) idx = int(y.argmax()) return CLASS_NAMES[idx], float(y[idx]) # (label, confidence)