import streamlit as st from PIL import Image import cv2 import numpy as np from ultralytics import YOLO import yaml from province import ThaiLicensePlateDetector import io import tempfile import os # Get the absolute path of the current directory BASE_DIR = os.path.dirname(os.path.abspath(__file__)) class LicensePlateProcessor: def __init__(self): # Load models for plate detection self.yolo_detector = YOLO(os.path.join(BASE_DIR, 'models', 'detect1.pt')) # For plate detection self.char_reader = YOLO(os.path.join(BASE_DIR, 'models', 'read_char.pt')) # For character reading # Initialize province detector province_model = os.path.join(BASE_DIR, 'models', 'best_province.pt') province_yaml = os.path.join(BASE_DIR, 'config', 'data_province.yaml') self.province_detector = ThaiLicensePlateDetector(province_model, province_yaml) # Load character mapping from yaml with open(os.path.join(BASE_DIR, 'config', 'data.yaml'), 'r', encoding='utf-8') as f: data_config = yaml.safe_load(f) self.char_mapping = data_config.get('char_mapping', {}) self.names = data_config['names'] self.CONF_THRESHOLD = 0.3 def _map_class_to_char(self, class_name): """Map class to character using yaml mapping""" if str(class_name) in self.char_mapping: return self.char_mapping[str(class_name)] return str(class_name) def read_plate_characters(self, plate_image): """Read characters from plate image""" results = self.char_reader.predict(plate_image, conf=0.3) detections = [] for r in results: boxes = r.boxes for box in boxes: x1, y1, x2, y2 = map(int, box.xyxy[0]) confidence = float(box.conf[0]) class_id = int(box.cls[0]) mapped_char = self._map_class_to_char(self.names[class_id]) detections.append({ 'char': mapped_char, 'confidence': confidence, 'bbox': (x1, y1, x2, y2) }) # Sort detections left to right detections.sort(key=lambda x: x['bbox'][0]) # Combine characters plate_text = ''.join(det['char'] for det in detections) return plate_text def process_image(self, image): try: data = {"plate_number": "", "province": "", "raw_province": ""} # Save visualization output_image = image.copy() # Detect license plate location results = self.yolo_detector(image) for result in results: for box in result.boxes: confidence = float(box.conf) if confidence < self.CONF_THRESHOLD: continue x1, y1, x2, y2 = map(int, box.xyxy.flatten()) cropped_image = image[y1:y2, x1:x2] # Draw rectangle on output image color = (0, 255, 0) if int(box.cls.item()) == 0 else (255, 0, 0) cv2.rectangle(output_image, (x1, y1), (x2, y2), color, 2) if int(box.cls.item()) == 0: # License plate number # Process license plate number data["plate_number"] = self.read_plate_characters(cropped_image) # Process province using ThaiLicensePlateDetector try: province_results = self.province_detector.detect_and_process(cropped_image) if province_results and province_results['plate_text']: data["raw_province"] = province_results['plate_text'] data["province"] = province_results['plate_text'] except Exception as e: pass return data, output_image except Exception as e: return None, None def main(): st.title("Thai License Plate Detection") st.write("Upload an image to detect license plate and province") # Initialize processor processor = LicensePlateProcessor() # File uploader uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: # Convert uploaded file to image file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8) image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) # Process image with st.spinner('Processing image...'): results, output_image = processor.process_image(image) if results and output_image is not None: # Display results st.subheader("Detection Results:") col1, col2 = st.columns(2) with col1: st.write("License Plate Number:", results['plate_number']) st.write("Province:", results['province']) with col2: # Convert BGR to RGB for display output_image_rgb = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB) st.image(output_image_rgb, caption='Detection Result', use_column_width=True) else: st.error("Failed to process image") if __name__ == "__main__": main()