Invoice_Extractor / info_det_ocr.py
gouri180's picture
Firs Push
36255e5
import torch
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from ultralytics import YOLO
from Model_loading import processor_tr_ocr, trocr_model
import os
# Define a path relative to the script's location
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(BASE_DIR, "Models", "invoice_yolo_100_7classes.pt")
# Load the YOLO model
info_model = YOLO(model_path)
# Function for performing OCR with TrOCR
def ocr_with_transformer(image):
# Convert the image to RGB and process it for TrOCR
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pixel_values = processor_tr_ocr(images=image_rgb, return_tensors="pt").pixel_values
generated_ids = trocr_model.generate(pixel_values)
# Decode the generated text from the model
generated_text = processor_tr_ocr.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
# Main function for detection, cropping, and OCR
def info_det_and_ocr(image_arr, conf_threshold=0.5):
# Initialize the dictionary to store detected class names and text
detected_data = {}
# Perform inference using YOLO
results = info_model.predict(image_arr)
boxes = results[0].boxes.xyxy.cpu().numpy() # Bounding boxes
scores = results[0].boxes.conf.cpu().numpy() # Confidence scores
class_ids = results[0].boxes.cls.cpu().numpy().astype(int) # Class IDs
class_names = results[0].names # Class names dictionary
# Iterate over the detected objects
for i, box in enumerate(boxes):
score = scores[i]
if score < conf_threshold:
continue # Skip low confidence detections
x1, y1, x2, y2 = map(int, box) # Convert bounding box to integers
class_id = class_ids[i]
class_name = class_names[class_id]
# Crop the detected region
cropped_region = image_arr[y1:y2, x1:x2]
# Perform OCR using TrOCR on the cropped region
detected_text = ocr_with_transformer(cropped_region)
# Save the detected text into the dictionary with class name as key
detected_data[class_name] = detected_text
# Optional: Display cropped image for debugging
#plt.imshow(cv2.cvtColor(cropped_region, cv2.COLOR_BGR2RGB))
#plt.title(f"Detected: {class_name}")
#plt.show()
return detected_data
# You can now call this function by passing the image as a numpy array.
# Example:
# extracted_data = info_det_and_ocr(image_arr)
# print(extracted_data)