cang1602004 commited on
Commit
700dd4e
·
verified ·
1 Parent(s): 1e6035c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -9
app.py CHANGED
@@ -4,6 +4,8 @@ import numpy as np
4
  from tensorflow.keras.applications.efficientnet import preprocess_input
5
  from PIL import Image
6
  from tensorflow.keras.layers import InputLayer
 
 
7
 
8
  # 1. Định nghĩa hằng số
9
  MODEL_PATH = "best_model.h5" # Tên file mô hình đã upload (ĐÃ SỬA: Đồng bộ với tên file trên GitHub)
@@ -11,21 +13,32 @@ IMG_SIZE = (224, 224)
11
  CLASS_NAMES = ['bad', 'good', 'very_good']
12
 
13
  # =================================================================
14
- # KHẮC PHỤC LỖI DESERIALIZATION (batch_shape)
15
- # Nếu hình được lưu bằng Keras cũ, ta cần định nghĩa lại InputLayer
16
- # để chấp nhận tham số batch_shape khi tải.
17
  class FixedInputLayer(InputLayer):
18
- """
19
- Tạo InputLayer cố định để xử lý lỗi "Unrecognized keyword arguments: ['batch_shape']"
20
- khi tải mô hình đã lưu bằng TF/Keras cũ hơn.
21
- """
22
  def __init__(self, **kwargs):
23
  if 'batch_shape' in kwargs:
24
- # Chuyển batch_shape thành input_shape
25
  kwargs['input_shape'] = kwargs['batch_shape'][1:]
26
  del kwargs['batch_shape']
27
  super().__init__(**kwargs)
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # =================================================================
30
 
31
  # 2. Tải mô hình
@@ -38,7 +51,12 @@ try:
38
  model = tf.keras.models.load_model(
39
  MODEL_PATH,
40
  custom_objects={
41
- 'InputLayer': FixedInputLayer # Sử dụng lớp FixedInputLayer để thay thế
 
 
 
 
 
42
  }
43
  )
44
  print("✅ Mô hình đã được tải thành công.")
 
4
  from tensorflow.keras.applications.efficientnet import preprocess_input
5
  from PIL import Image
6
  from tensorflow.keras.layers import InputLayer
7
+ # Import các lớp Augmentation mà mô hình có thể đã lưu
8
+ from tensorflow.keras.layers import RandomFlip, RandomRotation, RandomZoom, RandomContrast
9
 
10
  # 1. Định nghĩa hằng số
11
  MODEL_PATH = "best_model.h5" # Tên file mô hình đã upload (ĐÃ SỬA: Đồng bộ với tên file trên GitHub)
 
13
  CLASS_NAMES = ['bad', 'good', 'very_good']
14
 
15
  # =================================================================
16
+ # KHẮC PHỤC LỖI DESERIALIZATION CHO CÁC LỚP AUGMENTATION VÀ INPUT LAYER
17
+ # Mục đích: Bỏ qua tham số 'batch_shape' 'data_format' không tương thích.
18
+
19
  class FixedInputLayer(InputLayer):
20
+ """Xử lý lỗi batch_shape."""
 
 
 
21
  def __init__(self, **kwargs):
22
  if 'batch_shape' in kwargs:
 
23
  kwargs['input_shape'] = kwargs['batch_shape'][1:]
24
  del kwargs['batch_shape']
25
  super().__init__(**kwargs)
26
 
27
+ def fix_augmentation_layer(LayerClass):
28
+ """Tạo lớp cố định để loại bỏ tham số 'data_format'."""
29
+ class FixedLayer(LayerClass):
30
+ def __init__(self, **kwargs):
31
+ if 'data_format' in kwargs:
32
+ del kwargs['data_format']
33
+ super().__init__(**kwargs)
34
+ return FixedLayer
35
+
36
+ # Tạo các lớp cố định cho Augmentation
37
+ FixedRandomFlip = fix_augmentation_layer(RandomFlip)
38
+ FixedRandomRotation = fix_augmentation_layer(RandomRotation)
39
+ FixedRandomZoom = fix_augmentation_layer(RandomZoom)
40
+ FixedRandomContrast = fix_augmentation_layer(RandomContrast)
41
+
42
  # =================================================================
43
 
44
  # 2. Tải mô hình
 
51
  model = tf.keras.models.load_model(
52
  MODEL_PATH,
53
  custom_objects={
54
+ 'InputLayer': FixedInputLayer, # Fix InputLayer
55
+ 'RandomFlip': FixedRandomFlip, # Fix RandomFlip
56
+ 'RandomRotation': FixedRandomRotation, # Fix RandomRotation
57
+ 'RandomZoom': FixedRandomZoom, # Fix RandomZoom
58
+ 'RandomContrast': FixedRandomContrast # Fix RandomContrast
59
+ # Bổ sung các lớp Augmentation khác nếu cần
60
  }
61
  )
62
  print("✅ Mô hình đã được tải thành công.")