PandaLT commited on
Commit
ccf20e1
·
verified ·
1 Parent(s): f5fc0c8

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +71 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from PIL import Image
5
+ import gradio as gr
6
+ import pickle
7
+
8
+ # Thông số
9
+ IMG_HEIGHT = 64
10
+ IMG_WIDTH = 64
11
+
12
+ # Load model và label encoder
13
+ model = tf.keras.models.load_model('traffic_sign_model.h5')
14
+ with open('label_encoder.pkl', 'rb') as f:
15
+ le = pickle.load(f)
16
+
17
+ def predict_traffic_sign(image):
18
+ """
19
+ Hàm dự đoán biển báo giao thông từ ảnh đầu vào
20
+ """
21
+ try:
22
+ # Xử lý ảnh đầu vào
23
+ img = Image.fromarray(image.astype('uint8')).convert('RGB')
24
+ img = img.resize((IMG_HEIGHT, IMG_WIDTH))
25
+ img_array = np.array(img).astype('float32') / 255.0
26
+ img_array = np.expand_dims(img_array, axis=0)
27
+
28
+ # Dự đoán
29
+ predictions = model.predict(img_array, verbose=0)
30
+ predicted_class_idx = np.argmax(predictions[0])
31
+ confidence = predictions[0][predicted_class_idx]
32
+
33
+ # Lấy tên lớp
34
+ predicted_class_name = le.inverse_transform([predicted_class_idx])[0]
35
+
36
+ # Tạo dictionary kết quả cho tất cả các lớp
37
+ results = {}
38
+ for idx, class_name in enumerate(le.classes_):
39
+ results[class_name] = float(predictions[0][idx])
40
+
41
+ return results
42
+
43
+ except Exception as e:
44
+ return {f"Error: {str(e)}": 0.0}
45
+
46
+ # Tạo Gradio Interface
47
+ demo = gr.Interface(
48
+ fn=predict_traffic_sign,
49
+ inputs=gr.Image(label="Tải ảnh biển báo giao thông"),
50
+ outputs=gr.Label(num_top_classes=5, label="Kết quả dự đoán"),
51
+ title="🚦 Nhận diện Biển báo Giao thông",
52
+ description="""
53
+ **Upload một ảnh biển báo giao thông để nhận diện.**
54
+
55
+ Model CNN được huấn luyện để phân loại các loại biển báo giao thông Việt Nam.
56
+
57
+ 📊 Kết quả hiển thị top 5 dự đoán có xác suất cao nhất.
58
+ """,
59
+ examples=[
60
+ # Thêm đường dẫn đến ảnh mẫu trong thư mục examples/
61
+ # ["examples/stop_sign.jpg"],
62
+ # ["examples/speed_limit.jpg"],
63
+ ],
64
+ theme=gr.themes.Soft(),
65
+ allow_flagging="never",
66
+ analytics_enabled=False
67
+ )
68
+
69
+ # Launch
70
+ if __name__ == "__main__":
71
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ tensorflow==2.15.0
2
+ numpy==1.24.3
3
+ scikit-learn==1.3.2
4
+ Pillow==10.1.0
5
+ gradio==4.10.0
6
+ matplotlib==3.8.2