Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from PIL import Image | |
| import cv2 | |
| import numpy as np | |
| from ultralytics import YOLO | |
| import yaml | |
| from province import ThaiLicensePlateDetector | |
| import io | |
| import tempfile | |
| import os | |
| # Get the absolute path of the current directory | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| class LicensePlateProcessor: | |
| def __init__(self): | |
| # Load models for plate detection | |
| self.yolo_detector = YOLO(os.path.join(BASE_DIR, 'models', 'detect1.pt')) # For plate detection | |
| self.char_reader = YOLO(os.path.join(BASE_DIR, 'models', 'read_char.pt')) # For character reading | |
| # Initialize province detector | |
| province_model = os.path.join(BASE_DIR, 'models', 'best_province.pt') | |
| province_yaml = os.path.join(BASE_DIR, 'config', 'data_province.yaml') | |
| self.province_detector = ThaiLicensePlateDetector(province_model, province_yaml) | |
| # Load character mapping from yaml | |
| with open(os.path.join(BASE_DIR, 'config', 'data.yaml'), 'r', encoding='utf-8') as f: | |
| data_config = yaml.safe_load(f) | |
| self.char_mapping = data_config.get('char_mapping', {}) | |
| self.names = data_config['names'] | |
| self.CONF_THRESHOLD = 0.3 | |
| def _map_class_to_char(self, class_name): | |
| """Map class to character using yaml mapping""" | |
| if str(class_name) in self.char_mapping: | |
| return self.char_mapping[str(class_name)] | |
| return str(class_name) | |
| def read_plate_characters(self, plate_image): | |
| """Read characters from plate image""" | |
| results = self.char_reader.predict(plate_image, conf=0.3) | |
| detections = [] | |
| for r in results: | |
| boxes = r.boxes | |
| for box in boxes: | |
| x1, y1, x2, y2 = map(int, box.xyxy[0]) | |
| confidence = float(box.conf[0]) | |
| class_id = int(box.cls[0]) | |
| mapped_char = self._map_class_to_char(self.names[class_id]) | |
| detections.append({ | |
| 'char': mapped_char, | |
| 'confidence': confidence, | |
| 'bbox': (x1, y1, x2, y2) | |
| }) | |
| # Sort detections left to right | |
| detections.sort(key=lambda x: x['bbox'][0]) | |
| # Combine characters | |
| plate_text = ''.join(det['char'] for det in detections) | |
| return plate_text | |
| def process_image(self, image): | |
| try: | |
| data = {"plate_number": "", "province": "", "raw_province": ""} | |
| # Save visualization | |
| output_image = image.copy() | |
| # Detect license plate location | |
| results = self.yolo_detector(image) | |
| for result in results: | |
| for box in result.boxes: | |
| confidence = float(box.conf) | |
| if confidence < self.CONF_THRESHOLD: | |
| continue | |
| x1, y1, x2, y2 = map(int, box.xyxy.flatten()) | |
| cropped_image = image[y1:y2, x1:x2] | |
| # Draw rectangle on output image | |
| color = (0, 255, 0) if int(box.cls.item()) == 0 else (255, 0, 0) | |
| cv2.rectangle(output_image, (x1, y1), (x2, y2), color, 2) | |
| if int(box.cls.item()) == 0: # License plate number | |
| # Process license plate number | |
| data["plate_number"] = self.read_plate_characters(cropped_image) | |
| # Process province using ThaiLicensePlateDetector | |
| try: | |
| province_results = self.province_detector.detect_and_process(cropped_image) | |
| if province_results and province_results['plate_text']: | |
| data["raw_province"] = province_results['plate_text'] | |
| data["province"] = province_results['plate_text'] | |
| except Exception as e: | |
| pass | |
| return data, output_image | |
| except Exception as e: | |
| return None, None | |
| def main(): | |
| st.title("Thai License Plate Detection") | |
| st.write("Upload an image to detect license plate and province") | |
| # Initialize processor | |
| processor = LicensePlateProcessor() | |
| # File uploader | |
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| # Convert uploaded file to image | |
| file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8) | |
| image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) | |
| # Process image | |
| with st.spinner('Processing image...'): | |
| results, output_image = processor.process_image(image) | |
| if results and output_image is not None: | |
| # Display results | |
| st.subheader("Detection Results:") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.write("License Plate Number:", results['plate_number']) | |
| st.write("Province:", results['province']) | |
| with col2: | |
| # Convert BGR to RGB for display | |
| output_image_rgb = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB) | |
| st.image(output_image_rgb, caption='Detection Result', use_column_width=True) | |
| else: | |
| st.error("Failed to process image") | |
| if __name__ == "__main__": | |
| main() |