Spaces:
Runtime error
Runtime error
File size: 4,821 Bytes
06142a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import os
import sys
import json
import cv2
import argparse
import matplotlib.pyplot as plt
import ultralytics
import torch
import torch.nn as nn
import torchvision
from torchvision.models import resnet101
from torchvision import transforms
sys.path.append(os.getcwd())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from ultralytics import YOLO
from src.Text_Recognization.text_recognization import *
from src.Text_Recognization.prepare_dataset import *
# config
def load_json_config(config_path):
with open(config_path, "r") as f:
config = json.load(f)
return config
config = load_json_config('src/config.json')
# char to idx
with open('src/encode.pkl', "rb") as f:
char_to_idx = pickle.load(f)
# idx to char
with open('src/decode.pkl', "rb") as f:
idx_to_char = pickle.load(f)
# text detection model
text_det_model_path = 'checkpoints/yolov11m.pt'
yolo = YOLO(text_det_model_path)
# text recognition model
text_rec_model_path = 'checkpoints/crnn_extend_vocab.pt'
# rcnn model
rcnn_model = CRNN(vocab_size=74, hidden_size=config['CRNN']['hidden_size'], n_layers=config['CRNN']['n_layers'])
rcnn_model.load_state_dict(torch.load(text_rec_model_path, weights_only=True, map_location=torch.device('cpu')))
def text_detection(img_path, text_det_model):
text_det_results = text_det_model(img_path, verbose=False)[0]
bboxes = text_det_results.boxes.xyxy.tolist()
classes = text_det_results.boxes.cls.tolist()
names = text_det_results.names
confs = text_det_results.boxes.conf.tolist()
return bboxes, classes, names, confs
def visualize_gt_bboxes_yolo(image_path, gt_location_yolo):
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Convert to original format
for data in gt_location_yolo:
xmin, ymin, xmax, ymax = data
xmin, ymin, xmax, ymax = int(xmin), int(ymin), int(xmax), int(ymax)
image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color=(255, 0, 0), thickness=2)
plt.imshow(image)
plt.axis('off')
plt.show()
def text_recognization(image, data_transforms, text_reg_model, idx_to_char=idx_to_char, device=device):
transformsed_image = data_transforms(image)
transformsed_image = transformsed_image.unsqueeze(0).to(device)
text_reg_model.to(device)
text_reg_model.eval()
with torch.no_grad():
preds = text_reg_model(transformsed_image)
_, idx = torch.max(preds, dim=2)
idx = idx.view(-1)
text = decode(idx, idx_to_char, char_to_idx)
return text, idx
def visualize_detection(image, detections):
plt.figure(figsize=(10, 8))
for bbox, detected_classes, conf, text, _ in detections:
x1, y1, x2, y2 = bbox
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
image = cv2.rectangle(image, (x1, y1), (x2, y2), color=(255, 0, 0), thickness=2)
image = cv2.putText(image, f"{conf:.2f} {text}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
plt.imshow(image)
plt.axis('off')
plt.show()
return image
data_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((100, 400)),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def prediction(image, text_det_model=yolo, text_reg_model=rcnn_model, idx_to_char=idx_to_char, char_to_idx=char_to_idx, data_transforms=data_transforms, device=device):
# detection
bboxes, classes, names, confs = text_detection(image, text_det_model)
predictions = []
for bbox, cls, conf in zip(bboxes, classes, confs):
x1, y1, x2, y2 = bbox
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
detected_text = image[y1:y2, x1:x2]
text, encoded_text = text_recognization(detected_text, data_transforms, text_reg_model, idx_to_char, device)
predictions.append((bbox, cls, conf, text, encoded_text))
print(bbox, cls, conf, text)
return predictions
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--image_path', type=str, help='Path to the image')
parser.add_argument('--save_path', type=str, default=None, help='Path to save the image')
args = parser.parse_args()
image_path = args.image_path
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
detections = prediction(image)
image = visualize_detection(image, detections)
if args.save_path:
print(f"Saving the image to {os.path.join(args.save_path, 'predicted_image.jpg')}")
cv2.imwrite(os.path.join(args.save_path, 'predicted_image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
if __name__ == '__main__':
main()
|