TungDuong's picture
source code
06142a4 verified
import os
import sys
import json
import cv2
import argparse
import matplotlib.pyplot as plt
import ultralytics
import torch
import torch.nn as nn
import torchvision
from torchvision.models import resnet101
from torchvision import transforms
sys.path.append(os.getcwd())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from ultralytics import YOLO
from src.Text_Recognization.text_recognization import *
from src.Text_Recognization.prepare_dataset import *
# config
def load_json_config(config_path):
with open(config_path, "r") as f:
config = json.load(f)
return config
config = load_json_config('src/config.json')
# char to idx
with open('src/encode.pkl', "rb") as f:
char_to_idx = pickle.load(f)
# idx to char
with open('src/decode.pkl', "rb") as f:
idx_to_char = pickle.load(f)
# text detection model
text_det_model_path = 'checkpoints/yolov11m.pt'
yolo = YOLO(text_det_model_path)
# text recognition model
text_rec_model_path = 'checkpoints/crnn_extend_vocab.pt'
# rcnn model
rcnn_model = CRNN(vocab_size=74, hidden_size=config['CRNN']['hidden_size'], n_layers=config['CRNN']['n_layers'])
rcnn_model.load_state_dict(torch.load(text_rec_model_path, weights_only=True, map_location=torch.device('cpu')))
def text_detection(img_path, text_det_model):
text_det_results = text_det_model(img_path, verbose=False)[0]
bboxes = text_det_results.boxes.xyxy.tolist()
classes = text_det_results.boxes.cls.tolist()
names = text_det_results.names
confs = text_det_results.boxes.conf.tolist()
return bboxes, classes, names, confs
def visualize_gt_bboxes_yolo(image_path, gt_location_yolo):
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Convert to original format
for data in gt_location_yolo:
xmin, ymin, xmax, ymax = data
xmin, ymin, xmax, ymax = int(xmin), int(ymin), int(xmax), int(ymax)
image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color=(255, 0, 0), thickness=2)
plt.imshow(image)
plt.axis('off')
plt.show()
def text_recognization(image, data_transforms, text_reg_model, idx_to_char=idx_to_char, device=device):
transformsed_image = data_transforms(image)
transformsed_image = transformsed_image.unsqueeze(0).to(device)
text_reg_model.to(device)
text_reg_model.eval()
with torch.no_grad():
preds = text_reg_model(transformsed_image)
_, idx = torch.max(preds, dim=2)
idx = idx.view(-1)
text = decode(idx, idx_to_char, char_to_idx)
return text, idx
def visualize_detection(image, detections):
plt.figure(figsize=(10, 8))
for bbox, detected_classes, conf, text, _ in detections:
x1, y1, x2, y2 = bbox
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
image = cv2.rectangle(image, (x1, y1), (x2, y2), color=(255, 0, 0), thickness=2)
image = cv2.putText(image, f"{conf:.2f} {text}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
plt.imshow(image)
plt.axis('off')
plt.show()
return image
data_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((100, 400)),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def prediction(image, text_det_model=yolo, text_reg_model=rcnn_model, idx_to_char=idx_to_char, char_to_idx=char_to_idx, data_transforms=data_transforms, device=device):
# detection
bboxes, classes, names, confs = text_detection(image, text_det_model)
predictions = []
for bbox, cls, conf in zip(bboxes, classes, confs):
x1, y1, x2, y2 = bbox
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
detected_text = image[y1:y2, x1:x2]
text, encoded_text = text_recognization(detected_text, data_transforms, text_reg_model, idx_to_char, device)
predictions.append((bbox, cls, conf, text, encoded_text))
print(bbox, cls, conf, text)
return predictions
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--image_path', type=str, help='Path to the image')
parser.add_argument('--save_path', type=str, default=None, help='Path to save the image')
args = parser.parse_args()
image_path = args.image_path
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
detections = prediction(image)
image = visualize_detection(image, detections)
if args.save_path:
print(f"Saving the image to {os.path.join(args.save_path, 'predicted_image.jpg')}")
cv2.imwrite(os.path.join(args.save_path, 'predicted_image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
if __name__ == '__main__':
main()