File size: 2,541 Bytes
36255e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from ultralytics import YOLO
from Model_loading import processor_tr_ocr, trocr_model

import os

# Define a path relative to the script's location
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(BASE_DIR, "Models", "invoice_yolo_100_7classes.pt")


# Load the YOLO model

info_model = YOLO(model_path)



# Function for performing OCR with TrOCR
def ocr_with_transformer(image):
    # Convert the image to RGB and process it for TrOCR
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    pixel_values = processor_tr_ocr(images=image_rgb, return_tensors="pt").pixel_values
    generated_ids = trocr_model.generate(pixel_values)
    
    # Decode the generated text from the model
    generated_text = processor_tr_ocr.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text

# Main function for detection, cropping, and OCR
def info_det_and_ocr(image_arr, conf_threshold=0.5):
    # Initialize the dictionary to store detected class names and text
    detected_data = {}

    # Perform inference using YOLO
    results = info_model.predict(image_arr)
    
    boxes = results[0].boxes.xyxy.cpu().numpy()  # Bounding boxes
    scores = results[0].boxes.conf.cpu().numpy()  # Confidence scores
    class_ids = results[0].boxes.cls.cpu().numpy().astype(int)  # Class IDs
    class_names = results[0].names  # Class names dictionary
    
    # Iterate over the detected objects
    for i, box in enumerate(boxes):
        score = scores[i]
        if score < conf_threshold:
            continue  # Skip low confidence detections

        x1, y1, x2, y2 = map(int, box)  # Convert bounding box to integers
        class_id = class_ids[i]
        class_name = class_names[class_id]

        # Crop the detected region
        cropped_region = image_arr[y1:y2, x1:x2]

        # Perform OCR using TrOCR on the cropped region
        detected_text = ocr_with_transformer(cropped_region)

        # Save the detected text into the dictionary with class name as key
        detected_data[class_name] = detected_text

        # Optional: Display cropped image for debugging
        #plt.imshow(cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB))
        #plt.title(f"Detected: {class_name}")
        #plt.show()

    return detected_data

# You can now call this function by passing the image as a numpy array.
# Example:
# extracted_data = info_det_and_ocr(image_arr)
# print(extracted_data)