meat / inference.py
thienphuc12339's picture
Upload 5 files
fd1c26e verified
# inference.py
from pathlib import Path
import cv2
import numpy as np
# Ưu tiên tflite-runtime; fallback sang tensorflow.lite nếu cần
try:
from tflite_runtime.interpreter import Interpreter
except Exception:
import tensorflow as tf
Interpreter = lambda model_path: tf.lite.Interpreter(model_path=model_path)
MODEL_PATH = Path("models/best.tflite") # đổi tên nếu khác
CLASS_NAMES = ["FRESH", "HALF-FRESH", "SPOILED"] # sửa theo mô hình bạn
if not MODEL_PATH.exists():
raise FileNotFoundError(f"Model not found: {MODEL_PATH}")
_interpreter = Interpreter(model_path=str(MODEL_PATH))
_interpreter.allocate_tensors()
_IN = _interpreter.get_input_details()[0]
_OUT = _interpreter.get_output_details()[0]
_, H, W, _ = _IN["shape"] # ví dụ (1,224,224,3)
def _preprocess(path: str) -> np.ndarray:
img = cv2.imread(path)
if img is None:
raise ValueError("Không đọc được ảnh")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
# Chuẩn hóa tùy theo dtype của tensor vào
if _IN["dtype"] == np.uint8 and "quantization" in _IN:
s, z = _IN["quantization"] # scale, zero-point
x = (img.astype(np.float32) / 255.0) / (s or 1.0) + (z or 0.0)
x = x.astype(np.uint8)
else:
x = img.astype(np.float32) / 255.0
return x[None] # (1,H,W,3)
def predict(image_path: str):
x = _preprocess(image_path)
_interpreter.set_tensor(_IN["index"], x)
_interpreter.invoke()
y = _interpreter.get_tensor(_OUT["index"])[0]
# Hậu xử lý cho đầu ra lượng tử hóa
if _OUT["dtype"] == np.uint8 and "quantization" in _OUT:
s, z = _OUT["quantization"]
y = (y.astype(np.float32) - (z or 0.0)) * (s or 1.0)
idx = int(y.argmax())
return CLASS_NAMES[idx], float(y[idx]) # (label, confidence)