cang1602004 commited on
Commit
83fbb3a
·
verified ·
1 Parent(s): 738ca62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -6
app.py CHANGED
@@ -3,19 +3,48 @@ import tensorflow as tf
3
  import numpy as np
4
  from tensorflow.keras.applications.efficientnet import preprocess_input
5
  from PIL import Image
 
6
 
7
  # 1. Định nghĩa hằng số
8
  MODEL_PATH = "best_model.h5" # Tên file mô hình đã upload
9
  IMG_SIZE = (224, 224)
10
- CLASS_NAMES = ['bad', 'good', 'very_good'] # Dùng list cứng nếu không muốn dùng pickle
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # 2. Tải mô hình
13
- # Sử dụng try/except để xử lý lỗi nếu model.h5 quá lớn
14
  try:
15
- model = tf.keras.models.load_model(MODEL_PATH)
 
 
 
 
 
 
 
 
 
 
16
  except Exception as e:
17
  # Nếu mô hình không tải được (ví dụ: lỗi cấu trúc), in ra lỗi
18
- print(f"Lỗi tải mô hình: {e}")
19
  model = None
20
 
21
  def predict_guava_quality(img_input):
@@ -52,5 +81,7 @@ demo = gr.Interface(
52
  description="Tải lên ảnh quả ổi để phân loại thành: Hàng xuất khẩu (very_good), Hàng nội địa (good), hoặc Loại bỏ (bad)."
53
  )
54
 
55
- # 4. Chạy Gradio App (Hugging Face Spaces sẽ tự gọi file này)
56
- # demo.launch()
 
 
 
3
  import numpy as np
4
  from tensorflow.keras.applications.efficientnet import preprocess_input
5
  from PIL import Image
6
+ from tensorflow.keras.layers import InputLayer
7
 
8
  # 1. Định nghĩa hằng số
9
  MODEL_PATH = "best_model.h5" # Tên file mô hình đã upload
10
  IMG_SIZE = (224, 224)
11
+ CLASS_NAMES = ['bad', 'good', 'very_good']
12
+
13
+ # =================================================================
14
+ # KHẮC PHỤC LỖI DESERIALIZATION (batch_shape)
15
+ # Nếu mô hình được lưu bằng Keras cũ, ta cần định nghĩa lại InputLayer
16
+ # để chấp nhận tham số batch_shape khi tải.
17
+ class FixedInputLayer(InputLayer):
18
+ """
19
+ Tạo InputLayer cố định để xử lý lỗi "Unrecognized keyword arguments: ['batch_shape']"
20
+ khi tải mô hình đã lưu bằng TF/Keras cũ hơn.
21
+ """
22
+ def __init__(self, **kwargs):
23
+ if 'batch_shape' in kwargs:
24
+ # Chuyển batch_shape thành input_shape
25
+ kwargs['input_shape'] = kwargs['batch_shape'][1:]
26
+ del kwargs['batch_shape']
27
+ super().__init__(**kwargs)
28
+
29
+ # =================================================================
30
 
31
  # 2. Tải mô hình
32
+ model = None
33
  try:
34
+ # Tắt log TensorFlow khi chạy trên nền tảng
35
+ tf.get_logger().setLevel('ERROR')
36
+
37
+ # SỬ DỤNG custom_objects KHI TẢI MÔ HÌNH
38
+ model = tf.keras.models.load_model(
39
+ MODEL_PATH,
40
+ custom_objects={
41
+ 'InputLayer': FixedInputLayer # Sử dụng lớp FixedInputLayer để thay thế
42
+ }
43
+ )
44
+ print("✅ Mô hình đã được tải thành công.")
45
  except Exception as e:
46
  # Nếu mô hình không tải được (ví dụ: lỗi cấu trúc), in ra lỗi
47
+ print(f"Lỗi tải mô hình: {e}")
48
  model = None
49
 
50
  def predict_guava_quality(img_input):
 
81
  description="Tải lên ảnh quả ổi để phân loại thành: Hàng xuất khẩu (very_good), Hàng nội địa (good), hoặc Loại bỏ (bad)."
82
  )
83
 
84
+ # 4. Chạy Gradio App (Hugging Face Spaces sẽ tự chạy file này)
85
+ if __name__ == "__main__":
86
+ # Server port và name cần thiết để Gradio chạy trong môi trường container
87
+ demo.launch(server_name="0.0.0.0", server_port=7860)