PandaLT commited on
Commit
9be12cd
·
verified ·
1 Parent(s): 1f55b66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -7
app.py CHANGED
@@ -1,19 +1,45 @@
1
  import os
2
  import numpy as np
3
  import tensorflow as tf
 
4
  from PIL import Image
5
  import gradio as gr
6
  import pickle
 
7
 
8
- # Thông số
9
- IMG_HEIGHT = 64
10
- IMG_WIDTH = 64
11
 
12
- # Load model và label encoder
13
- model = tf.keras.models.load_model('traffic_sign_model.keras')
 
 
 
14
  with open('label_encoder.pkl', 'rb') as f:
15
  le = pickle.load(f)
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def predict_traffic_sign(image):
18
  """
19
  Hàm dự đoán biển báo giao thông từ ảnh đầu vào
@@ -58,10 +84,12 @@ demo = gr.Interface(
58
  """,
59
  examples=[
60
  # Thêm đường dẫn đến ảnh mẫu trong thư mục examples/
61
- # ["examples/stop_sign.jpg"],
62
- # ["examples/speed_limit.jpg"],
63
  ],
64
  theme=gr.themes.Soft(),
65
  allow_flagging="never",
66
  analytics_enabled=False
67
  )
 
 
 
 
 
1
  import os
2
  import numpy as np
3
  import tensorflow as tf
4
+ from tensorflow.keras import layers, models
5
  from PIL import Image
6
  import gradio as gr
7
  import pickle
8
+ import json
9
 
10
+ # Load config
11
+ with open('model_config.json', 'r') as f:
12
+ config = json.load(f)
13
 
14
+ IMG_HEIGHT = config['img_height']
15
+ IMG_WIDTH = config['img_width']
16
+ num_classes = config['num_classes']
17
+
18
+ # Load label encoder
19
  with open('label_encoder.pkl', 'rb') as f:
20
  le = pickle.load(f)
21
 
22
+ # Rebuild model architecture
23
+ model = models.Sequential([
24
+ layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(IMG_HEIGHT, IMG_WIDTH, 3)),
25
+ layers.MaxPooling2D((2, 2)),
26
+ layers.Flatten(),
27
+ layers.Dropout(0.5),
28
+ layers.Dense(num_classes, activation='softmax')
29
+ ])
30
+
31
+ # Load weights
32
+ model.load_weights('model_weights.h5')
33
+
34
+ # Compile (cần thiết cho predict)
35
+ model.compile(
36
+ optimizer='adam',
37
+ loss='sparse_categorical_crossentropy',
38
+ metrics=['accuracy']
39
+ )
40
+
41
+ print("✅ Model loaded successfully!")
42
+
43
  def predict_traffic_sign(image):
44
  """
45
  Hàm dự đoán biển báo giao thông từ ảnh đầu vào
 
84
  """,
85
  examples=[
86
  # Thêm đường dẫn đến ảnh mẫu trong thư mục examples/
 
 
87
  ],
88
  theme=gr.themes.Soft(),
89
  allow_flagging="never",
90
  analytics_enabled=False
91
  )
92
+
93
+ # Launch
94
+ if __name__ == "__main__":
95
+ demo.launch()