Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import json | |
| import cv2 | |
| import argparse | |
| import matplotlib.pyplot as plt | |
| import ultralytics | |
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| from torchvision.models import resnet101 | |
| from torchvision import transforms | |
| sys.path.append(os.getcwd()) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| from ultralytics import YOLO | |
| from src.Text_Recognization.text_recognization import * | |
| from src.Text_Recognization.prepare_dataset import * | |
| # config | |
| def load_json_config(config_path): | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| return config | |
| config = load_json_config('src/config.json') | |
| # char to idx | |
| with open('src/encode.pkl', "rb") as f: | |
| char_to_idx = pickle.load(f) | |
| # idx to char | |
| with open('src/decode.pkl', "rb") as f: | |
| idx_to_char = pickle.load(f) | |
| # text detection model | |
| text_det_model_path = 'checkpoints/yolov11m.pt' | |
| yolo = YOLO(text_det_model_path) | |
| # text recognition model | |
| text_rec_model_path = 'checkpoints/crnn_extend_vocab.pt' | |
| # rcnn model | |
| rcnn_model = CRNN(vocab_size=74, hidden_size=config['CRNN']['hidden_size'], n_layers=config['CRNN']['n_layers']) | |
| rcnn_model.load_state_dict(torch.load(text_rec_model_path, weights_only=True, map_location=torch.device('cpu'))) | |
| def text_detection(img_path, text_det_model): | |
| text_det_results = text_det_model(img_path, verbose=False)[0] | |
| bboxes = text_det_results.boxes.xyxy.tolist() | |
| classes = text_det_results.boxes.cls.tolist() | |
| names = text_det_results.names | |
| confs = text_det_results.boxes.conf.tolist() | |
| return bboxes, classes, names, confs | |
| def visualize_gt_bboxes_yolo(image_path, gt_location_yolo): | |
| image = cv2.imread(image_path) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| # Convert to original format | |
| for data in gt_location_yolo: | |
| xmin, ymin, xmax, ymax = data | |
| xmin, ymin, xmax, ymax = int(xmin), int(ymin), int(xmax), int(ymax) | |
| image = cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color=(255, 0, 0), thickness=2) | |
| plt.imshow(image) | |
| plt.axis('off') | |
| plt.show() | |
| def text_recognization(image, data_transforms, text_reg_model, idx_to_char=idx_to_char, device=device): | |
| transformsed_image = data_transforms(image) | |
| transformsed_image = transformsed_image.unsqueeze(0).to(device) | |
| text_reg_model.to(device) | |
| text_reg_model.eval() | |
| with torch.no_grad(): | |
| preds = text_reg_model(transformsed_image) | |
| _, idx = torch.max(preds, dim=2) | |
| idx = idx.view(-1) | |
| text = decode(idx, idx_to_char, char_to_idx) | |
| return text, idx | |
| def visualize_detection(image, detections): | |
| plt.figure(figsize=(10, 8)) | |
| for bbox, detected_classes, conf, text, _ in detections: | |
| x1, y1, x2, y2 = bbox | |
| x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) | |
| image = cv2.rectangle(image, (x1, y1), (x2, y2), color=(255, 0, 0), thickness=2) | |
| image = cv2.putText(image, f"{conf:.2f} {text}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2) | |
| plt.imshow(image) | |
| plt.axis('off') | |
| plt.show() | |
| return image | |
| data_transforms = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Resize((100, 400)), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def prediction(image, text_det_model=yolo, text_reg_model=rcnn_model, idx_to_char=idx_to_char, char_to_idx=char_to_idx, data_transforms=data_transforms, device=device): | |
| # detection | |
| bboxes, classes, names, confs = text_detection(image, text_det_model) | |
| predictions = [] | |
| for bbox, cls, conf in zip(bboxes, classes, confs): | |
| x1, y1, x2, y2 = bbox | |
| x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) | |
| detected_text = image[y1:y2, x1:x2] | |
| text, encoded_text = text_recognization(detected_text, data_transforms, text_reg_model, idx_to_char, device) | |
| predictions.append((bbox, cls, conf, text, encoded_text)) | |
| print(bbox, cls, conf, text) | |
| return predictions | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--image_path', type=str, help='Path to the image') | |
| parser.add_argument('--save_path', type=str, default=None, help='Path to save the image') | |
| args = parser.parse_args() | |
| image_path = args.image_path | |
| image = cv2.imread(image_path) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| detections = prediction(image) | |
| image = visualize_detection(image, detections) | |
| if args.save_path: | |
| print(f"Saving the image to {os.path.join(args.save_path, 'predicted_image.jpg')}") | |
| cv2.imwrite(os.path.join(args.save_path, 'predicted_image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) | |
| if __name__ == '__main__': | |
| main() | |