MohammedHamdy32 commited on
Commit
77f8d5f
·
1 Parent(s): ade118b

Add Egyption ID information Extraction

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +4 -0
  2. app.py +50 -0
  3. code/__init__.py +0 -0
  4. code/__pycache__/__init__.cpython-310.pyc +0 -0
  5. code/detection/__pycache__/detection.cpython-310.pyc +0 -0
  6. code/detection/detection.py +77 -0
  7. code/detection/recognize_id/__pycache__/detect_and_recognize_id.cpython-310.pyc +0 -0
  8. code/detection/recognize_id/data/id_1.png +0 -0
  9. code/detection/recognize_id/detect_and_recognize_id.py +36 -0
  10. code/recognization/.ipynb_checkpoints/Untitled-checkpoint.ipynb +6 -0
  11. code/recognization/__pycache__/augmentation.cpython-310.pyc +0 -0
  12. code/recognization/__pycache__/config.cpython-310.pyc +0 -0
  13. code/recognization/__pycache__/custom_test.cpython-310.pyc +0 -0
  14. code/recognization/__pycache__/dataset.cpython-310.pyc +0 -0
  15. code/recognization/__pycache__/densenet.cpython-310.pyc +0 -0
  16. code/recognization/__pycache__/dropout_layer.cpython-310.pyc +0 -0
  17. code/recognization/__pycache__/feature_extraction.cpython-310.pyc +0 -0
  18. code/recognization/__pycache__/hrnet.cpython-310.pyc +0 -0
  19. code/recognization/__pycache__/inception_unet.cpython-310.pyc +0 -0
  20. code/recognization/__pycache__/model.cpython-310.pyc +0 -0
  21. code/recognization/__pycache__/my_test.cpython-310.pyc +0 -0
  22. code/recognization/__pycache__/prediction.cpython-310.pyc +0 -0
  23. code/recognization/__pycache__/rcnn.cpython-310.pyc +0 -0
  24. code/recognization/__pycache__/recognization.cpython-310.pyc +0 -0
  25. code/recognization/__pycache__/resnet.cpython-310.pyc +0 -0
  26. code/recognization/__pycache__/resunet.cpython-310.pyc +0 -0
  27. code/recognization/__pycache__/sequence_modeling.cpython-310.pyc +0 -0
  28. code/recognization/__pycache__/unet.cpython-310.pyc +0 -0
  29. code/recognization/__pycache__/unet_attn.cpython-310.pyc +0 -0
  30. code/recognization/__pycache__/unet_plus_plus.cpython-310.pyc +0 -0
  31. code/recognization/__pycache__/utils.cpython-310.pyc +0 -0
  32. code/recognization/__pycache__/vgg.cpython-310.pyc +0 -0
  33. code/recognization/augmentation.py +134 -0
  34. code/recognization/config.py +30 -0
  35. code/recognization/custom_test.py +235 -0
  36. code/recognization/data/1.png +0 -0
  37. code/recognization/data/10.png +0 -0
  38. code/recognization/data/11.png +0 -0
  39. code/recognization/data/12.png +0 -0
  40. code/recognization/data/13.png +0 -0
  41. code/recognization/data/14.png +0 -0
  42. code/recognization/data/15.png +0 -0
  43. code/recognization/data/16.png +0 -0
  44. code/recognization/data/2.png +0 -0
  45. code/recognization/data/2_1.png +0 -0
  46. code/recognization/data/2_2.png +0 -0
  47. code/recognization/data/3.png +0 -0
  48. code/recognization/data/4.png +0 -0
  49. code/recognization/data/5.png +0 -0
  50. code/recognization/data/6.png +0 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Ignore all .log files
2
+ *.pt
3
+ *.pth
4
+ models
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from code.detection.recognize_id.detect_and_recognize_id import Recognize_ID
3
+ from code.detection.detection import detection
4
+ from code.recognization.recognization import TextRecognition
5
+ import os
6
+
7
+ # Define a dummy prediction function
8
+ def predict_image(image):
9
+
10
+ # Recognize ID
11
+ rec_id = Recognize_ID()
12
+ id = rec_id.give_me_id_number(image)
13
+
14
+ # Detection
15
+ det = detection()
16
+ detection_list = det.full_pipeline(image)
17
+
18
+ result = ''
19
+ # Loop on all detected images and recognize them
20
+ recognizer = TextRecognition()
21
+ for line in detection_list[2:6]:
22
+ for word in line:
23
+ recognized_word = recognizer.recognize_image(word)
24
+ result = result + recognized_word + ' '
25
+ result += '\n'
26
+
27
+ # Add Id number
28
+ result = result + id
29
+
30
+ return result
31
+
32
+ # List of paths to your sample images
33
+ current_dir = os.path.dirname(os.path.abspath(__file__))
34
+ sample_images = [
35
+ os.path.join(current_dir , "samples/id_1.png" )
36
+ ]
37
+
38
+ # Create the Gradio interface
39
+ interface = gr.Interface(
40
+ fn=predict_image, # Function to run
41
+ inputs="image", # Input type
42
+ outputs="text", # Output type
43
+ title="Recognization",
44
+ description="Upload an image",
45
+ examples=sample_images
46
+ )
47
+
48
+ # Launch the app
49
+ interface.launch()
50
+
code/__init__.py ADDED
File without changes
code/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (186 Bytes). View file
 
code/detection/__pycache__/detection.cpython-310.pyc ADDED
Binary file (2.86 kB). View file
 
code/detection/detection.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ from glob import glob
3
+ import matplotlib.pyplot as plt
4
+ import cv2
5
+ import os
6
+ from PIL import Image
7
+ from ultralytics.engine.results import Results
8
+ import numpy as np
9
+
10
+
11
+ class detection:
12
+
13
+ def __init__(self,model_path='detection.pt'):
14
+ current_dir = os.path.dirname(os.path.abspath(__file__))
15
+ model_path = os.path.join(current_dir , model_path )
16
+ self.model = YOLO(model_path)
17
+
18
+ def get_distance(self,res):
19
+ boxes = res[0].boxes.xywh.numpy() # Convert to numpy array
20
+ # Sort primarily by Y (vertical), then X (horizontal) using lexsort
21
+ sorted_indices = np.lexsort((boxes[:, 0], boxes[:, 1]))
22
+ sorted_boxes = boxes[sorted_indices]
23
+ return sorted_boxes[:, 1], sorted_indices # Return sorted Y values and indices
24
+
25
+ def handle_the_boxes(self,res, img, y_threshold=30):
26
+ distance_sorted, sorted_indices = self.get_distance(res)
27
+ PB = res[0].boxes.xyxy.numpy()[sorted_indices] # Get boxes in sorted order
28
+ same_object = []
29
+ current_line = [PB[0]]
30
+
31
+ # Group boxes into lines using Y threshold
32
+ for i in range(1, len(PB)):
33
+ prev_y = current_line[-1][1] # Use ymin from XYXY format
34
+ current_y = PB[i][1]
35
+ if abs(current_y - prev_y) > y_threshold:
36
+ # Sort line left-to-right before adding
37
+ current_line = sorted(current_line, key=lambda x: x[0] , reverse=True)
38
+ same_object.append(current_line)
39
+ current_line = [PB[i]]
40
+ else:
41
+ current_line.append(PB[i])
42
+
43
+ # Add the last line and sort it
44
+ if current_line:
45
+ current_line = sorted(current_line, key=lambda x: x[0])
46
+ same_object.append(current_line)
47
+
48
+ # Extract word images in final order
49
+ return [
50
+ [self.words_pixels(img, box) for box in line]
51
+ for line in same_object
52
+ ]
53
+
54
+ # Keep words_pixels as original
55
+ def words_pixels(self,img, xyxy):
56
+ xmin, ymin, xmax, ymax = xyxy.tolist()
57
+ return img[int(ymin):int(ymax)+1, int(xmin):int(xmax)+1]
58
+
59
+ def full_pipeline(self,image,show=False):
60
+
61
+ if isinstance(image, str): # If the input is a file path
62
+ img = cv2.imread(image)
63
+ elif isinstance(image, np.ndarray): # If the input is a NumPy array
64
+ image = image
65
+ img = image
66
+
67
+ res = self.model(image)
68
+
69
+ if show:
70
+ res[0].show()
71
+
72
+
73
+ return self.handle_the_boxes(res , img)
74
+
75
+
76
+
77
+
code/detection/recognize_id/__pycache__/detect_and_recognize_id.cpython-310.pyc ADDED
Binary file (1.77 kB). View file
 
code/detection/recognize_id/data/id_1.png ADDED
code/detection/recognize_id/detect_and_recognize_id.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ from ultralytics.engine.results import Results
3
+ import cv2
4
+ import os
5
+ import numpy as np
6
+
7
+ class Recognize_ID:
8
+
9
+ def __init__(self,model_path='recognization_id.pt'):
10
+
11
+ current_dir = os.path.dirname(os.path.abspath(__file__))
12
+ model_path = os.path.join(current_dir , model_path )
13
+ self.model = YOLO(model=model_path )
14
+
15
+ def give_me_id_number(self,image:str):
16
+ """
17
+ image_dir : input image directory
18
+ model : yolo model
19
+ """
20
+ if isinstance(image, str): # If the input is a file path
21
+ current_dir = os.path.dirname(os.path.abspath(__file__))
22
+ image_path = os.path.join(current_dir , image )
23
+ img = cv2.imread(image_path)
24
+ elif isinstance(image, np.ndarray): # If the input is a NumPy array
25
+ img = image
26
+
27
+ print(type(img))
28
+ res = self.model(img)
29
+ boxes = res[0].boxes.xywh[::,0].tolist()
30
+ classes = res[0].boxes.cls.tolist()
31
+ boxes_labels =[(int(key) , int(value)) for key , value in zip(boxes, classes)]
32
+ boxes_labels.sort()
33
+ national_id = "".join([str(i[1]) for i in boxes_labels])
34
+
35
+ return national_id
36
+
code/recognization/.ipynb_checkpoints/Untitled-checkpoint.ipynb ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [],
3
+ "metadata": {},
4
+ "nbformat": 4,
5
+ "nbformat_minor": 5
6
+ }
code/recognization/__pycache__/augmentation.cpython-310.pyc ADDED
Binary file (5.33 kB). View file
 
code/recognization/__pycache__/config.cpython-310.pyc ADDED
Binary file (881 Bytes). View file
 
code/recognization/__pycache__/custom_test.cpython-310.pyc ADDED
Binary file (7.13 kB). View file
 
code/recognization/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (11 kB). View file
 
code/recognization/__pycache__/densenet.cpython-310.pyc ADDED
Binary file (4.01 kB). View file
 
code/recognization/__pycache__/dropout_layer.cpython-310.pyc ADDED
Binary file (1.54 kB). View file
 
code/recognization/__pycache__/feature_extraction.cpython-310.pyc ADDED
Binary file (5.18 kB). View file
 
code/recognization/__pycache__/hrnet.cpython-310.pyc ADDED
Binary file (7.44 kB). View file
 
code/recognization/__pycache__/inception_unet.cpython-310.pyc ADDED
Binary file (5.47 kB). View file
 
code/recognization/__pycache__/model.cpython-310.pyc ADDED
Binary file (3.84 kB). View file
 
code/recognization/__pycache__/my_test.cpython-310.pyc ADDED
Binary file (7.09 kB). View file
 
code/recognization/__pycache__/prediction.cpython-310.pyc ADDED
Binary file (3.66 kB). View file
 
code/recognization/__pycache__/rcnn.cpython-310.pyc ADDED
Binary file (3.66 kB). View file
 
code/recognization/__pycache__/recognization.cpython-310.pyc ADDED
Binary file (2.87 kB). View file
 
code/recognization/__pycache__/resnet.cpython-310.pyc ADDED
Binary file (4.83 kB). View file
 
code/recognization/__pycache__/resunet.cpython-310.pyc ADDED
Binary file (3.38 kB). View file
 
code/recognization/__pycache__/sequence_modeling.cpython-310.pyc ADDED
Binary file (2.81 kB). View file
 
code/recognization/__pycache__/unet.cpython-310.pyc ADDED
Binary file (4.02 kB). View file
 
code/recognization/__pycache__/unet_attn.cpython-310.pyc ADDED
Binary file (4.94 kB). View file
 
code/recognization/__pycache__/unet_plus_plus.cpython-310.pyc ADDED
Binary file (3.43 kB). View file
 
code/recognization/__pycache__/utils.cpython-310.pyc ADDED
Binary file (15.7 kB). View file
 
code/recognization/__pycache__/vgg.cpython-310.pyc ADDED
Binary file (1.81 kB). View file
 
code/recognization/augmentation.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Paper: "UTRNet: High-Resolution Urdu Text Recognition In Printed Documents" presented at ICDAR 2023
3
+ Authors: Abdur Rahman, Arjun Ghosh, Chetan Arora
4
+ GitHub Repository: https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition
5
+ Project Website: https://abdur75648.github.io/UTRNet/
6
+ Copyright (c) 2023-present: This work is licensed under the Creative Commons Attribution-NonCommercial
7
+ 4.0 International License (http://creativecommons.org/licenses/by-nc/4.0/)
8
+ """
9
+
10
+ from functools import partial
11
+ import random as rnd
12
+ import imgaug.augmenters as iaa
13
+ import numpy as np
14
+ from PIL import ImageFilter, Image
15
+ from timm.data import auto_augment
16
+
17
+ _OP_CACHE = {}
18
+
19
+ def _get_op(key, factory):
20
+ try:
21
+ op = _OP_CACHE[key]
22
+ except KeyError:
23
+ op = factory()
24
+ _OP_CACHE[key] = op
25
+ return op
26
+
27
+
28
+ def _get_param(level, img, max_dim_factor, min_level=1):
29
+ max_level = max(min_level, max_dim_factor * max(img.size))
30
+ return round(min(level, max_level))
31
+
32
+ def gaussian_blur(img, radius, **__):
33
+ radius = _get_param(radius, img, 0.02)
34
+ key = 'gaussian_blur_' + str(radius)
35
+ op = _get_op(key, lambda: ImageFilter.GaussianBlur(radius))
36
+ return img.filter(op)
37
+
38
+ def motion_blur(img, k, **__):
39
+ k = _get_param(k, img, 0.08, 3) | 1 # bin to odd values
40
+ key = 'motion_blur_' + str(k)
41
+ op = _get_op(key, lambda: iaa.MotionBlur(k))
42
+ return Image.fromarray(op(image=np.asarray(img)))
43
+
44
+ def gaussian_noise(img, scale, **_):
45
+ scale = _get_param(scale, img, 0.25) | 1 # bin to odd values
46
+ key = 'gaussian_noise_' + str(scale)
47
+ op = _get_op(key, lambda: iaa.AdditiveGaussianNoise(scale=scale))
48
+ return Image.fromarray(op(image=np.asarray(img)))
49
+
50
+ def poisson_noise(img, lam, **_):
51
+ lam = _get_param(lam, img, 0.2) | 1 # bin to odd values
52
+ key = 'poisson_noise_' + str(lam)
53
+ op = _get_op(key, lambda: iaa.AdditivePoissonNoise(lam))
54
+ return Image.fromarray(op(image=np.asarray(img)))
55
+
56
+ def salt_and_pepper_noise(image, prob=0.05):
57
+ if prob <= 0:
58
+ return image
59
+ arr = np.asarray(image)
60
+ original_dtype = arr.dtype
61
+ intensity_levels = 2 ** (arr[0, 0].nbytes * 8)
62
+ min_intensity = 0
63
+ max_intensity = intensity_levels - 1
64
+ random_image_arr = np.random.choice([min_intensity, 1, np.nan], p=[prob / 2, 1 - prob, prob / 2], size=arr.shape)
65
+ salt_and_peppered_arr = arr.astype(np.float) * random_image_arr
66
+ salt_and_peppered_arr = np.nan_to_num(salt_and_peppered_arr, nan=max_intensity).astype(original_dtype)
67
+ return Image.fromarray(salt_and_peppered_arr)
68
+
69
+ def random_border_crop(image):
70
+ img_width,img_height = image.size
71
+ crop_left = int(img_width * rnd.uniform(0.0, 0.025))
72
+ crop_top = int(img_height * rnd.uniform(0.0, 0.075))
73
+ crop_right = int(img_width * rnd.uniform(0.975, 1.0))
74
+ crop_bottom = int(img_height * rnd.uniform(0.925, 1.0))
75
+ final_image = image.crop((crop_left, crop_top, crop_right, crop_bottom))
76
+ return final_image
77
+
78
+ def random_resize(image):
79
+ size = image.size
80
+ new_size = [rnd.randint(int(0.5*size[0]), int(1.5*size[0])), rnd.randint(int(0.5*size[1]), int(1.5*size[1]))]
81
+ reduce_factor = rnd.randint(1,4)
82
+ new_size = tuple([int(x/reduce_factor) for x in new_size])
83
+ final_image = image.resize(new_size)
84
+ return final_image
85
+
86
+ def _level_to_arg(level, _hparams, max):
87
+ level = max * level / auto_augment._LEVEL_DENOM
88
+ return level,
89
+
90
+ _RAND_TRANSFORMS = [
91
+ 'AutoContrast',
92
+ 'Equalize',
93
+ 'Invert',
94
+ # 'Rotate',
95
+ 'Posterize',
96
+ 'Solarize',
97
+ 'SolarizeAdd',
98
+ 'Color',
99
+ 'Contrast',
100
+ 'Brightness',
101
+ 'Sharpness',
102
+ 'ShearX',
103
+ ]
104
+ #_RAND_TRANSFORMS.remove('SharpnessIncreasing') # remove, interferes with *blur ops
105
+ _RAND_TRANSFORMS.extend([
106
+ 'GaussianBlur',
107
+ 'GaussianNoise',
108
+ 'PoissonNoise'
109
+ ])
110
+ auto_augment.LEVEL_TO_ARG.update({
111
+ 'GaussianBlur': partial(_level_to_arg, max=4),
112
+ 'MotionBlur': partial(_level_to_arg, max=20),
113
+ 'GaussianNoise': partial(_level_to_arg, max=0.1 * 255),
114
+ 'PoissonNoise': partial(_level_to_arg, max=40)
115
+ })
116
+ auto_augment.NAME_TO_OP.update({
117
+ 'GaussianBlur': gaussian_blur,
118
+ 'MotionBlur': motion_blur,
119
+ 'GaussianNoise': gaussian_noise,
120
+ 'PoissonNoise': poisson_noise
121
+ })
122
+
123
+ def rand_augment_transform(magnitude=5, num_layers=3):
124
+ # These are tuned for magnitude=5, which means that effective magnitudes are half of these values.
125
+ hparams = {
126
+ 'img_mean':128,
127
+ # 'rotate_deg': 5,
128
+ 'shear_x_pct': 0.9,
129
+ 'shear_y_pct': 0.0,
130
+ }
131
+ ra_ops = auto_augment.rand_augment_ops(magnitude, hparams, transforms=_RAND_TRANSFORMS)
132
+ # Supply weights to disable replacement in random selection (i.e. avoid applying the same op twice)
133
+ choice_weights = [1. / len(ra_ops) for _ in range(len(ra_ops))]
134
+ return auto_augment.RandAugment(ra_ops, num_layers, choice_weights)
code/recognization/config.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Config:
2
+ FeatureExtraction = 'HRNet' # or any other feature extraction method
3
+ SequenceModeling = 'DBiLSTM' # or any other sequential model
4
+ Prediction = 'CTC' # or 'Attn'
5
+ input_channel = 1 # e.g., RGB image has 3 channels
6
+ output_channel = 32 # Adjust based on your architecture
7
+ hidden_size = 256 # Adjust based on your architecture
8
+ num_class = 182 # Number of output classes
9
+ device = 'cpu' # or 'cuda' for GPU
10
+ batch_max_length = 8 # Maximum sequence length for prediction
11
+ # Adam optimizer
12
+ adam = False
13
+ lr = 0.1
14
+ batch_size = 4
15
+ beta1 = 0.9
16
+ workers = 4
17
+ num_epochs = 5
18
+ rho = 0.95
19
+ eps = 1e-8
20
+
21
+ imgH = 32
22
+ imgW = 400
23
+ train_data = 'result/train/' # path to train data
24
+ valid_data = 'result/validate/' # path to validation data
25
+ saved_model = 'model/'
26
+
27
+ character =''
28
+ rgb = False
29
+ grad_clip = 5
30
+
code/recognization/custom_test.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Paper: "UTRNet: High-Resolution Urdu Text Recognition In Printed Documents" presented at ICDAR 2023
3
+ Authors: Abdur Rahman, Arjun Ghosh, Chetan Arora
4
+ GitHub Repository: https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition
5
+ Project Website: https://abdur75648.github.io/UTRNet/
6
+ Copyright (c) 2023-present: This work is licensed under the Creative Commons Attribution-NonCommercial
7
+ 4.0 International License (http://creativecommons.org/licenses/by-nc/4.0/)
8
+ """
9
+
10
+ import os,shutil
11
+ import time
12
+ import argparse
13
+ import random
14
+ import numpy as np
15
+ import matplotlib.pyplot as plt
16
+ from datetime import datetime
17
+ import pytz
18
+
19
+ import torch
20
+ import torch.utils.data
21
+ import torch.nn.functional as F
22
+ from tqdm import tqdm
23
+ from nltk.metrics.distance import edit_distance
24
+
25
+ from utils import CTCLabelConverter, AttnLabelConverter, Averager, Logger
26
+ from dataset import hierarchical_dataset, AlignCollate
27
+ from model import Model
28
+
29
+ def validation(model, criterion, evaluation_loader, converter, opt, device):
30
+ """ validation or evaluation """
31
+ eval_arr = []
32
+ sum_len_gt = 0
33
+
34
+ n_correct = 0
35
+
36
+ norm_ED = 0
37
+ length_of_data = 0
38
+ infer_time = 0
39
+ valid_loss_avg = Averager()
40
+
41
+ for i, (image_tensors, labels) in enumerate(tqdm(evaluation_loader)):
42
+ batch_size = image_tensors.size(0)
43
+ length_of_data = length_of_data + batch_size
44
+ image = image_tensors.to(device)
45
+ # For max length prediction
46
+ length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
47
+ text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
48
+
49
+ text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length)
50
+
51
+ start_time = time.time()
52
+ if 'CTC' in opt.Prediction:
53
+ preds = model(image)
54
+ forward_time = time.time() - start_time
55
+ preds_size = torch.IntTensor([preds.size(1)] * batch_size)
56
+ cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)
57
+ _, preds_index = preds.max(2)
58
+ preds_str = converter.decode(preds_index.data, preds_size.data)
59
+ else:
60
+ preds = model(image, text=text_for_pred, is_train=False)
61
+ forward_time = time.time() - start_time
62
+
63
+ preds = preds[:, :text_for_loss.shape[1] - 1, :].to(device)
64
+ target = text_for_loss[:, 1:].to(device) # without [GO] Symbol
65
+ cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1))
66
+ _, preds_index = preds.max(2)
67
+ preds_str = converter.decode(preds_index, length_for_pred)
68
+ labels = converter.decode(text_for_loss[:, 1:], length_for_loss)
69
+
70
+ infer_time += forward_time
71
+ valid_loss_avg.add(cost)
72
+
73
+ # calculate accuracy & confidence score
74
+ preds_prob = F.softmax(preds, dim=2)
75
+ preds_max_prob, _ = preds_prob.max(dim=2)
76
+ confidence_score_list = []
77
+ for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
78
+ if 'Attn' in opt.Prediction:
79
+ gt = gt[:gt.find('[s]')]
80
+ pred_EOS = pred.find('[s]')
81
+ pred = pred[:pred_EOS] # prune after "end of sentence" token ([s])
82
+ pred_max_prob = pred_max_prob[:pred_EOS]
83
+
84
+ if pred == gt:
85
+ n_correct += 1
86
+
87
+ # ICDAR2019 Normalized Edit Distance
88
+ if len(gt) == 0 or len(pred) == 0:
89
+ ED = 0
90
+ elif len(gt) > len(pred):
91
+ ED = 1 - edit_distance(pred, gt) / len(gt)
92
+ else:
93
+ ED = 1 - edit_distance(pred, gt) / len(pred)
94
+
95
+ eval_arr.append([gt,pred,ED])
96
+
97
+ sum_len_gt += len(gt)
98
+ norm_ED += (ED*len(gt))
99
+
100
+ # calculate confidence score (= multiply of pred_max_prob)
101
+ try:
102
+ confidence_score = pred_max_prob.cumprod(dim=0)[-1]
103
+ except:
104
+ confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([s])
105
+ confidence_score_list.append(confidence_score)
106
+ # print(pred, gt, pred==gt, confidence_score)
107
+
108
+ accuracy = n_correct / float(length_of_data) * 100
109
+ norm_ED = norm_ED / float(sum_len_gt)
110
+
111
+ return valid_loss_avg.val(), accuracy, norm_ED, eval_arr
112
+
113
+
114
+ def test(opt, device):
115
+ opt.device = device
116
+ os.makedirs("test_outputs", exist_ok=True)
117
+ datetime_now = str(datetime.now(pytz.timezone('Asia/Kolkata')).strftime("%Y-%m-%d_%H-%M-%S"))
118
+ logger = Logger(f'test_outputs/{datetime_now}.txt')
119
+ """ model configuration """
120
+ if 'CTC' in opt.Prediction:
121
+ converter = CTCLabelConverter(opt.character)
122
+ else:
123
+ converter = AttnLabelConverter(opt.character)
124
+ opt.num_class = len(converter.character)
125
+
126
+ if opt.rgb:
127
+ opt.input_channel = 3
128
+ model = Model(opt)
129
+ logger.log('model input parameters', opt.imgH, opt.imgW, opt.input_channel, opt.output_channel,
130
+ opt.hidden_size, opt.num_class, opt.batch_max_length, opt.FeatureExtraction,
131
+ opt.SequenceModeling, opt.Prediction)
132
+ model = model.to(device)
133
+
134
+ # load model
135
+ model.load_state_dict(torch.load(opt.saved_model, map_location=device))
136
+ logger.log('Loaded pretrained model from %s' % opt.saved_model)
137
+ # logger.log(model)
138
+
139
+ """ setup loss """
140
+ if 'CTC' in opt.Prediction:
141
+ criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
142
+ else:
143
+ criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0
144
+
145
+ """ evaluation """
146
+ model.eval()
147
+ with torch.no_grad():
148
+ AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW)#, keep_ratio_with_pad=opt.PAD)
149
+ eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data, opt=opt, rand_aug=False)
150
+ logger.log(eval_data_log)
151
+ evaluation_loader = torch.utils.data.DataLoader(
152
+ eval_data, batch_size=opt.batch_size,
153
+ shuffle=False,
154
+ num_workers=int(opt.workers),
155
+ collate_fn=AlignCollate_evaluation, pin_memory=True)
156
+ _, accuracy, norm_ED, eval_arr = validation( model, criterion, evaluation_loader, converter, opt,device)
157
+ logger.log("="*20)
158
+ logger.log(f'Accuracy : {accuracy:0.4f}\n')
159
+ logger.log(f'Norm_ED : {norm_ED:0.4f}\n')
160
+ logger.log("="*20)
161
+
162
+ if opt.visualize:
163
+ logger.log("Threshold - ", opt.threshold)
164
+ logger.log("ED","\t","gt","\t","pred")
165
+ arr = []
166
+ for gt,pred,ED in eval_arr:
167
+ ED = ED*100.0
168
+ arr.append(ED)
169
+ if ED<=(opt.threshold):
170
+ logger.log(ED,"\t",gt,"\t",pred)
171
+ plt.hist(arr, edgecolor="red")
172
+ plt.savefig('test_outputs/'+str(datetime_now)+".png")
173
+ plt.close()
174
+
175
+ if __name__ == '__main__':
176
+ parser = argparse.ArgumentParser()
177
+ parser.add_argument('--visualize', action='store_true', help='for visualization of bad samples')
178
+ parser.add_argument('--threshold', type=float, help='Save samples below this threshold in txt file', default=50.0)
179
+ parser.add_argument('--eval_data', required=True, help='path to evaluation dataset')
180
+ parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
181
+ parser.add_argument('--batch_size', type=int, default=32, help='input batch size')
182
+ parser.add_argument('--saved_model', required=True, help="path to saved_model to evaluation")
183
+ """ Data processing """
184
+ parser.add_argument('--batch_max_length', type=int, default=100, help='maximum-label-length')
185
+ parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')
186
+ parser.add_argument('--imgW', type=int, default=400, help='the width of the input image')
187
+ parser.add_argument('--rgb', action='store_true', help='use rgb input')
188
+ """ Model Architecture """
189
+ parser.add_argument('--FeatureExtraction', type=str, default="HRNet", #required=True,
190
+ help='FeatureExtraction stage VGG|RCNN|ResNet|UNet|HRNet|Densenet|InceptionUnet|ResUnet|AttnUNet|UNet|VGG')
191
+ parser.add_argument('--SequenceModeling', type=str, default="DBiLSTM", #required=True,
192
+ help='SequenceModeling stage LSTM|GRU|MDLSTM|BiLSTM|DBiLSTM')
193
+ parser.add_argument('--Prediction', type=str, default="CTC", #required=True,
194
+ help='Prediction stage CTC|Attn')
195
+ parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor')
196
+ parser.add_argument('--output_channel', type=int, default=512, help='the number of output channel of Feature extractor')
197
+ parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')
198
+ """ GPU Selection """
199
+ parser.add_argument('--device_id', type=str, default=None, help='cuda device ID')
200
+
201
+ opt = parser.parse_args()
202
+ if opt.FeatureExtraction == "HRNet":
203
+ opt.output_channel = 32
204
+
205
+ # Fix random seeds for both numpy and pytorch
206
+ seed = 1111
207
+ torch.manual_seed(seed)
208
+ torch.cuda.manual_seed(seed)
209
+ np.random.seed(seed)
210
+ random.seed(seed)
211
+ torch.backends.cudnn.deterministic = True
212
+ torch.backends.cudnn.benchmark = False
213
+
214
+ """ vocab / character number configuration """
215
+ file = open("UrduGlyphs.txt","r",encoding="utf-8")
216
+ content = file.readlines()
217
+ content = ''.join([str(elem).strip('\n') for elem in content])
218
+ opt.character = content+" "
219
+
220
+ cuda_str = 'cuda'
221
+ if opt.device_id is not None:
222
+ cuda_str = f'cuda:{opt.device_id}'
223
+ device = torch.device(cuda_str if torch.cuda.is_available() else 'cpu')
224
+ print("Device : ", device)
225
+
226
+ # opt.eval_data = "/DATA/parseq/val/"
227
+ # test(opt, device)
228
+
229
+ # opt.eval_data = "/DATA/parseq/IIITH/lmdb_new/"
230
+ # test(opt, device)
231
+
232
+ # opt.eval_data = "/DATA/public_datasets/UPTI/valid/"
233
+ # test(opt, device)
234
+
235
+ test(opt, device)
code/recognization/data/1.png ADDED
code/recognization/data/10.png ADDED
code/recognization/data/11.png ADDED
code/recognization/data/12.png ADDED
code/recognization/data/13.png ADDED
code/recognization/data/14.png ADDED
code/recognization/data/15.png ADDED
code/recognization/data/16.png ADDED
code/recognization/data/2.png ADDED
code/recognization/data/2_1.png ADDED
code/recognization/data/2_2.png ADDED
code/recognization/data/3.png ADDED
code/recognization/data/4.png ADDED
code/recognization/data/5.png ADDED
code/recognization/data/6.png ADDED