PavelLekomtsev commited on
Commit
1d5bc4c
·
verified ·
1 Parent(s): b771aa0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -174
app.py CHANGED
@@ -1,174 +1,174 @@
1
- from PIL import Image
2
- from ultralytics import YOLO
3
-
4
- import os
5
- import torch
6
- import re
7
- import cv2
8
-
9
- import gradio as gr
10
- import torchvision.transforms as T
11
- import albumentations as A
12
- import numpy as np
13
- import matplotlib.pyplot as plt
14
- import matplotlib.patches as patches
15
-
16
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
17
-
18
- # Folders
19
- input_folder = "./target"
20
- output_folder = "./target_output"
21
-
22
- os.makedirs(output_folder, exist_ok=True)
23
-
24
- # Detector model
25
- license_plate_detector = YOLO("./models/yolo11x.pt")
26
-
27
- # SuperResolution model
28
- sr = cv2.dnn_superres.DnnSuperResImpl_create()
29
- sr.readModel("./models/FSRCNN_x3.pb")
30
- sr.setModel("fsrcnn", 3)
31
-
32
-
33
- class App:
34
- models = ['parseq', 'parseq_tiny', 'abinet', 'crnn', 'trba', 'vitstr']
35
-
36
- def __init__(self):
37
- self._model_cache = {}
38
- self._preprocess = T.Compose([
39
- T.Resize((32, 128), T.InterpolationMode.BICUBIC),
40
- T.ToTensor(),
41
- T.Normalize(0.5, 0.5)
42
- ])
43
-
44
- def _get_model(self, name):
45
- if name in self._model_cache:
46
- return self._model_cache[name]
47
- model = torch.hub.load('baudm/parseq', name, pretrained=True).eval().to(device)
48
- self._model_cache[name] = model
49
- return model
50
-
51
- @torch.inference_mode()
52
- def __call__(self, model_name, image):
53
- if image is None:
54
- return '', []
55
- model = self._get_model(model_name)
56
- image = self._preprocess(image.convert('RGB')).unsqueeze(0).to(device)
57
- pred = model(image).softmax(-1)
58
- label, _ = model.tokenizer.decode(pred)
59
- raw_label, raw_confidence = model.tokenizer.decode(pred, raw=True)
60
- max_len = 25 if model_name == 'crnn' else len(label[0]) + 1
61
- conf = list(map('{:0.1f}'.format, raw_confidence[0][:max_len].tolist()))
62
- return label[0], [raw_label[0][:max_len], conf]
63
-
64
-
65
- p = App()
66
-
67
-
68
- def detect_license_plates(model, image):
69
- plate_image_np = pil_to_np(image)
70
-
71
- transform = A.Compose([
72
- A.ToGray(p=1.0),
73
- A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=1.0),
74
- ])
75
-
76
- transformed = transform(image=plate_image_np)['image']
77
-
78
- if len(transformed.shape) == 2:
79
- transformed = cv2.cvtColor(transformed, cv2.COLOR_GRAY2RGB)
80
-
81
- image = np_to_pil(transformed)
82
-
83
- results = model(image)
84
- plates = []
85
- for result in results:
86
- for box in result.boxes.xyxy.cpu().numpy():
87
- x1, y1, x2, y2 = map(int, box)
88
- plate = image.crop((x1, y1, x2, y2))
89
- plates.append((plate, (x1, y1, x2, y2)))
90
-
91
- return plates
92
-
93
-
94
- def pil_to_np(image):
95
- return np.array(image)
96
-
97
-
98
- def np_to_pil(image_np):
99
- return Image.fromarray(image_np)
100
-
101
-
102
- def preprocess_license_plate(plate_image: Image):
103
- plate_image_np = pil_to_np(plate_image)
104
- if not(plate_image_np.ndim == 2 or plate_image_np.shape[-1] == 1):
105
- plate_image_np = A.ToGray(p=1.0, num_output_channels=1)(image=plate_image_np)['image']
106
- super_resolved = sr.upsample(plate_image_np)
107
- augmented = A.Compose([
108
- A.CLAHE(clip_limit=2, tile_grid_size=(1, 1), p=1.0),
109
- A.Morphological(p=1.0, scale=(4, 4), operation="erosion"),
110
- ])(image=super_resolved)['image']
111
-
112
- super_resolved_pil = np_to_pil(augmented)
113
- return super_resolved_pil
114
-
115
-
116
- def process_image(image_path: Image):
117
- image_np = np.array(image_path)
118
-
119
- fig, ax = plt.subplots(1, figsize=(10, 6))
120
- ax.imshow(image_np)
121
-
122
- plates = detect_license_plates(license_plate_detector, image_path)
123
- recognized_text = ""
124
-
125
- for i, (plate, bbox) in enumerate(plates):
126
- preprocessed_plate = preprocess_license_plate(plate)
127
- recognized_text, raw_output = p.__call__("parseq", preprocessed_plate)
128
-
129
- if recognized_text and len(recognized_text) > 5:
130
- recognized_text = re.sub(r"[^A-Za-z0-9]", "", recognized_text).upper()
131
- recognized_text = recognized_text.replace('V', 'Y').replace('I', '')
132
- recognized_text = recognized_text.replace('8', 'В', 1) if recognized_text[0] == "8" else recognized_text
133
- recognized_text = recognized_text.replace('7', 'T', 1) if recognized_text[0] == "7" else recognized_text
134
- recognized_text = recognized_text.replace('0', 'O', 1) if recognized_text[0] == "0" else recognized_text
135
- recognized_text = recognized_text[:9] if len(recognized_text) >= 9 else recognized_text
136
-
137
- x1, y1, x2, y2 = bbox
138
- rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='r', facecolor='none')
139
- ax.add_patch(rect)
140
- ax.text(x1, y1 - 10, recognized_text, color='red', fontsize=12, bbox=dict(facecolor='white', alpha=0.5))
141
-
142
- plt.axis('off')
143
-
144
- # Saving image to buffer
145
- output_buffer = "processed_image.png"
146
- plt.savefig(output_buffer, bbox_inches='tight')
147
- plt.close()
148
-
149
- return Image.open(output_buffer), recognized_text.strip()
150
-
151
-
152
- # Gradio UI
153
-
154
- target_folder = "./target"
155
- example_images = [
156
- os.path.join(target_folder, file) for file in os.listdir(target_folder) if file.lower().endswith(("jpg", "png", "bmp"))
157
- ]
158
-
159
- interface = gr.Interface(
160
- fn=process_image,
161
- inputs=gr.Image(type="pil", label="Загрузите фото машины с номером 📤"),
162
- outputs=[
163
- gr.Image(type="pil", label="📸 Выход 0 - Обработанное изображение"),
164
- gr.Text(label="🔍 Выход 1 - Распознанный номер"),
165
- ],
166
- title="Распознавание российских номеров",
167
- description="🔎 **Загрузите изображение с автомобильным номером** и модель автоматически **определит госномер!** 🔥\n\n📸 **Форматы:** JPG, PNG, BMP",
168
- examples=example_images,
169
- allow_flagging="never",
170
- theme="compact",
171
- )
172
-
173
- if __name__ == "__main__":
174
- interface.launch(share=True)
 
1
+ from PIL import Image
2
+ from ultralytics import YOLO
3
+
4
+ import os
5
+ import torch
6
+ import re
7
+ import cv2
8
+
9
+ import gradio as gr
10
+ import torchvision.transforms as T
11
+ import albumentations as A
12
+ import numpy as np
13
+ import matplotlib.pyplot as plt
14
+ import matplotlib.patches as patches
15
+
16
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
17
+
18
+ # Folders
19
+ input_folder = "./target"
20
+ output_folder = "./target_output"
21
+
22
+ os.makedirs(output_folder, exist_ok=True)
23
+
24
+ # Detector model
25
+ license_plate_detector = YOLO("./models/yolo11x.pt")
26
+
27
+ # SuperResolution model
28
+ sr = cv2.dnn_superres.DnnSuperResImpl_create()
29
+ sr.readModel("./models/FSRCNN_x3.pb")
30
+ sr.setModel("fsrcnn", 3)
31
+
32
+
33
+ class App:
34
+ models = ['parseq', 'parseq_tiny', 'abinet', 'crnn', 'trba', 'vitstr']
35
+
36
+ def __init__(self):
37
+ self._model_cache = {}
38
+ self._preprocess = T.Compose([
39
+ T.Resize((32, 128), T.InterpolationMode.BICUBIC),
40
+ T.ToTensor(),
41
+ T.Normalize(0.5, 0.5)
42
+ ])
43
+
44
+ def _get_model(self, name):
45
+ if name in self._model_cache:
46
+ return self._model_cache[name]
47
+ model = torch.hub.load('baudm/parseq', name, pretrained=True).eval().to(device)
48
+ self._model_cache[name] = model
49
+ return model
50
+
51
+ @torch.inference_mode()
52
+ def __call__(self, model_name, image):
53
+ if image is None:
54
+ return '', []
55
+ model = self._get_model(model_name)
56
+ image = self._preprocess(image.convert('RGB')).unsqueeze(0).to(device)
57
+ pred = model(image).softmax(-1)
58
+ label, _ = model.tokenizer.decode(pred)
59
+ raw_label, raw_confidence = model.tokenizer.decode(pred, raw=True)
60
+ max_len = 25 if model_name == 'crnn' else len(label[0]) + 1
61
+ conf = list(map('{:0.1f}'.format, raw_confidence[0][:max_len].tolist()))
62
+ return label[0], [raw_label[0][:max_len], conf]
63
+
64
+
65
+ p = App()
66
+
67
+
68
+ def detect_license_plates(model, image):
69
+ plate_image_np = pil_to_np(image)
70
+
71
+ transform = A.Compose([
72
+ A.ToGray(p=1.0),
73
+ A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=1.0),
74
+ ])
75
+
76
+ transformed = transform(image=plate_image_np)['image']
77
+
78
+ if len(transformed.shape) == 2:
79
+ transformed = cv2.cvtColor(transformed, cv2.COLOR_GRAY2RGB)
80
+
81
+ image = np_to_pil(transformed)
82
+
83
+ results = model(image)
84
+ plates = []
85
+ for result in results:
86
+ for box in result.boxes.xyxy.cpu().numpy():
87
+ x1, y1, x2, y2 = map(int, box)
88
+ plate = image.crop((x1, y1, x2, y2))
89
+ plates.append((plate, (x1, y1, x2, y2)))
90
+
91
+ return plates
92
+
93
+
94
+ def pil_to_np(image):
95
+ return np.array(image)
96
+
97
+
98
+ def np_to_pil(image_np):
99
+ return Image.fromarray(image_np)
100
+
101
+
102
+ def preprocess_license_plate(plate_image: Image):
103
+ plate_image_np = pil_to_np(plate_image)
104
+ if not(plate_image_np.ndim == 2 or plate_image_np.shape[-1] == 1):
105
+ plate_image_np = A.ToGray(p=1.0, num_output_channels=1)(image=plate_image_np)['image']
106
+ super_resolved = sr.upsample(plate_image_np)
107
+ augmented = A.Compose([
108
+ A.CLAHE(clip_limit=2, tile_grid_size=(1, 1), p=1.0),
109
+ A.Morphological(p=1.0, scale=(4, 4), operation="erosion"),
110
+ ])(image=super_resolved)['image']
111
+
112
+ super_resolved_pil = np_to_pil(augmented)
113
+ return super_resolved_pil
114
+
115
+
116
+ def process_image(image_path: Image):
117
+ image_np = np.array(image_path)
118
+
119
+ fig, ax = plt.subplots(1, figsize=(10, 6))
120
+ ax.imshow(image_np)
121
+
122
+ plates = detect_license_plates(license_plate_detector, image_path)
123
+ recognized_text = ""
124
+
125
+ for i, (plate, bbox) in enumerate(plates):
126
+ preprocessed_plate = preprocess_license_plate(plate)
127
+ recognized_text, raw_output = p.__call__("parseq", preprocessed_plate)
128
+
129
+ if recognized_text and len(recognized_text) > 5:
130
+ recognized_text = re.sub(r"[^A-Za-z0-9]", "", recognized_text).upper()
131
+ recognized_text = recognized_text.replace('V', 'Y').replace('I', '')
132
+ recognized_text = recognized_text.replace('8', 'В', 1) if recognized_text[0] == "8" else recognized_text
133
+ recognized_text = recognized_text.replace('7', 'T', 1) if recognized_text[0] == "7" else recognized_text
134
+ recognized_text = recognized_text.replace('0', 'O', 1) if recognized_text[0] == "0" else recognized_text
135
+ recognized_text = recognized_text[:9] if len(recognized_text) >= 9 else recognized_text
136
+
137
+ x1, y1, x2, y2 = bbox
138
+ rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='r', facecolor='none')
139
+ ax.add_patch(rect)
140
+ ax.text(x1, y1 - 10, recognized_text, color='red', fontsize=12, bbox=dict(facecolor='white', alpha=0.5))
141
+
142
+ plt.axis('off')
143
+
144
+ # Saving image to buffer
145
+ output_buffer = "processed_image.png"
146
+ plt.savefig(output_buffer, bbox_inches='tight')
147
+ plt.close()
148
+
149
+ return Image.open(output_buffer), recognized_text.strip()
150
+
151
+
152
+ # Gradio UI
153
+
154
+ target_folder = "./target"
155
+ example_images = [
156
+ os.path.join(target_folder, file) for file in os.listdir(target_folder) if file.lower().endswith(("jpg", "png", "bmp"))
157
+ ]
158
+
159
+ interface = gr.Interface(
160
+ fn=process_image,
161
+ inputs=gr.Image(type="pil", label="Загрузите фото машины с номером 📤"),
162
+ outputs=[
163
+ gr.Image(type="pil", label="📸 Выход 0 - Обработанное изображение"),
164
+ gr.Text(label="🔍 Выход 1 - Распознанный номер"),
165
+ ],
166
+ title="Распознавание российских номеров",
167
+ description="🔎 **Загрузите изображение с автомобильным номером** и модель автоматически **определит госномер!** 🔥\n\n📸 **Форматы:** JPG, PNG, BMP",
168
+ examples=example_images,
169
+ flagging_mode="never",
170
+ theme="compact",
171
+ )
172
+
173
+ if __name__ == "__main__":
174
+ interface.launch(share=True)