Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,19 +3,48 @@ import tensorflow as tf
|
|
| 3 |
import numpy as np
|
| 4 |
from tensorflow.keras.applications.efficientnet import preprocess_input
|
| 5 |
from PIL import Image
|
|
|
|
| 6 |
|
| 7 |
# 1. Định nghĩa hằng số
|
| 8 |
MODEL_PATH = "best_model.h5" # Tên file mô hình đã upload
|
| 9 |
IMG_SIZE = (224, 224)
|
| 10 |
-
CLASS_NAMES = ['bad', 'good', 'very_good']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
# 2. Tải mô hình
|
| 13 |
-
|
| 14 |
try:
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
except Exception as e:
|
| 17 |
# Nếu mô hình không tải được (ví dụ: lỗi cấu trúc), in ra lỗi
|
| 18 |
-
print(f"Lỗi tải mô hình: {e}")
|
| 19 |
model = None
|
| 20 |
|
| 21 |
def predict_guava_quality(img_input):
|
|
@@ -52,5 +81,7 @@ demo = gr.Interface(
|
|
| 52 |
description="Tải lên ảnh quả ổi để phân loại thành: Hàng xuất khẩu (very_good), Hàng nội địa (good), hoặc Loại bỏ (bad)."
|
| 53 |
)
|
| 54 |
|
| 55 |
-
# 4. Chạy Gradio App (Hugging Face Spaces sẽ tự
|
| 56 |
-
|
|
|
|
|
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
from tensorflow.keras.applications.efficientnet import preprocess_input
|
| 5 |
from PIL import Image
|
| 6 |
+
from tensorflow.keras.layers import InputLayer
|
| 7 |
|
| 8 |
# 1. Định nghĩa hằng số
|
| 9 |
MODEL_PATH = "best_model.h5" # Tên file mô hình đã upload
|
| 10 |
IMG_SIZE = (224, 224)
|
| 11 |
+
CLASS_NAMES = ['bad', 'good', 'very_good']
|
| 12 |
+
|
| 13 |
+
# =================================================================
|
| 14 |
+
# KHẮC PHỤC LỖI DESERIALIZATION (batch_shape)
|
| 15 |
+
# Nếu mô hình được lưu bằng Keras cũ, ta cần định nghĩa lại InputLayer
|
| 16 |
+
# để chấp nhận tham số batch_shape khi tải.
|
| 17 |
+
class FixedInputLayer(InputLayer):
|
| 18 |
+
"""
|
| 19 |
+
Tạo InputLayer cố định để xử lý lỗi "Unrecognized keyword arguments: ['batch_shape']"
|
| 20 |
+
khi tải mô hình đã lưu bằng TF/Keras cũ hơn.
|
| 21 |
+
"""
|
| 22 |
+
def __init__(self, **kwargs):
|
| 23 |
+
if 'batch_shape' in kwargs:
|
| 24 |
+
# Chuyển batch_shape thành input_shape
|
| 25 |
+
kwargs['input_shape'] = kwargs['batch_shape'][1:]
|
| 26 |
+
del kwargs['batch_shape']
|
| 27 |
+
super().__init__(**kwargs)
|
| 28 |
+
|
| 29 |
+
# =================================================================
|
| 30 |
|
| 31 |
# 2. Tải mô hình
|
| 32 |
+
model = None
|
| 33 |
try:
|
| 34 |
+
# Tắt log TensorFlow khi chạy trên nền tảng
|
| 35 |
+
tf.get_logger().setLevel('ERROR')
|
| 36 |
+
|
| 37 |
+
# SỬ DỤNG custom_objects KHI TẢI MÔ HÌNH
|
| 38 |
+
model = tf.keras.models.load_model(
|
| 39 |
+
MODEL_PATH,
|
| 40 |
+
custom_objects={
|
| 41 |
+
'InputLayer': FixedInputLayer # Sử dụng lớp FixedInputLayer để thay thế
|
| 42 |
+
}
|
| 43 |
+
)
|
| 44 |
+
print("✅ Mô hình đã được tải thành công.")
|
| 45 |
except Exception as e:
|
| 46 |
# Nếu mô hình không tải được (ví dụ: lỗi cấu trúc), in ra lỗi
|
| 47 |
+
print(f"❌ Lỗi tải mô hình: {e}")
|
| 48 |
model = None
|
| 49 |
|
| 50 |
def predict_guava_quality(img_input):
|
|
|
|
| 81 |
description="Tải lên ảnh quả ổi để phân loại thành: Hàng xuất khẩu (very_good), Hàng nội địa (good), hoặc Loại bỏ (bad)."
|
| 82 |
)
|
| 83 |
|
| 84 |
+
# 4. Chạy Gradio App (Hugging Face Spaces sẽ tự chạy file này)
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
# Server port và name cần thiết để Gradio chạy trong môi trường container
|
| 87 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|