PandaLT commited on
Commit
dfea709
·
verified ·
1 Parent(s): 5f413e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -71
app.py CHANGED
@@ -1,71 +1,67 @@
1
- import os
2
- import numpy as np
3
- import tensorflow as tf
4
- from PIL import Image
5
- import gradio as gr
6
- import pickle
7
-
8
- # Thông số
9
- IMG_HEIGHT = 64
10
- IMG_WIDTH = 64
11
-
12
- # Load model và label encoder
13
- model = tf.keras.models.load_model('traffic_sign_model.h5')
14
- with open('label_encoder.pkl', 'rb') as f:
15
- le = pickle.load(f)
16
-
17
- def predict_traffic_sign(image):
18
- """
19
- Hàm dự đoán biển báo giao thông từ ảnh đầu vào
20
- """
21
- try:
22
- # Xử lý ảnh đầu vào
23
- img = Image.fromarray(image.astype('uint8')).convert('RGB')
24
- img = img.resize((IMG_HEIGHT, IMG_WIDTH))
25
- img_array = np.array(img).astype('float32') / 255.0
26
- img_array = np.expand_dims(img_array, axis=0)
27
-
28
- # Dự đoán
29
- predictions = model.predict(img_array, verbose=0)
30
- predicted_class_idx = np.argmax(predictions[0])
31
- confidence = predictions[0][predicted_class_idx]
32
-
33
- # Lấy tên lớp
34
- predicted_class_name = le.inverse_transform([predicted_class_idx])[0]
35
-
36
- # Tạo dictionary kết quả cho tất cả các lớp
37
- results = {}
38
- for idx, class_name in enumerate(le.classes_):
39
- results[class_name] = float(predictions[0][idx])
40
-
41
- return results
42
-
43
- except Exception as e:
44
- return {f"Error: {str(e)}": 0.0}
45
-
46
- # Tạo Gradio Interface
47
- demo = gr.Interface(
48
- fn=predict_traffic_sign,
49
- inputs=gr.Image(label="Tải ảnh biển báo giao thông"),
50
- outputs=gr.Label(num_top_classes=5, label="Kết quả dự đoán"),
51
- title="🚦 Nhận diện Biển báo Giao thông",
52
- description="""
53
- **Upload một ảnh biển báo giao thông để nhận diện.**
54
-
55
- Model CNN được huấn luyện để phân loại các loại biển báo giao thông Việt Nam.
56
-
57
- 📊 Kết quả hiển thị top 5 dự đoán có xác suất cao nhất.
58
- """,
59
- examples=[
60
- # Thêm đường dẫn đến ảnh mẫu trong thư mục examples/
61
- # ["examples/stop_sign.jpg"],
62
- # ["examples/speed_limit.jpg"],
63
- ],
64
- theme=gr.themes.Soft(),
65
- allow_flagging="never",
66
- analytics_enabled=False
67
- )
68
-
69
- # Launch
70
- if __name__ == "__main__":
71
- demo.launch()
 
1
+ import os
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from PIL import Image
5
+ import gradio as gr
6
+ import pickle
7
+
8
+ # Thông số
9
+ IMG_HEIGHT = 64
10
+ IMG_WIDTH = 64
11
+
12
+ # Load model và label encoder
13
+ model = tf.keras.models.load_model('traffic_sign_model.keras')
14
+ with open('label_encoder.pkl', 'rb') as f:
15
+ le = pickle.load(f)
16
+
17
+ def predict_traffic_sign(image):
18
+ """
19
+ Hàm dự đoán biển báo giao thông từ ảnh đầu vào
20
+ """
21
+ try:
22
+ # Xử lý ảnh đầu vào
23
+ img = Image.fromarray(image.astype('uint8')).convert('RGB')
24
+ img = img.resize((IMG_HEIGHT, IMG_WIDTH))
25
+ img_array = np.array(img).astype('float32') / 255.0
26
+ img_array = np.expand_dims(img_array, axis=0)
27
+
28
+ # Dự đoán
29
+ predictions = model.predict(img_array, verbose=0)
30
+ predicted_class_idx = np.argmax(predictions[0])
31
+ confidence = predictions[0][predicted_class_idx]
32
+
33
+ # Lấy tên lớp
34
+ predicted_class_name = le.inverse_transform([predicted_class_idx])[0]
35
+
36
+ # Tạo dictionary kết quả cho tất cả các lớp
37
+ results = {}
38
+ for idx, class_name in enumerate(le.classes_):
39
+ results[class_name] = float(predictions[0][idx])
40
+
41
+ return results
42
+
43
+ except Exception as e:
44
+ return {f"Error: {str(e)}": 0.0}
45
+
46
+ # Tạo Gradio Interface
47
+ demo = gr.Interface(
48
+ fn=predict_traffic_sign,
49
+ inputs=gr.Image(label="Tải ảnh biển báo giao thông"),
50
+ outputs=gr.Label(num_top_classes=5, label="Kết quả dự đoán"),
51
+ title="🚦 Nhận diện Biển báo Giao thông",
52
+ description="""
53
+ **Upload một ảnh biển báo giao thông để nhận diện.**
54
+
55
+ Model CNN được huấn luyện để phân loại các loại biển báo giao thông Việt Nam.
56
+
57
+ 📊 Kết quả hiển thị top 5 dự đoán có xác suất cao nhất.
58
+ """,
59
+ examples=[
60
+ # Thêm đường dẫn đến ảnh mẫu trong thư mục examples/
61
+ # ["examples/stop_sign.jpg"],
62
+ # ["examples/speed_limit.jpg"],
63
+ ],
64
+ theme=gr.themes.Soft(),
65
+ allow_flagging="never",
66
+ analytics_enabled=False
67
+ )