PandaLT commited on
Commit
e890977
·
verified ·
1 Parent(s): 19348c5

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -95
app.py DELETED
@@ -1,95 +0,0 @@
1
- import os
2
- import numpy as np
3
- import tensorflow as tf
4
- from tensorflow.keras import layers, models
5
- from PIL import Image
6
- import gradio as gr
7
- import pickle
8
- import json
9
-
10
- # Load config
11
- with open('model_config.json', 'r') as f:
12
- config = json.load(f)
13
-
14
- IMG_HEIGHT = config['img_height']
15
- IMG_WIDTH = config['img_width']
16
- num_classes = config['num_classes']
17
-
18
- # Load label encoder
19
- with open('label_encoder.pkl', 'rb') as f:
20
- le = pickle.load(f)
21
-
22
- # Rebuild model architecture
23
- model = models.Sequential([
24
- layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(IMG_HEIGHT, IMG_WIDTH, 3)),
25
- layers.MaxPooling2D((2, 2)),
26
- layers.Flatten(),
27
- layers.Dropout(0.5),
28
- layers.Dense(num_classes, activation='softmax')
29
- ])
30
-
31
- # Load weights
32
- model.load_weights('model_weights.h5')
33
-
34
- # Compile (cần thiết cho predict)
35
- model.compile(
36
- optimizer='adam',
37
- loss='sparse_categorical_crossentropy',
38
- metrics=['accuracy']
39
- )
40
-
41
- print("✅ Model loaded successfully!")
42
-
43
- def predict_traffic_sign(image):
44
- """
45
- Hàm dự đoán biển báo giao thông từ ảnh đầu vào
46
- """
47
- try:
48
- # Xử lý ảnh đầu vào
49
- img = Image.fromarray(image.astype('uint8')).convert('RGB')
50
- img = img.resize((IMG_HEIGHT, IMG_WIDTH))
51
- img_array = np.array(img).astype('float32') / 255.0
52
- img_array = np.expand_dims(img_array, axis=0)
53
-
54
- # Dự đoán
55
- predictions = model.predict(img_array, verbose=0)
56
- predicted_class_idx = np.argmax(predictions[0])
57
- confidence = predictions[0][predicted_class_idx]
58
-
59
- # Lấy tên lớp
60
- predicted_class_name = le.inverse_transform([predicted_class_idx])[0]
61
-
62
- # Tạo dictionary kết quả cho tất cả các lớp
63
- results = {}
64
- for idx, class_name in enumerate(le.classes_):
65
- results[class_name] = float(predictions[0][idx])
66
-
67
- return results
68
-
69
- except Exception as e:
70
- return {f"Error: {str(e)}": 0.0}
71
-
72
- # Tạo Gradio Interface
73
- demo = gr.Interface(
74
- fn=predict_traffic_sign,
75
- inputs=gr.Image(label="Tải ảnh biển báo giao thông"),
76
- outputs=gr.Label(num_top_classes=5, label="Kết quả dự đoán"),
77
- title="🚦 Nhận diện Biển báo Giao thông",
78
- description="""
79
- **Upload một ảnh biển báo giao thông để nhận diện.**
80
-
81
- Model CNN được huấn luyện để phân loại các loại biển báo giao thông Việt Nam.
82
-
83
- 📊 Kết quả hiển thị top 5 dự đoán có xác suất cao nhất.
84
- """,
85
- examples=[
86
- # Thêm đường dẫn đến ảnh mẫu trong thư mục examples/
87
- ],
88
- theme=gr.themes.Soft(),
89
- allow_flagging="never",
90
- analytics_enabled=False
91
- )
92
-
93
- # Launch
94
- if __name__ == "__main__":
95
- demo.launch()