ocr_plate_indonesia / src /apps /char_detection.py
Alimustoofaa's picture
first commit
e43f2e6
'''
@Author : Ali Mustofa HALOTEC
@Module : Character Detection Faster RCNN
@Created on : 19 Jul 2022
'''
#!/usr/bin/env python3
# Path: src/apps/char_detection.py
import os
import cv2
import numpy as np
from PIL import Image
from src.utils.utils import download_and_unzip_model
import torch
import torchvision
from torchvision import transforms
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
class CharDetection:
def __init__(self, root_path:str, model_config:dict) -> None:
'''
Load model
@params:
- root_path:str -> root of path model
- model_config:dict -> config of model {filename, classes, url, file_size}
'''
self.root_path = root_path
self.model_config = model_config
self.model_name = f'{root_path}/{model_config["filename"]}'
self.classes = model_config['classes']
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.model = self.__load_model()
@staticmethod
def __check_model(root_path:str, model_config:dict) -> None:
if not os.path.isfile(f'{root_path}/{model_config["filename"]}'):
download_and_unzip_model(
root_dir = root_path,
name = model_config['filename'],
url = model_config['url'],
file_size = model_config['file_size'],
unzip = False
)
else: print('Load model char detection')
@staticmethod
def __image_transform(image) -> torch.Tensor:
return transforms.Compose([transforms.ToTensor()])(image)
def __load_model(self) -> torch.nn.Module:
self.__check_model(self.root_path, self.model_config)
model = self.__fasterrcnn_resnet50_fpn()
model.load_state_dict(torch.load(self.model_name, map_location=self.device), False)
model.to(self.device)
return model.eval()
def __fasterrcnn_resnet50_fpn(self)-> torch.nn.Module:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, len(self.classes)+1)
return model
@staticmethod
def __filter_threshold(probs:dict, threshold:float) -> dict:
num_filtered = (probs['scores']>threshold).float()
keep = (num_filtered == torch.tensor(1)).nonzero().flatten()
final_probs = probs
final_probs['boxes'] = final_probs['boxes'][keep]
final_probs['scores'] = final_probs['scores'][keep]
final_probs['labels'] = final_probs['labels'][keep]
return final_probs
@staticmethod
def __original_boxes(boxes:torch.Tensor, img_size:tuple,resized:int) -> torch.Tensor:
image_width, image_height = img_size[1], img_size[0]
boxes = torch.tensor([[
(x_min/resized)*image_width, (y_min/resized)*image_height, \
(x_max/resized)*image_width, (y_max/resized)*image_height] \
for (x_min, y_min, x_max, y_max) in boxes.cpu().numpy()])
return boxes
@staticmethod
def __sort_by_boxes(probs:dict) -> dict:
x_min_list = [i[0] for i in probs['boxes']]
idx = [x_min_list.index(x) for x in sorted(x_min_list)]
probs['boxes'] = probs['boxes'][idx]
probs['scores'] = probs['scores'][idx]
probs['labels'] = probs['labels'][idx]
return probs
def detect(self, image:np.array, size:int = None,
boxes_ori:bool = False, threshold:float = 0.5, sorted:bool = True) -> dict:
'''
@params:
- image: numpy array of image
- size: int of image resize
- boxes_ori: bool of original boxes
- threshold: float of threshold
- sorted: bool of sorted by boxes
@return:
probs: dict of probs -> {
'boxes' : [x_min, y_min, x_max, y_max],
'scores': [float],
'labels': [int]
}
'''
im_shape = (image.shape[0], image.shape[1])
image = cv2.resize(image, (size,size)) if size else image
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
image = self.__image_transform(image)
with torch.no_grad():
probs = self.model([image])[0]
probs = self.__filter_threshold(probs, threshold)
if boxes_ori and size:
probs['boxes'] = self.__original_boxes(probs['boxes'],im_shape, size)
if sorted:
probs = self.__sort_by_boxes(probs)
return {k: v.cpu().numpy() for k, v in probs.items()}
if __name__ == '__main__':
char_detection = CharDetection('./models/text_detection.ali', ['text'])
image = cv2.imread('./images/1.jpg')
results = char_detection.detect(image, size=244, boxes_ori=True, threshold=0.01)
print(results)