Spaces:
Sleeping
Sleeping
| # inference.py | |
| from pathlib import Path | |
| import cv2 | |
| import numpy as np | |
| # Ưu tiên tflite-runtime; fallback sang tensorflow.lite nếu cần | |
| try: | |
| from tflite_runtime.interpreter import Interpreter | |
| except Exception: | |
| import tensorflow as tf | |
| Interpreter = lambda model_path: tf.lite.Interpreter(model_path=model_path) | |
| MODEL_PATH = Path("models/best.tflite") # đổi tên nếu khác | |
| CLASS_NAMES = ["FRESH", "HALF-FRESH", "SPOILED"] # sửa theo mô hình bạn | |
| if not MODEL_PATH.exists(): | |
| raise FileNotFoundError(f"Model not found: {MODEL_PATH}") | |
| _interpreter = Interpreter(model_path=str(MODEL_PATH)) | |
| _interpreter.allocate_tensors() | |
| _IN = _interpreter.get_input_details()[0] | |
| _OUT = _interpreter.get_output_details()[0] | |
| _, H, W, _ = _IN["shape"] # ví dụ (1,224,224,3) | |
| def _preprocess(path: str) -> np.ndarray: | |
| img = cv2.imread(path) | |
| if img is None: | |
| raise ValueError("Không đọc được ảnh") | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| img = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) | |
| # Chuẩn hóa tùy theo dtype của tensor vào | |
| if _IN["dtype"] == np.uint8 and "quantization" in _IN: | |
| s, z = _IN["quantization"] # scale, zero-point | |
| x = (img.astype(np.float32) / 255.0) / (s or 1.0) + (z or 0.0) | |
| x = x.astype(np.uint8) | |
| else: | |
| x = img.astype(np.float32) / 255.0 | |
| return x[None] # (1,H,W,3) | |
| def predict(image_path: str): | |
| x = _preprocess(image_path) | |
| _interpreter.set_tensor(_IN["index"], x) | |
| _interpreter.invoke() | |
| y = _interpreter.get_tensor(_OUT["index"])[0] | |
| # Hậu xử lý cho đầu ra lượng tử hóa | |
| if _OUT["dtype"] == np.uint8 and "quantization" in _OUT: | |
| s, z = _OUT["quantization"] | |
| y = (y.astype(np.float32) - (z or 0.0)) * (s or 1.0) | |
| idx = int(y.argmax()) | |
| return CLASS_NAMES[idx], float(y[idx]) # (label, confidence) | |