Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,65 +3,89 @@ import tensorflow as tf
|
|
| 3 |
import numpy as np
|
| 4 |
from tensorflow.keras.applications.efficientnet import preprocess_input
|
| 5 |
from PIL import Image
|
| 6 |
-
from tensorflow.keras.layers import InputLayer
|
| 7 |
-
|
| 8 |
-
from tensorflow.keras.layers import RandomFlip, RandomRotation, RandomZoom, RandomContrast
|
| 9 |
|
| 10 |
# 1. Định nghĩa hằng số
|
| 11 |
-
MODEL_PATH = "best_model.h5"
|
| 12 |
IMG_SIZE = (224, 224)
|
| 13 |
CLASS_NAMES = ['bad', 'good', 'very_good']
|
| 14 |
|
| 15 |
# =================================================================
|
| 16 |
-
# KHẮC PHỤC LỖI
|
| 17 |
-
#
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
class FixedInputLayer(InputLayer):
|
| 20 |
-
"""Xử lý lỗi batch_shape."""
|
| 21 |
def __init__(self, **kwargs):
|
| 22 |
if 'batch_shape' in kwargs:
|
| 23 |
kwargs['input_shape'] = kwargs['batch_shape'][1:]
|
| 24 |
del kwargs['batch_shape']
|
|
|
|
|
|
|
|
|
|
| 25 |
super().__init__(**kwargs)
|
| 26 |
|
|
|
|
| 27 |
def fix_augmentation_layer(LayerClass):
|
| 28 |
-
"""Tạo lớp cố định để loại bỏ tham số 'data_format'."""
|
| 29 |
class FixedLayer(LayerClass):
|
| 30 |
def __init__(self, **kwargs):
|
|
|
|
| 31 |
if 'data_format' in kwargs:
|
| 32 |
del kwargs['data_format']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
super().__init__(**kwargs)
|
| 34 |
return FixedLayer
|
| 35 |
|
| 36 |
-
#
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
# =================================================================
|
| 43 |
|
| 44 |
# 2. Tải mô hình
|
| 45 |
model = None
|
| 46 |
try:
|
| 47 |
-
# Tắt log TensorFlow
|
| 48 |
tf.get_logger().setLevel('ERROR')
|
| 49 |
|
| 50 |
-
#
|
| 51 |
model = tf.keras.models.load_model(
|
| 52 |
MODEL_PATH,
|
| 53 |
-
custom_objects=
|
| 54 |
-
'InputLayer': FixedInputLayer, # Fix InputLayer
|
| 55 |
-
'RandomFlip': FixedRandomFlip, # Fix RandomFlip
|
| 56 |
-
'RandomRotation': FixedRandomRotation, # Fix RandomRotation
|
| 57 |
-
'RandomZoom': FixedRandomZoom, # Fix RandomZoom
|
| 58 |
-
'RandomContrast': FixedRandomContrast # Fix RandomContrast
|
| 59 |
-
# Bổ sung các lớp Augmentation khác nếu cần
|
| 60 |
-
}
|
| 61 |
)
|
| 62 |
print("✅ Mô hình đã được tải thành công.")
|
| 63 |
except Exception as e:
|
| 64 |
-
# Nếu mô hình không tải được (ví dụ: lỗi cấu trúc), in ra lỗi
|
| 65 |
print(f"❌ Lỗi tải mô hình: {e}")
|
| 66 |
model = None
|
| 67 |
|
|
@@ -69,25 +93,30 @@ def predict_guava_quality(img_input):
|
|
| 69 |
if model is None:
|
| 70 |
return "❌ Lỗi: Không thể tải mô hình.", 0.0
|
| 71 |
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
|
|
|
| 89 |
|
| 90 |
-
# 3.
|
| 91 |
demo = gr.Interface(
|
| 92 |
fn=predict_guava_quality,
|
| 93 |
inputs=gr.Image(type="numpy", label="Tải ảnh Quả Ổi"),
|
|
@@ -99,7 +128,6 @@ demo = gr.Interface(
|
|
| 99 |
description="Tải lên ảnh quả ổi để phân loại thành: Hàng xuất khẩu (very_good), Hàng nội địa (good), hoặc Loại bỏ (bad)."
|
| 100 |
)
|
| 101 |
|
| 102 |
-
# 4. Chạy
|
| 103 |
if __name__ == "__main__":
|
| 104 |
-
# Server port và name cần thiết để Gradio chạy trong môi trường container
|
| 105 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
from tensorflow.keras.applications.efficientnet import preprocess_input
|
| 5 |
from PIL import Image
|
| 6 |
+
from tensorflow.keras.layers import InputLayer, RandomFlip, RandomRotation, RandomZoom, RandomContrast
|
| 7 |
+
from tensorflow.keras.layers import RandomWidth, RandomHeight
|
|
|
|
| 8 |
|
| 9 |
# 1. Định nghĩa hằng số
|
| 10 |
+
MODEL_PATH = "best_model.h5"
|
| 11 |
IMG_SIZE = (224, 224)
|
| 12 |
CLASS_NAMES = ['bad', 'good', 'very_good']
|
| 13 |
|
| 14 |
# =================================================================
|
| 15 |
+
# KHẮC PHỤC LỖI TƯƠNG THÍCH PHIÊN BẢN (KERAS 3 -> KERAS 2)
|
| 16 |
+
# =================================================================
|
| 17 |
|
| 18 |
+
# 1. Mock DTypePolicy (Xử lý lỗi: Unknown dtype policy: 'DTypePolicy')
|
| 19 |
+
class MockDTypePolicy:
|
| 20 |
+
"""
|
| 21 |
+
Lớp giả lập để thay thế DTypePolicy của Keras 3.
|
| 22 |
+
Giúp tránh lỗi deserialization khi chạy trên môi trường cũ.
|
| 23 |
+
"""
|
| 24 |
+
def __init__(self, **kwargs):
|
| 25 |
+
pass # Không làm gì cả, chỉ cần tồn tại để không báo lỗi
|
| 26 |
+
|
| 27 |
+
@classmethod
|
| 28 |
+
def from_config(cls, config):
|
| 29 |
+
return cls(**config)
|
| 30 |
+
|
| 31 |
+
def get_config(self):
|
| 32 |
+
return {}
|
| 33 |
+
|
| 34 |
+
# 2. Xử lý InputLayer (Xử lý lỗi: batch_shape)
|
| 35 |
class FixedInputLayer(InputLayer):
|
|
|
|
| 36 |
def __init__(self, **kwargs):
|
| 37 |
if 'batch_shape' in kwargs:
|
| 38 |
kwargs['input_shape'] = kwargs['batch_shape'][1:]
|
| 39 |
del kwargs['batch_shape']
|
| 40 |
+
# Xóa dtype nếu nó là dạng dictionary (config của Keras 3)
|
| 41 |
+
if 'dtype' in kwargs and isinstance(kwargs['dtype'], dict):
|
| 42 |
+
del kwargs['dtype']
|
| 43 |
super().__init__(**kwargs)
|
| 44 |
|
| 45 |
+
# 3. Xử lý Augmentation Layers (Xử lý lỗi: data_format & dtype policy)
|
| 46 |
def fix_augmentation_layer(LayerClass):
|
|
|
|
| 47 |
class FixedLayer(LayerClass):
|
| 48 |
def __init__(self, **kwargs):
|
| 49 |
+
# Xóa tham số data_format gây lỗi
|
| 50 |
if 'data_format' in kwargs:
|
| 51 |
del kwargs['data_format']
|
| 52 |
+
|
| 53 |
+
# Xóa tham số dtype gây lỗi (DTypePolicy)
|
| 54 |
+
if 'dtype' in kwargs:
|
| 55 |
+
del kwargs['dtype']
|
| 56 |
+
|
| 57 |
super().__init__(**kwargs)
|
| 58 |
return FixedLayer
|
| 59 |
|
| 60 |
+
# 4. Đăng ký tất cả Custom Objects
|
| 61 |
+
CUSTOM_OBJECTS = {
|
| 62 |
+
'InputLayer': FixedInputLayer,
|
| 63 |
+
# Đăng ký lớp giả lập DTypePolicy
|
| 64 |
+
'DTypePolicy': MockDTypePolicy,
|
| 65 |
+
# Augmentation Layers đã được vá lỗi
|
| 66 |
+
'RandomFlip': fix_augmentation_layer(RandomFlip),
|
| 67 |
+
'RandomRotation': fix_augmentation_layer(RandomRotation),
|
| 68 |
+
'RandomZoom': fix_augmentation_layer(RandomZoom),
|
| 69 |
+
'RandomContrast': fix_augmentation_layer(RandomContrast),
|
| 70 |
+
'RandomWidth': fix_augmentation_layer(RandomWidth),
|
| 71 |
+
'RandomHeight': fix_augmentation_layer(RandomHeight)
|
| 72 |
+
}
|
| 73 |
|
| 74 |
# =================================================================
|
| 75 |
|
| 76 |
# 2. Tải mô hình
|
| 77 |
model = None
|
| 78 |
try:
|
| 79 |
+
# Tắt log TensorFlow
|
| 80 |
tf.get_logger().setLevel('ERROR')
|
| 81 |
|
| 82 |
+
# Tải mô hình với danh sách custom objects đầy đủ
|
| 83 |
model = tf.keras.models.load_model(
|
| 84 |
MODEL_PATH,
|
| 85 |
+
custom_objects=CUSTOM_OBJECTS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
)
|
| 87 |
print("✅ Mô hình đã được tải thành công.")
|
| 88 |
except Exception as e:
|
|
|
|
| 89 |
print(f"❌ Lỗi tải mô hình: {e}")
|
| 90 |
model = None
|
| 91 |
|
|
|
|
| 93 |
if model is None:
|
| 94 |
return "❌ Lỗi: Không thể tải mô hình.", 0.0
|
| 95 |
|
| 96 |
+
if img_input is None:
|
| 97 |
+
return "❌ Vui lòng tải ảnh lên.", 0.0
|
| 98 |
+
|
| 99 |
+
try:
|
| 100 |
+
# Chuyển đổi ảnh
|
| 101 |
+
img_pil = Image.fromarray(img_input).convert("RGB")
|
| 102 |
+
img_resized = img_pil.resize(IMG_SIZE)
|
| 103 |
+
|
| 104 |
+
# Preprocess
|
| 105 |
+
arr = np.array(img_resized).astype("float32")
|
| 106 |
+
arr = preprocess_input(arr)
|
| 107 |
+
arr = np.expand_dims(arr, 0)
|
| 108 |
|
| 109 |
+
# Dự đoán
|
| 110 |
+
preds = model.predict(arr)[0]
|
| 111 |
+
idx = np.argmax(preds)
|
| 112 |
+
confidence = preds[idx]
|
| 113 |
+
label = CLASS_NAMES[idx]
|
| 114 |
|
| 115 |
+
return f"✅ Kết quả: {label}", float(confidence)
|
| 116 |
+
except Exception as e:
|
| 117 |
+
return f"❌ Lỗi xử lý ảnh: {str(e)}", 0.0
|
| 118 |
|
| 119 |
+
# 3. Giao diện Gradio
|
| 120 |
demo = gr.Interface(
|
| 121 |
fn=predict_guava_quality,
|
| 122 |
inputs=gr.Image(type="numpy", label="Tải ảnh Quả Ổi"),
|
|
|
|
| 128 |
description="Tải lên ảnh quả ổi để phân loại thành: Hàng xuất khẩu (very_good), Hàng nội địa (good), hoặc Loại bỏ (bad)."
|
| 129 |
)
|
| 130 |
|
| 131 |
+
# 4. Chạy App
|
| 132 |
if __name__ == "__main__":
|
|
|
|
| 133 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|