cang1602004 commited on
Commit
7fd0239
·
verified ·
1 Parent(s): 700dd4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -42
app.py CHANGED
@@ -3,65 +3,89 @@ import tensorflow as tf
3
  import numpy as np
4
  from tensorflow.keras.applications.efficientnet import preprocess_input
5
  from PIL import Image
6
- from tensorflow.keras.layers import InputLayer
7
- # Import các lớp Augmentation mà mô hình có thể đã lưu
8
- from tensorflow.keras.layers import RandomFlip, RandomRotation, RandomZoom, RandomContrast
9
 
10
  # 1. Định nghĩa hằng số
11
- MODEL_PATH = "best_model.h5" # Tên file mô hình đã upload (ĐÃ SỬA: Đồng bộ với tên file trên GitHub)
12
  IMG_SIZE = (224, 224)
13
  CLASS_NAMES = ['bad', 'good', 'very_good']
14
 
15
  # =================================================================
16
- # KHẮC PHỤC LỖI DESERIALIZATION CHO CÁC LỚP AUGMENTATION INPUT LAYER
17
- # Mục đích: Bỏ qua tham số 'batch_shape' và 'data_format' không tương thích.
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  class FixedInputLayer(InputLayer):
20
- """Xử lý lỗi batch_shape."""
21
  def __init__(self, **kwargs):
22
  if 'batch_shape' in kwargs:
23
  kwargs['input_shape'] = kwargs['batch_shape'][1:]
24
  del kwargs['batch_shape']
 
 
 
25
  super().__init__(**kwargs)
26
 
 
27
  def fix_augmentation_layer(LayerClass):
28
- """Tạo lớp cố định để loại bỏ tham số 'data_format'."""
29
  class FixedLayer(LayerClass):
30
  def __init__(self, **kwargs):
 
31
  if 'data_format' in kwargs:
32
  del kwargs['data_format']
 
 
 
 
 
33
  super().__init__(**kwargs)
34
  return FixedLayer
35
 
36
- # Tạo các lớp cố định cho Augmentation
37
- FixedRandomFlip = fix_augmentation_layer(RandomFlip)
38
- FixedRandomRotation = fix_augmentation_layer(RandomRotation)
39
- FixedRandomZoom = fix_augmentation_layer(RandomZoom)
40
- FixedRandomContrast = fix_augmentation_layer(RandomContrast)
 
 
 
 
 
 
 
 
41
 
42
  # =================================================================
43
 
44
  # 2. Tải mô hình
45
  model = None
46
  try:
47
- # Tắt log TensorFlow khi chạy trên nền tảng
48
  tf.get_logger().setLevel('ERROR')
49
 
50
- # SỬ DỤNG custom_objects KHI TẢI HÌNH
51
  model = tf.keras.models.load_model(
52
  MODEL_PATH,
53
- custom_objects={
54
- 'InputLayer': FixedInputLayer, # Fix InputLayer
55
- 'RandomFlip': FixedRandomFlip, # Fix RandomFlip
56
- 'RandomRotation': FixedRandomRotation, # Fix RandomRotation
57
- 'RandomZoom': FixedRandomZoom, # Fix RandomZoom
58
- 'RandomContrast': FixedRandomContrast # Fix RandomContrast
59
- # Bổ sung các lớp Augmentation khác nếu cần
60
- }
61
  )
62
  print("✅ Mô hình đã được tải thành công.")
63
  except Exception as e:
64
- # Nếu mô hình không tải được (ví dụ: lỗi cấu trúc), in ra lỗi
65
  print(f"❌ Lỗi tải mô hình: {e}")
66
  model = None
67
 
@@ -69,25 +93,30 @@ def predict_guava_quality(img_input):
69
  if model is None:
70
  return "❌ Lỗi: Không thể tải mô hình.", 0.0
71
 
72
- # Chuyển đổi từ numpy array của Gradio sang PIL Image và resize
73
- img_pil = Image.fromarray(img_input).convert("RGB")
74
- img_resized = img_pil.resize(IMG_SIZE)
75
-
76
- # Preprocess
77
- arr = np.array(img_resized).astype("float32")
78
- arr = preprocess_input(arr)
79
- arr = np.expand_dims(arr, 0) # Thêm dimension batch size
 
 
 
 
80
 
81
- # Dự đoán
82
- preds = model.predict(arr)[0]
83
- idx = np.argmax(preds)
84
- confidence = preds[idx]
85
- label = CLASS_NAMES[idx]
86
 
87
- # Trả về kết quả
88
- return f"✅ Kết quả: {label}", float(confidence)
 
89
 
90
- # 3. Định nghĩa giao diện Gradio
91
  demo = gr.Interface(
92
  fn=predict_guava_quality,
93
  inputs=gr.Image(type="numpy", label="Tải ảnh Quả Ổi"),
@@ -99,7 +128,6 @@ demo = gr.Interface(
99
  description="Tải lên ảnh quả ổi để phân loại thành: Hàng xuất khẩu (very_good), Hàng nội địa (good), hoặc Loại bỏ (bad)."
100
  )
101
 
102
- # 4. Chạy Gradio App (Hugging Face Spaces sẽ tự chạy file này)
103
  if __name__ == "__main__":
104
- # Server port và name cần thiết để Gradio chạy trong môi trường container
105
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
3
  import numpy as np
4
  from tensorflow.keras.applications.efficientnet import preprocess_input
5
  from PIL import Image
6
+ from tensorflow.keras.layers import InputLayer, RandomFlip, RandomRotation, RandomZoom, RandomContrast
7
+ from tensorflow.keras.layers import RandomWidth, RandomHeight
 
8
 
9
  # 1. Định nghĩa hằng số
10
+ MODEL_PATH = "best_model.h5"
11
  IMG_SIZE = (224, 224)
12
  CLASS_NAMES = ['bad', 'good', 'very_good']
13
 
14
  # =================================================================
15
+ # KHẮC PHỤC LỖI TƯƠNG THÍCH PHIÊN BẢN (KERAS 3 -> KERAS 2)
16
+ # =================================================================
17
 
18
+ # 1. Mock DTypePolicy (Xử lý lỗi: Unknown dtype policy: 'DTypePolicy')
19
+ class MockDTypePolicy:
20
+ """
21
+ Lớp giả lập để thay thế DTypePolicy của Keras 3.
22
+ Giúp tránh lỗi deserialization khi chạy trên môi trường cũ.
23
+ """
24
+ def __init__(self, **kwargs):
25
+ pass # Không làm gì cả, chỉ cần tồn tại để không báo lỗi
26
+
27
+ @classmethod
28
+ def from_config(cls, config):
29
+ return cls(**config)
30
+
31
+ def get_config(self):
32
+ return {}
33
+
34
+ # 2. Xử lý InputLayer (Xử lý lỗi: batch_shape)
35
  class FixedInputLayer(InputLayer):
 
36
  def __init__(self, **kwargs):
37
  if 'batch_shape' in kwargs:
38
  kwargs['input_shape'] = kwargs['batch_shape'][1:]
39
  del kwargs['batch_shape']
40
+ # Xóa dtype nếu nó là dạng dictionary (config của Keras 3)
41
+ if 'dtype' in kwargs and isinstance(kwargs['dtype'], dict):
42
+ del kwargs['dtype']
43
  super().__init__(**kwargs)
44
 
45
+ # 3. Xử lý Augmentation Layers (Xử lý lỗi: data_format & dtype policy)
46
  def fix_augmentation_layer(LayerClass):
 
47
  class FixedLayer(LayerClass):
48
  def __init__(self, **kwargs):
49
+ # Xóa tham số data_format gây lỗi
50
  if 'data_format' in kwargs:
51
  del kwargs['data_format']
52
+
53
+ # Xóa tham số dtype gây lỗi (DTypePolicy)
54
+ if 'dtype' in kwargs:
55
+ del kwargs['dtype']
56
+
57
  super().__init__(**kwargs)
58
  return FixedLayer
59
 
60
+ # 4. Đăng tất cả Custom Objects
61
+ CUSTOM_OBJECTS = {
62
+ 'InputLayer': FixedInputLayer,
63
+ # Đăng ký lớp giả lập DTypePolicy
64
+ 'DTypePolicy': MockDTypePolicy,
65
+ # Augmentation Layers đã được vá lỗi
66
+ 'RandomFlip': fix_augmentation_layer(RandomFlip),
67
+ 'RandomRotation': fix_augmentation_layer(RandomRotation),
68
+ 'RandomZoom': fix_augmentation_layer(RandomZoom),
69
+ 'RandomContrast': fix_augmentation_layer(RandomContrast),
70
+ 'RandomWidth': fix_augmentation_layer(RandomWidth),
71
+ 'RandomHeight': fix_augmentation_layer(RandomHeight)
72
+ }
73
 
74
  # =================================================================
75
 
76
  # 2. Tải mô hình
77
  model = None
78
  try:
79
+ # Tắt log TensorFlow
80
  tf.get_logger().setLevel('ERROR')
81
 
82
+ # Tải hình với danh sách custom objects đầy đủ
83
  model = tf.keras.models.load_model(
84
  MODEL_PATH,
85
+ custom_objects=CUSTOM_OBJECTS
 
 
 
 
 
 
 
86
  )
87
  print("✅ Mô hình đã được tải thành công.")
88
  except Exception as e:
 
89
  print(f"❌ Lỗi tải mô hình: {e}")
90
  model = None
91
 
 
93
  if model is None:
94
  return "❌ Lỗi: Không thể tải mô hình.", 0.0
95
 
96
+ if img_input is None:
97
+ return "❌ Vui lòng tải ảnh lên.", 0.0
98
+
99
+ try:
100
+ # Chuyển đổi ảnh
101
+ img_pil = Image.fromarray(img_input).convert("RGB")
102
+ img_resized = img_pil.resize(IMG_SIZE)
103
+
104
+ # Preprocess
105
+ arr = np.array(img_resized).astype("float32")
106
+ arr = preprocess_input(arr)
107
+ arr = np.expand_dims(arr, 0)
108
 
109
+ # Dự đoán
110
+ preds = model.predict(arr)[0]
111
+ idx = np.argmax(preds)
112
+ confidence = preds[idx]
113
+ label = CLASS_NAMES[idx]
114
 
115
+ return f"✅ Kết quả: {label}", float(confidence)
116
+ except Exception as e:
117
+ return f"❌ Lỗi xử lý ảnh: {str(e)}", 0.0
118
 
119
+ # 3. Giao diện Gradio
120
  demo = gr.Interface(
121
  fn=predict_guava_quality,
122
  inputs=gr.Image(type="numpy", label="Tải ảnh Quả Ổi"),
 
128
  description="Tải lên ảnh quả ổi để phân loại thành: Hàng xuất khẩu (very_good), Hàng nội địa (good), hoặc Loại bỏ (bad)."
129
  )
130
 
131
+ # 4. Chạy App
132
  if __name__ == "__main__":
 
133
  demo.launch(server_name="0.0.0.0", server_port=7860)