PandaLT's picture
Update app.py
9be12cd verified
raw
history blame
2.82 kB
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from PIL import Image
import gradio as gr
import pickle
import json
# Load config
with open('model_config.json', 'r') as f:
config = json.load(f)
IMG_HEIGHT = config['img_height']
IMG_WIDTH = config['img_width']
num_classes = config['num_classes']
# Load label encoder
with open('label_encoder.pkl', 'rb') as f:
le = pickle.load(f)
# Rebuild model architecture
model = models.Sequential([
layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(IMG_HEIGHT, IMG_WIDTH, 3)),
layers.MaxPooling2D((2, 2)),
layers.Flatten(),
layers.Dropout(0.5),
layers.Dense(num_classes, activation='softmax')
])
# Load weights
model.load_weights('model_weights.h5')
# Compile (cần thiết cho predict)
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
print("✅ Model loaded successfully!")
def predict_traffic_sign(image):
"""
Hàm dự đoán biển báo giao thông từ ảnh đầu vào
"""
try:
# Xử lý ảnh đầu vào
img = Image.fromarray(image.astype('uint8')).convert('RGB')
img = img.resize((IMG_HEIGHT, IMG_WIDTH))
img_array = np.array(img).astype('float32') / 255.0
img_array = np.expand_dims(img_array, axis=0)
# Dự đoán
predictions = model.predict(img_array, verbose=0)
predicted_class_idx = np.argmax(predictions[0])
confidence = predictions[0][predicted_class_idx]
# Lấy tên lớp
predicted_class_name = le.inverse_transform([predicted_class_idx])[0]
# Tạo dictionary kết quả cho tất cả các lớp
results = {}
for idx, class_name in enumerate(le.classes_):
results[class_name] = float(predictions[0][idx])
return results
except Exception as e:
return {f"Error: {str(e)}": 0.0}
# Tạo Gradio Interface
demo = gr.Interface(
fn=predict_traffic_sign,
inputs=gr.Image(label="Tải ảnh biển báo giao thông"),
outputs=gr.Label(num_top_classes=5, label="Kết quả dự đoán"),
title="🚦 Nhận diện Biển báo Giao thông",
description="""
**Upload một ảnh biển báo giao thông để nhận diện.**
Model CNN được huấn luyện để phân loại các loại biển báo giao thông Việt Nam.
📊 Kết quả hiển thị top 5 dự đoán có xác suất cao nhất.
""",
examples=[
# Thêm đường dẫn đến ảnh mẫu trong thư mục examples/
],
theme=gr.themes.Soft(),
allow_flagging="never",
analytics_enabled=False
)
# Launch
if __name__ == "__main__":
demo.launch()