File size: 5,567 Bytes
6393710
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import streamlit as st
from PIL import Image
import cv2
import numpy as np
from ultralytics import YOLO
import yaml
from province import ThaiLicensePlateDetector
import io
import tempfile
import os

# Get the absolute path of the current directory
BASE_DIR = os.path.dirname(os.path.abspath(__file__))

class LicensePlateProcessor:
    def __init__(self):
        # Load models for plate detection
        self.yolo_detector = YOLO(os.path.join(BASE_DIR, 'models', 'detect1.pt'))  # For plate detection
        self.char_reader = YOLO(os.path.join(BASE_DIR, 'models', 'read_char.pt'))  # For character reading
        
        # Initialize province detector
        province_model = os.path.join(BASE_DIR, 'models', 'best_province.pt')
        province_yaml = os.path.join(BASE_DIR, 'config', 'data_province.yaml')
        self.province_detector = ThaiLicensePlateDetector(province_model, province_yaml)
        
        # Load character mapping from yaml
        with open(os.path.join(BASE_DIR, 'config', 'data.yaml'), 'r', encoding='utf-8') as f:
            data_config = yaml.safe_load(f)
            self.char_mapping = data_config.get('char_mapping', {})
            self.names = data_config['names']
        
        self.CONF_THRESHOLD = 0.3

    def _map_class_to_char(self, class_name):
        """Map class to character using yaml mapping"""
        if str(class_name) in self.char_mapping:
            return self.char_mapping[str(class_name)]
        return str(class_name)

    def read_plate_characters(self, plate_image):
        """Read characters from plate image"""
        results = self.char_reader.predict(plate_image, conf=0.3)
        
        detections = []
        for r in results:
            boxes = r.boxes
            for box in boxes:
                x1, y1, x2, y2 = map(int, box.xyxy[0])
                confidence = float(box.conf[0])
                class_id = int(box.cls[0])
                
                mapped_char = self._map_class_to_char(self.names[class_id])
                
                detections.append({
                    'char': mapped_char,
                    'confidence': confidence,
                    'bbox': (x1, y1, x2, y2)
                })
        
        # Sort detections left to right
        detections.sort(key=lambda x: x['bbox'][0])
        
        # Combine characters
        plate_text = ''.join(det['char'] for det in detections)
        return plate_text

    def process_image(self, image):
        try:
            data = {"plate_number": "", "province": "", "raw_province": ""}
            
            # Save visualization
            output_image = image.copy()
            
            # Detect license plate location
            results = self.yolo_detector(image)
            
            for result in results:
                for box in result.boxes:
                    confidence = float(box.conf)
                    if confidence < self.CONF_THRESHOLD:
                        continue

                    x1, y1, x2, y2 = map(int, box.xyxy.flatten())
                    cropped_image = image[y1:y2, x1:x2]

                    # Draw rectangle on output image
                    color = (0, 255, 0) if int(box.cls.item()) == 0 else (255, 0, 0)
                    cv2.rectangle(output_image, (x1, y1), (x2, y2), color, 2)

                    if int(box.cls.item()) == 0:  # License plate number
                        # Process license plate number
                        data["plate_number"] = self.read_plate_characters(cropped_image)
                        
                        # Process province using ThaiLicensePlateDetector
                        try:
                            province_results = self.province_detector.detect_and_process(cropped_image)
                            if province_results and province_results['plate_text']:
                                data["raw_province"] = province_results['plate_text']
                                data["province"] = province_results['plate_text']
                        except Exception as e:
                            pass

            return data, output_image

        except Exception as e:
            return None, None

def main():
    st.title("Thai License Plate Detection")
    st.write("Upload an image to detect license plate and province")

    # Initialize processor
    processor = LicensePlateProcessor()

    # File uploader
    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

    if uploaded_file is not None:
        # Convert uploaded file to image
        file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
        image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)

        # Process image
        with st.spinner('Processing image...'):
            results, output_image = processor.process_image(image)

        if results and output_image is not None:
            # Display results
            st.subheader("Detection Results:")
            col1, col2 = st.columns(2)
            
            with col1:
                st.write("License Plate Number:", results['plate_number'])
                st.write("Province:", results['province'])
            
            with col2:
                # Convert BGR to RGB for display
                output_image_rgb = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)
                st.image(output_image_rgb, caption='Detection Result', use_column_width=True)

        else:
            st.error("Failed to process image")

if __name__ == "__main__":
    main()