Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,139 +1,51 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import tensorflow as tf
|
| 3 |
import numpy as np
|
| 4 |
-
from tensorflow.keras.applications.efficientnet import preprocess_input
|
| 5 |
from PIL import Image
|
| 6 |
-
from tensorflow.keras.
|
| 7 |
-
from tensorflow.keras.layers import RandomWidth, RandomHeight
|
| 8 |
|
| 9 |
-
#
|
| 10 |
-
MODEL_PATH = "
|
| 11 |
IMG_SIZE = (224, 224)
|
| 12 |
-
CLASS_NAMES = ['bad', 'good', 'very_good']
|
| 13 |
-
|
| 14 |
-
# =================================================================
|
| 15 |
-
# KHẮC PHỤC LỖI TƯƠNG THÍCH PHIÊN BẢN (KERAS 3 -> KERAS 2)
|
| 16 |
-
# =================================================================
|
| 17 |
-
|
| 18 |
-
# 1. Mock DTypePolicy (Xử lý lỗi: Unknown dtype policy, Attribute name & compute_dtype)
|
| 19 |
-
class MockDTypePolicy:
|
| 20 |
-
"""
|
| 21 |
-
Lớp giả lập để thay thế DTypePolicy của Keras 3.
|
| 22 |
-
Giúp tránh lỗi deserialization khi chạy trên môi trường cũ.
|
| 23 |
-
"""
|
| 24 |
-
def __init__(self, **kwargs):
|
| 25 |
-
# SỬA LỖI: Thêm đầy đủ các thuộc tính mà Keras 3 yêu cầu
|
| 26 |
-
self.name = kwargs.get("name", "float32")
|
| 27 |
-
self.compute_dtype = kwargs.get("compute_dtype", "float32")
|
| 28 |
-
self.variable_dtype = kwargs.get("variable_dtype", "float32")
|
| 29 |
-
|
| 30 |
-
@classmethod
|
| 31 |
-
def from_config(cls, config):
|
| 32 |
-
return cls(**config)
|
| 33 |
-
|
| 34 |
-
def get_config(self):
|
| 35 |
-
return {
|
| 36 |
-
"name": self.name,
|
| 37 |
-
"compute_dtype": self.compute_dtype,
|
| 38 |
-
"variable_dtype": self.variable_dtype
|
| 39 |
-
}
|
| 40 |
-
|
| 41 |
-
# 2. Xử lý InputLayer (Xử lý lỗi: batch_shape)
|
| 42 |
-
class FixedInputLayer(InputLayer):
|
| 43 |
-
def __init__(self, **kwargs):
|
| 44 |
-
if 'batch_shape' in kwargs:
|
| 45 |
-
kwargs['input_shape'] = kwargs['batch_shape'][1:]
|
| 46 |
-
del kwargs['batch_shape']
|
| 47 |
-
# Xóa dtype nếu nó là dạng dictionary (config của Keras 3)
|
| 48 |
-
if 'dtype' in kwargs and isinstance(kwargs['dtype'], dict):
|
| 49 |
-
del kwargs['dtype']
|
| 50 |
-
super().__init__(**kwargs)
|
| 51 |
|
| 52 |
-
#
|
| 53 |
-
|
| 54 |
-
class FixedLayer(LayerClass):
|
| 55 |
-
def __init__(self, **kwargs):
|
| 56 |
-
# Danh sách các tham số gây lỗi tương thích giữa Keras 3 và 2
|
| 57 |
-
ignore_keys = ['data_format', 'dtype', 'value_range']
|
| 58 |
-
|
| 59 |
-
for key in ignore_keys:
|
| 60 |
-
if key in kwargs:
|
| 61 |
-
del kwargs[key]
|
| 62 |
-
|
| 63 |
-
super().__init__(**kwargs)
|
| 64 |
-
return FixedLayer
|
| 65 |
|
| 66 |
-
|
| 67 |
-
CUSTOM_OBJECTS = {
|
| 68 |
-
'InputLayer': FixedInputLayer,
|
| 69 |
-
# Đăng ký lớp giả lập DTypePolicy
|
| 70 |
-
'DTypePolicy': MockDTypePolicy,
|
| 71 |
-
# Augmentation Layers đã được vá lỗi
|
| 72 |
-
'RandomFlip': fix_augmentation_layer(RandomFlip),
|
| 73 |
-
'RandomRotation': fix_augmentation_layer(RandomRotation),
|
| 74 |
-
'RandomZoom': fix_augmentation_layer(RandomZoom),
|
| 75 |
-
'RandomContrast': fix_augmentation_layer(RandomContrast),
|
| 76 |
-
'RandomWidth': fix_augmentation_layer(RandomWidth),
|
| 77 |
-
'RandomHeight': fix_augmentation_layer(RandomHeight)
|
| 78 |
-
}
|
| 79 |
-
|
| 80 |
-
# =================================================================
|
| 81 |
-
|
| 82 |
-
# 2. Tải mô hình
|
| 83 |
-
model = None
|
| 84 |
-
try:
|
| 85 |
-
# Tắt log TensorFlow
|
| 86 |
-
tf.get_logger().setLevel('ERROR')
|
| 87 |
-
|
| 88 |
-
# Tải mô hình với danh sách custom objects đầy đủ
|
| 89 |
-
model = tf.keras.models.load_model(
|
| 90 |
-
MODEL_PATH,
|
| 91 |
-
custom_objects=CUSTOM_OBJECTS
|
| 92 |
-
)
|
| 93 |
-
print("✅ Mô hình đã được tải thành công.")
|
| 94 |
-
except Exception as e:
|
| 95 |
-
print(f"❌ Lỗi tải mô hình: {e}")
|
| 96 |
-
model = None
|
| 97 |
|
| 98 |
def predict_guava_quality(img_input):
|
| 99 |
-
if model is None:
|
| 100 |
-
return "❌ Lỗi: Không thể tải mô hình.", 0.0
|
| 101 |
-
|
| 102 |
if img_input is None:
|
| 103 |
-
return "❌ Vui lòng tải ảnh
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
img_resized = img_pil.resize(IMG_SIZE)
|
| 109 |
-
|
| 110 |
-
# Preprocess
|
| 111 |
-
arr = np.array(img_resized).astype("float32")
|
| 112 |
-
arr = preprocess_input(arr)
|
| 113 |
-
arr = np.expand_dims(arr, 0)
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
confidence = preds[idx]
|
| 119 |
-
label = CLASS_NAMES[idx]
|
| 120 |
|
| 121 |
-
|
| 122 |
-
except Exception as e:
|
| 123 |
-
return f"❌ Lỗi xử lý ảnh: {str(e)}", 0.0
|
| 124 |
|
| 125 |
-
# 3. Giao diện Gradio
|
| 126 |
demo = gr.Interface(
|
| 127 |
fn=predict_guava_quality,
|
| 128 |
inputs=gr.Image(type="numpy", label="Tải ảnh Quả Ổi"),
|
| 129 |
outputs=[
|
| 130 |
gr.Textbox(label="Dự đoán"),
|
| 131 |
-
gr.Number(label="Độ tin cậy (%)", precision=
|
| 132 |
],
|
| 133 |
-
title="Phân loại
|
| 134 |
-
description="
|
| 135 |
)
|
| 136 |
|
| 137 |
-
# 4. Chạy App
|
| 138 |
if __name__ == "__main__":
|
| 139 |
-
demo.launch(
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import tensorflow as tf
|
| 3 |
import numpy as np
|
|
|
|
| 4 |
from PIL import Image
|
| 5 |
+
from tensorflow.keras.applications.efficientnet import preprocess_input
|
|
|
|
| 6 |
|
| 7 |
+
# Đường dẫn model SavedModel
|
| 8 |
+
MODEL_PATH = "exported_model"
|
| 9 |
IMG_SIZE = (224, 224)
|
| 10 |
+
CLASS_NAMES = ['bad', 'good', 'very_good']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
# Load model
|
| 13 |
+
model = tf.saved_model.load(MODEL_PATH)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
infer = model.signatures["serving_default"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
def predict_guava_quality(img_input):
|
|
|
|
|
|
|
|
|
|
| 18 |
if img_input is None:
|
| 19 |
+
return "❌ Vui lòng tải ảnh", 0.0
|
| 20 |
+
|
| 21 |
+
# Convert image
|
| 22 |
+
img = Image.fromarray(img_input).convert("RGB")
|
| 23 |
+
img = img.resize(IMG_SIZE)
|
| 24 |
+
|
| 25 |
+
arr = np.array(img).astype("float32")
|
| 26 |
+
arr = preprocess_input(arr)
|
| 27 |
+
arr = np.expand_dims(arr, axis=0)
|
| 28 |
|
| 29 |
+
# TensorFlow serving
|
| 30 |
+
outputs = infer(tf.constant(arr))
|
| 31 |
+
preds = list(outputs.values())[0].numpy()[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
+
idx = np.argmax(preds)
|
| 34 |
+
confidence = preds[idx]
|
| 35 |
+
label = CLASS_NAMES[idx]
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
return f"✅ Kết quả: {label}", float(confidence)
|
|
|
|
|
|
|
| 38 |
|
|
|
|
| 39 |
demo = gr.Interface(
|
| 40 |
fn=predict_guava_quality,
|
| 41 |
inputs=gr.Image(type="numpy", label="Tải ảnh Quả Ổi"),
|
| 42 |
outputs=[
|
| 43 |
gr.Textbox(label="Dự đoán"),
|
| 44 |
+
gr.Number(label="Độ tin cậy (%)", precision=4)
|
| 45 |
],
|
| 46 |
+
title="Phân loại chất lượng Ổi",
|
| 47 |
+
description="Model EfficientNetB0 | very_good / good / bad"
|
| 48 |
)
|
| 49 |
|
|
|
|
| 50 |
if __name__ == "__main__":
|
| 51 |
+
demo.launch()
|