Abs6187 commited on
Commit
7565b34
·
verified ·
1 Parent(s): 939b7fb

Create license_plate_ocr.py

Browse files
Files changed (1) hide show
  1. license_plate_ocr.py +280 -0
license_plate_ocr.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
+ from PIL import Image, ImageEnhance, ImageFilter
4
+ import cv2
5
+ import numpy as np
6
+ import re
7
+ import easyocr
8
+ import os
9
+ from typing import List, Dict, Optional, Union
10
+
11
+ class LicensePlateOCR:
12
+ def __init__(self):
13
+ self.trocr_processor = None
14
+ self.trocr_model = None
15
+ self.easyocr_reader = None
16
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+
18
+ def load_trocr_model(self):
19
+ try:
20
+ print("Loading TrOCR model...")
21
+ self.trocr_processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-printed')
22
+ self.trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-printed')
23
+ self.trocr_model.to(self.device)
24
+ print(f"TrOCR model loaded on {self.device}")
25
+ return True
26
+ except Exception as e:
27
+ print(f"Error loading TrOCR model: {e}")
28
+ return False
29
+
30
+ def load_easyocr_model(self):
31
+ try:
32
+ print("Loading EasyOCR model...")
33
+ self.easyocr_reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
34
+ print("EasyOCR model loaded")
35
+ return True
36
+ except Exception as e:
37
+ print(f"Error loading EasyOCR model: {e}")
38
+ return False
39
+
40
+ def preprocess_license_plate(self, image: Image.Image) -> List[Image.Image]:
41
+ processed_images = []
42
+
43
+ try:
44
+ original = image.copy()
45
+ processed_images.append(original)
46
+
47
+ if image.mode != 'RGB':
48
+ image = image.convert('RGB')
49
+
50
+ enhancer = ImageEnhance.Contrast(image)
51
+ high_contrast = enhancer.enhance(2.0)
52
+ processed_images.append(high_contrast)
53
+
54
+ enhancer = ImageEnhance.Sharpness(high_contrast)
55
+ sharpened = enhancer.enhance(2.0)
56
+ processed_images.append(sharpened)
57
+
58
+ img_array = np.array(image)
59
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
60
+
61
+ clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
62
+ clahe_img = clahe.apply(gray)
63
+ clahe_pil = Image.fromarray(clahe_img).convert('RGB')
64
+ processed_images.append(clahe_pil)
65
+
66
+ _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
67
+ binary_pil = Image.fromarray(binary).convert('RGB')
68
+ processed_images.append(binary_pil)
69
+
70
+ denoised = cv2.bilateralFilter(gray, 9, 75, 75)
71
+ denoised_pil = Image.fromarray(denoised).convert('RGB')
72
+ processed_images.append(denoised_pil)
73
+
74
+ except Exception as e:
75
+ print(f"Error in preprocessing: {e}")
76
+ processed_images = [image]
77
+
78
+ return processed_images
79
+
80
+ def extract_text_trocr(self, image: Image.Image) -> str:
81
+ if self.trocr_processor is None or self.trocr_model is None:
82
+ if not self.load_trocr_model():
83
+ return ""
84
+
85
+ try:
86
+ pixel_values = self.trocr_processor(image, return_tensors="pt").pixel_values
87
+ pixel_values = pixel_values.to(self.device)
88
+
89
+ with torch.no_grad():
90
+ generated_ids = self.trocr_model.generate(pixel_values, max_length=50)
91
+
92
+ generated_text = self.trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
93
+ return generated_text.strip()
94
+
95
+ except Exception as e:
96
+ print(f"Error in TrOCR extraction: {e}")
97
+ return ""
98
+
99
+ def extract_text_easyocr(self, image: Image.Image) -> str:
100
+ if self.easyocr_reader is None:
101
+ if not self.load_easyocr_model():
102
+ return ""
103
+
104
+ try:
105
+ img_array = np.array(image)
106
+ results = self.easyocr_reader.readtext(img_array, detail=0, paragraph=False)
107
+
108
+ if results:
109
+ text = ' '.join(results)
110
+ return text.strip()
111
+ return ""
112
+
113
+ except Exception as e:
114
+ print(f"Error in EasyOCR extraction: {e}")
115
+ return ""
116
+
117
+ def clean_license_plate_text(self, text: str) -> str:
118
+ if not text:
119
+ return ""
120
+
121
+ text = text.upper().strip()
122
+ text = re.sub(r'[^A-Z0-9\s-]', '', text)
123
+ text = re.sub(r'\s+', ' ', text).strip()
124
+
125
+ common_mistakes = {
126
+ 'O': '0', 'I': '1', 'S': '5', 'B': '8',
127
+ 'G': '6', 'Z': '2', 'T': '7'
128
+ }
129
+
130
+ for mistake, correction in common_mistakes.items():
131
+ if len([c for c in text if c.isdigit()]) > len([c for c in text if c.isalpha()]):
132
+ text = text.replace(mistake, correction)
133
+
134
+ return text
135
+
136
+ def validate_license_plate_format(self, text: str) -> bool:
137
+ if not text or len(text) < 4:
138
+ return False
139
+
140
+ common_patterns = [
141
+ r'^[A-Z]{2}\d{2}[A-Z]{2}\d{4}$', # XX00XX0000
142
+ r'^[A-Z]{3}\d{4}$', # XXX0000
143
+ r'^[A-Z]{2}\d{4}$', # XX0000
144
+ r'^\d{3}[A-Z]{3}$', # 000XXX
145
+ r'^[A-Z]\d{3}[A-Z]{3}$', # X000XXX
146
+ r'^[A-Z]{2}\d{2}[A-Z]\d{3}$', # XX00X000
147
+ ]
148
+
149
+ text_clean = text.replace(' ', '').replace('-', '')
150
+
151
+ for pattern in common_patterns:
152
+ if re.match(pattern, text_clean):
153
+ return True
154
+
155
+ if 4 <= len(text_clean) <= 10:
156
+ alpha_count = sum(c.isalpha() for c in text_clean)
157
+ digit_count = sum(c.isdigit() for c in text_clean)
158
+ if alpha_count > 0 and digit_count > 0:
159
+ return True
160
+
161
+ return False
162
+
163
+ def extract_license_plate_text(self, image: Union[Image.Image, str, np.ndarray],
164
+ use_preprocessing: bool = True) -> Dict[str, any]:
165
+
166
+ try:
167
+ if isinstance(image, str):
168
+ if not os.path.exists(image):
169
+ return {"error": f"Image file not found: {image}"}
170
+ image = Image.open(image)
171
+ elif isinstance(image, np.ndarray):
172
+ image = Image.fromarray(image)
173
+ elif not isinstance(image, Image.Image):
174
+ return {"error": f"Unsupported image type: {type(image)}"}
175
+
176
+ if image.size[0] == 0 or image.size[1] == 0:
177
+ return {"error": "Image has zero dimensions"}
178
+
179
+ results = {
180
+ "original_image_size": image.size,
181
+ "preprocessing_used": use_preprocessing,
182
+ "extractions": [],
183
+ "best_result": "",
184
+ "confidence_score": 0.0,
185
+ "is_valid_format": False
186
+ }
187
+
188
+ images_to_process = self.preprocess_license_plate(image) if use_preprocessing else [image]
189
+
190
+ all_texts = []
191
+
192
+ for i, processed_img in enumerate(images_to_process):
193
+ try:
194
+ trocr_text = self.extract_text_trocr(processed_img)
195
+ easyocr_text = self.extract_text_easyocr(processed_img)
196
+
197
+ trocr_clean = self.clean_license_plate_text(trocr_text)
198
+ easyocr_clean = self.clean_license_plate_text(easyocr_text)
199
+
200
+ extraction_result = {
201
+ "preprocessing_step": i,
202
+ "trocr_raw": trocr_text,
203
+ "trocr_clean": trocr_clean,
204
+ "easyocr_raw": easyocr_text,
205
+ "easyocr_clean": easyocr_clean,
206
+ "trocr_valid": self.validate_license_plate_format(trocr_clean),
207
+ "easyocr_valid": self.validate_license_plate_format(easyocr_clean)
208
+ }
209
+
210
+ results["extractions"].append(extraction_result)
211
+
212
+ if trocr_clean:
213
+ all_texts.append((trocr_clean, extraction_result["trocr_valid"], "trocr"))
214
+ if easyocr_clean:
215
+ all_texts.append((easyocr_clean, extraction_result["easyocr_valid"], "easyocr"))
216
+
217
+ except Exception as e:
218
+ print(f"Error processing image variant {i}: {e}")
219
+ continue
220
+
221
+ if all_texts:
222
+ valid_texts = [t for t in all_texts if t[1]]
223
+ if valid_texts:
224
+ best_text = max(valid_texts, key=lambda x: len(x[0]))
225
+ results["best_result"] = best_text[0]
226
+ results["confidence_score"] = 0.9
227
+ results["is_valid_format"] = True
228
+ results["best_method"] = best_text[2]
229
+ else:
230
+ longest_text = max(all_texts, key=lambda x: len(x[0]))
231
+ results["best_result"] = longest_text[0]
232
+ results["confidence_score"] = 0.6
233
+ results["is_valid_format"] = False
234
+ results["best_method"] = longest_text[2]
235
+ else:
236
+ results["error"] = "No text could be extracted from the image"
237
+
238
+ return results
239
+
240
+ except Exception as e:
241
+ return {"error": f"Error in license plate extraction: {e}"}
242
+
243
+ def extract_license_plate_text(image_path_or_pil: Union[str, Image.Image]) -> str:
244
+ ocr = LicensePlateOCR()
245
+ result = ocr.extract_license_plate_text(image_path_or_pil)
246
+
247
+ if "error" in result:
248
+ return f"Error: {result['error']}"
249
+
250
+ return result.get("best_result", "No text found")
251
+
252
+ def get_detailed_license_plate_analysis(image_path_or_pil: Union[str, Image.Image]) -> Dict:
253
+ ocr = LicensePlateOCR()
254
+ return ocr.extract_license_plate_text(image_path_or_pil)
255
+
256
+ if __name__ == "__main__":
257
+ ocr_engine = LicensePlateOCR()
258
+
259
+ test_image_path = "license_plate_sample.jpg"
260
+
261
+ if os.path.exists(test_image_path):
262
+ print("Testing license plate OCR...")
263
+
264
+ result = ocr_engine.extract_license_plate_text(test_image_path)
265
+
266
+ print(f"Best Result: {result.get('best_result', 'No text found')}")
267
+ print(f"Valid Format: {result.get('is_valid_format', False)}")
268
+ print(f"Confidence: {result.get('confidence_score', 0):.2f}")
269
+
270
+ print("\nDetailed Results:")
271
+ for i, extraction in enumerate(result.get('extractions', [])):
272
+ print(f" Step {i}:")
273
+ print(f" TrOCR: {extraction['trocr_clean']} (Valid: {extraction['trocr_valid']})")
274
+ print(f" EasyOCR: {extraction['easyocr_clean']} (Valid: {extraction['easyocr_valid']})")
275
+ else:
276
+ print(f"Test image {test_image_path} not found.")
277
+ print("Usage example:")
278
+ print(" from license_plate_ocr import extract_license_plate_text")
279
+ print(" text = extract_license_plate_text('your_license_plate.jpg')")
280
+ print(" print(text)")