PandaLT commited on
Commit
c7801ef
·
verified ·
1 Parent(s): e890977

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +72 -0
  2. cnn_model.h5 +3 -0
  3. label_encoder.joblib +3 -0
  4. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import numpy as np
4
+ from PIL import Image
5
+ import tensorflow as tf
6
+ import joblib
7
+
8
+ # 1. Load Model CNN (Keras)
9
+ try:
10
+ model = tf.keras.models.load_model('cnn_model.h5')
11
+ print("Load model thành công")
12
+ except Exception as e:
13
+ print(f"Lỗi load model: {e}")
14
+ model = None
15
+
16
+ # 2. Load Label Encoder
17
+ try:
18
+ label_encoder = joblib.load('label_encoder.joblib')
19
+ print("Load encoder thành công")
20
+ except:
21
+ label_encoder = None
22
+
23
+ def preprocess_image(image):
24
+ # Chuyển sang RGB
25
+ image = image.convert("RGB")
26
+ # Resize đúng kích thước lúc train (64x64)
27
+ image = image.resize((64, 64))
28
+
29
+ # Chuyển thành mảng numpy và chuẩn hóa / 255.0
30
+ image_array = np.array(image) / 255.0
31
+
32
+ # QUAN TRỌNG: CNN cần input 4 chiều (Batch, Height, Width, Channel)
33
+ # Nên ta phải thêm 1 chiều ở đầu: (64,64,3) -> (1, 64, 64, 3)
34
+ image_array = np.expand_dims(image_array, axis=0)
35
+
36
+ return image_array
37
+
38
+ def predict(image):
39
+ if model is None or label_encoder is None:
40
+ return "Lỗi: Chưa load được model hoặc encoder."
41
+
42
+ try:
43
+ # Xử lý ảnh
44
+ processed_img = preprocess_image(image)
45
+
46
+ # Dự đoán
47
+ prediction = model.predict(processed_img)
48
+
49
+ # Lấy vị trí có xác suất cao nhất (argmax)
50
+ class_index = np.argmax(prediction)
51
+
52
+ # Chuyển từ số về tên nhãn
53
+ class_name = label_encoder.inverse_transform([class_index])[0]
54
+
55
+ # Lấy độ tin cậy (xác suất)
56
+ confidence = float(np.max(prediction))
57
+
58
+ return f"Kết quả: {class_name} ({confidence*100:.2f}%)"
59
+ except Exception as e:
60
+ return f"Lỗi dự đoán: {str(e)}"
61
+
62
+ # Giao diện Gradio
63
+ iface = gr.Interface(
64
+ fn=predict,
65
+ inputs=gr.Image(type="pil"),
66
+ outputs="text",
67
+ title="Nhận diện Biển báo Giao thông (CNN)",
68
+ description="Upload ảnh biển báo để model CNN dự đoán."
69
+ )
70
+
71
+ if __name__ == "__main__":
72
+ iface.launch()
cnn_model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f7ee3a6ac7b9cd4ca216ac7325c0dda935b06ab487e5149cadaf3a5ef716615
3
+ size 3199248
label_encoder.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40dae1d4f8eb9eaa66026e1d3756038d6d74e955054e08588f3987c208040e2c
3
+ size 343
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ tensorflow
3
+ numpy
4
+ Pillow
5
+ joblib
6
+ scikit-learn