cang1602004 commited on
Commit
3f94404
·
verified ·
1 Parent(s): 77329fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -115
app.py CHANGED
@@ -1,139 +1,51 @@
1
  import gradio as gr
2
  import tensorflow as tf
3
  import numpy as np
4
- from tensorflow.keras.applications.efficientnet import preprocess_input
5
  from PIL import Image
6
- from tensorflow.keras.layers import InputLayer, RandomFlip, RandomRotation, RandomZoom, RandomContrast
7
- from tensorflow.keras.layers import RandomWidth, RandomHeight
8
 
9
- # 1. Định nghĩa hằng số
10
- MODEL_PATH = "best_model.h5"
11
  IMG_SIZE = (224, 224)
12
- CLASS_NAMES = ['bad', 'good', 'very_good']
13
-
14
- # =================================================================
15
- # KHẮC PHỤC LỖI TƯƠNG THÍCH PHIÊN BẢN (KERAS 3 -> KERAS 2)
16
- # =================================================================
17
-
18
- # 1. Mock DTypePolicy (Xử lý lỗi: Unknown dtype policy, Attribute name & compute_dtype)
19
- class MockDTypePolicy:
20
- """
21
- Lớp giả lập để thay thế DTypePolicy của Keras 3.
22
- Giúp tránh lỗi deserialization khi chạy trên môi trường cũ.
23
- """
24
- def __init__(self, **kwargs):
25
- # SỬA LỖI: Thêm đầy đủ các thuộc tính mà Keras 3 yêu cầu
26
- self.name = kwargs.get("name", "float32")
27
- self.compute_dtype = kwargs.get("compute_dtype", "float32")
28
- self.variable_dtype = kwargs.get("variable_dtype", "float32")
29
-
30
- @classmethod
31
- def from_config(cls, config):
32
- return cls(**config)
33
-
34
- def get_config(self):
35
- return {
36
- "name": self.name,
37
- "compute_dtype": self.compute_dtype,
38
- "variable_dtype": self.variable_dtype
39
- }
40
-
41
- # 2. Xử lý InputLayer (Xử lý lỗi: batch_shape)
42
- class FixedInputLayer(InputLayer):
43
- def __init__(self, **kwargs):
44
- if 'batch_shape' in kwargs:
45
- kwargs['input_shape'] = kwargs['batch_shape'][1:]
46
- del kwargs['batch_shape']
47
- # Xóa dtype nếu nó là dạng dictionary (config của Keras 3)
48
- if 'dtype' in kwargs and isinstance(kwargs['dtype'], dict):
49
- del kwargs['dtype']
50
- super().__init__(**kwargs)
51
 
52
- # 3. Xử lý Augmentation Layers (Xử lý lỗi: data_format, dtype & value_range)
53
- def fix_augmentation_layer(LayerClass):
54
- class FixedLayer(LayerClass):
55
- def __init__(self, **kwargs):
56
- # Danh sách các tham số gây lỗi tương thích giữa Keras 3 và 2
57
- ignore_keys = ['data_format', 'dtype', 'value_range']
58
-
59
- for key in ignore_keys:
60
- if key in kwargs:
61
- del kwargs[key]
62
-
63
- super().__init__(**kwargs)
64
- return FixedLayer
65
 
66
- # 4. Đăng ký tất cả Custom Objects
67
- CUSTOM_OBJECTS = {
68
- 'InputLayer': FixedInputLayer,
69
- # Đăng ký lớp giả lập DTypePolicy
70
- 'DTypePolicy': MockDTypePolicy,
71
- # Augmentation Layers đã được vá lỗi
72
- 'RandomFlip': fix_augmentation_layer(RandomFlip),
73
- 'RandomRotation': fix_augmentation_layer(RandomRotation),
74
- 'RandomZoom': fix_augmentation_layer(RandomZoom),
75
- 'RandomContrast': fix_augmentation_layer(RandomContrast),
76
- 'RandomWidth': fix_augmentation_layer(RandomWidth),
77
- 'RandomHeight': fix_augmentation_layer(RandomHeight)
78
- }
79
-
80
- # =================================================================
81
-
82
- # 2. Tải mô hình
83
- model = None
84
- try:
85
- # Tắt log TensorFlow
86
- tf.get_logger().setLevel('ERROR')
87
-
88
- # Tải mô hình với danh sách custom objects đầy đủ
89
- model = tf.keras.models.load_model(
90
- MODEL_PATH,
91
- custom_objects=CUSTOM_OBJECTS
92
- )
93
- print("✅ Mô hình đã được tải thành công.")
94
- except Exception as e:
95
- print(f"❌ Lỗi tải mô hình: {e}")
96
- model = None
97
 
98
  def predict_guava_quality(img_input):
99
- if model is None:
100
- return "❌ Lỗi: Không thể tải mô hình.", 0.0
101
-
102
  if img_input is None:
103
- return "❌ Vui lòng tải ảnh lên.", 0.0
 
 
 
 
 
 
 
 
104
 
105
- try:
106
- # Chuyển đổi ảnh
107
- img_pil = Image.fromarray(img_input).convert("RGB")
108
- img_resized = img_pil.resize(IMG_SIZE)
109
-
110
- # Preprocess
111
- arr = np.array(img_resized).astype("float32")
112
- arr = preprocess_input(arr)
113
- arr = np.expand_dims(arr, 0)
114
 
115
- # Dự đoán
116
- preds = model.predict(arr)[0]
117
- idx = np.argmax(preds)
118
- confidence = preds[idx]
119
- label = CLASS_NAMES[idx]
120
 
121
- return f"✅ Kết quả: {label}", float(confidence)
122
- except Exception as e:
123
- return f"❌ Lỗi xử lý ảnh: {str(e)}", 0.0
124
 
125
- # 3. Giao diện Gradio
126
  demo = gr.Interface(
127
  fn=predict_guava_quality,
128
  inputs=gr.Image(type="numpy", label="Tải ảnh Quả Ổi"),
129
  outputs=[
130
  gr.Textbox(label="Dự đoán"),
131
- gr.Number(label="Độ tin cậy (%)", precision=2)
132
  ],
133
- title="Phân loại Chất lượng Quả Ổi (EfficientNetB0)",
134
- description="Tải lên ảnh quả ổi để phân loại thành: Hàng xuất khẩu (very_good), Hàng nội địa (good), hoặc Loại bỏ (bad)."
135
  )
136
 
137
- # 4. Chạy App
138
  if __name__ == "__main__":
139
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
  import tensorflow as tf
3
  import numpy as np
 
4
  from PIL import Image
5
+ from tensorflow.keras.applications.efficientnet import preprocess_input
 
6
 
7
+ # Đường dẫn model SavedModel
8
+ MODEL_PATH = "exported_model"
9
  IMG_SIZE = (224, 224)
10
+ CLASS_NAMES = ['bad', 'good', 'very_good']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Load model
13
+ model = tf.saved_model.load(MODEL_PATH)
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ infer = model.signatures["serving_default"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def predict_guava_quality(img_input):
 
 
 
18
  if img_input is None:
19
+ return "❌ Vui lòng tải ảnh", 0.0
20
+
21
+ # Convert image
22
+ img = Image.fromarray(img_input).convert("RGB")
23
+ img = img.resize(IMG_SIZE)
24
+
25
+ arr = np.array(img).astype("float32")
26
+ arr = preprocess_input(arr)
27
+ arr = np.expand_dims(arr, axis=0)
28
 
29
+ # TensorFlow serving
30
+ outputs = infer(tf.constant(arr))
31
+ preds = list(outputs.values())[0].numpy()[0]
 
 
 
 
 
 
32
 
33
+ idx = np.argmax(preds)
34
+ confidence = preds[idx]
35
+ label = CLASS_NAMES[idx]
 
 
36
 
37
+ return f"✅ Kết quả: {label}", float(confidence)
 
 
38
 
 
39
  demo = gr.Interface(
40
  fn=predict_guava_quality,
41
  inputs=gr.Image(type="numpy", label="Tải ảnh Quả Ổi"),
42
  outputs=[
43
  gr.Textbox(label="Dự đoán"),
44
+ gr.Number(label="Độ tin cậy (%)", precision=4)
45
  ],
46
+ title="Phân loại chất lượng Ổi",
47
+ description="Model EfficientNetB0 | very_good / good / bad"
48
  )
49
 
 
50
  if __name__ == "__main__":
51
+ demo.launch()