ocr_paligemma / utils_ocr.py
AnkitShrestha's picture
Add all data to db
5fcefd5
from pdf2image import convert_from_path
import numpy as np
import cv2
from PIL import Image
import json
import sqlite3
from datetime import datetime
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
from dotenv import load_dotenv
import os
from huggingface_hub import login
import torch
# from main import predict as predict_main
# # # Load environment variables
# # load_dotenv()
# # # Set the cache directory to a writable path
# # os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torch_inductor_cache"
# # token = os.getenv("huggingface_ankit")
# # # Login to the Hugging Face Hub
# # login(token)
# with open("ocr/VGG Image Annotator_files/mach_labeler.json", "r") as f:
# data = json.load(f)
# def center_pad_image(image, target_size=448):
# # Get original dimensions
# original_h, original_w = image.shape[:2]
# # If image is larger, resize while maintaining aspect ratio
# if original_h > target_size or original_w > target_size:
# scale = target_size / max(original_h, original_w)
# new_h, new_w = int(original_h * scale), int(original_w * scale)
# image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
# else:
# new_h, new_w = original_h, original_w
# # Calculate padding
# pad_h = (target_size - new_h) // 2
# pad_w = (target_size - new_w) // 2
# # Create black background
# new_image = np.ones((target_size, target_size, 3), dtype=np.uint8) * 255
# # Place the resized image at the center
# new_image[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = image
# return new_image
# def predict_check(cropped_image, threshold=0.0870):
# gray_image = np.mean(cropped_image, axis=2) # Convert to grayscale
# # remove noise
# gray_image = cv2.GaussianBlur(gray_image, (5, 5), 0)
# pixel_density = np.count_nonzero(gray_image < 128) / gray_image.size # Count dark pixels
# # print("pixel threshold ",pixel_density)
# if pixel_density > threshold:
# return "Ticked"
# else:
# return "NotTicked"
# def make_batch(ocr_regions, batch_size = 6):
# for i in range(0, len(ocr_regions), batch_size):
# yield ocr_regions[i:i + batch_size] # Yield a batch of size `batch_size`
# import requests
# def save_images(images,save_dir):
# os.makedirs(save_dir, exist_ok=True) # Ensure directory exists
# saved_paths = []
# for i, img in enumerate(images):
# file_path = os.path.join(save_dir, f"image_{i}.png") # Save as PNG
# img.save(file_path)
# saved_paths.append(file_path) # Store the file path
# return saved_paths
# import shutil
# def delete_saved_images(save_dir):
# if os.path.exists(save_dir):
# shutil.rmtree(save_dir) # Deletes the entire folder and its contents
# print(f"Deleted all images in {save_dir}")
# else:
# print(f"Directory {save_dir} does not exist")
# def batch_predict_ext(image_batch,save_path):
# file_paths = save_images(image_batch,save_path)
# # files = [("files", (img, open(img, "rb"), "image/jpeg" if img.endswith(".jpg") else "image/png")) for img in file_paths]
# files = []
# for img in file_paths:
# with open(img, "rb") as f:
# file_content = f.read() # Read the file into memory
# file_type = "image/jpeg" if img.endswith(".jpg") else "image/png"
# files.append(("files", (img, file_content, file_type))) # Pass the file content
# url = "https://aioverlords-amnil-ocr-test-pali.hf.space/batch_extract_text"
# headers = {"accept": "application/json"}
# response = requests.post(url, files=files, headers=headers)
# delete_saved_images(save_path)
# if response.status_code == 200:
# return response.json() # Returns extracted text as JSON
# else:
# return {"error": f"Request failed with status code {response.status_code}"}
# import uuid
# def batch_ocr_ext(file_name,task_id,batch_size):
# try:
# with open("ocr/VGG Image Annotator_files/mach_labeler.json", "r") as f:
# data = json.load(f)
# start_time = datetime.now()
# check_regions = []
# ocr_regions = []
# blank_regions=[]
# # final = []
# j = 0
# for k,v in data['_via_img_metadata'].items():
# # k is the pages in the form
# print(k)
# # regions is the list of regions in a single page.
# # it is a list of dictionary with each dictionary having shape_attributes and region_attributes
# regions = data['_via_img_metadata'][k]['regions']
# file = file_name
# # Check if the file is pdf
# if file.endswith("pdf"):
# # Extracts the j-th page from a PDF as an image.
# # .convert("L") converts the image to grayscale
# # then convert the image to numpy array to process it with opencv
# targ_img = np.array(convert_from_path(file)[j].convert("L"))
# else:
# targ_img = cv2.imread(file, cv2.IMREAD_GRAYSCALE)
# # Used for feature detection and image matching
# # Possible to optimize?
# MAX_NUM_FEATURES = 10000
# orb = cv2.ORB_create(MAX_NUM_FEATURES)
# # Load the blank form of j-th page
# orig_img = np.array(Image.open(f"ocr/VGG Image Annotator_files/mach_bank_form_page{j}.jpg").convert("L"))
# # Detects keypoints (corner-like features) in orig_img and targ_img.
# # and computes descriptors, which are binary feature representations for each keypoint.
# keypoints1, descriptors1 = orb.detectAndCompute(orig_img, None)
# keypoints2, descriptors2 = orb.detectAndCompute(targ_img, None)
# # ORB typically works on grayscale images.
# # Converts images back to BGR for displaying colored keypoints.
# # just for visualization or any other use-case?
# img1 = cv2.cvtColor(orig_img, cv2.COLOR_GRAY2BGR)
# img2 = cv2.cvtColor(targ_img, cv2.COLOR_GRAY2BGR)
# # Match features.
# # ORB uses binary descriptors, and Hamming distance counts the number of differing bits.
# # Faster than Euclidean distance for binary descriptors.
# matcher = cv2.DescriptorMatcher_create(cv2.DESCRIPTOR_MATCHER_BRUTEFORCE_HAMMING)
# # match() finds the best match for each descriptor in descriptor1 and descriptor2
# # matches stores a list of cv2.DMatch objects where
# # .queryIdx --> index of the keypoint in orig_img
# # .trainIdx --> index of matching keypoint in trag_img
# # .distance --> Hamming distance
# # Converting to list for sorting as tuples are immutable objects.
# matches = list(matcher.match(descriptors1, descriptors2, None))
# # Sort matches by score
# # Sorting the matches based on hamming distance
# matches.sort(key = lambda x: x.distance, reverse = False)
# # Remove not so good matches
# numGoodMatches = int(0.1*len(matches))
# matches = matches[:numGoodMatches]
# # matches = matches[:len(matches)//10]
# # Initialize arrays to store Keypoint locations
# # float32 used for compatibility with cv2.findHomography()
# points1 = np.zeros((len(matches), 2), dtype = np.float32)
# points2 = np.zeros((len(matches), 2), dtype = np.float32)
# # Extract location of good matches
# for i, match in enumerate(matches):
# points1[i, :] = keypoints1[match.queryIdx].pt
# points2[i, :] = keypoints2[match.trainIdx].pt
# # Find homography
# h, mask = cv2.findHomography(points2, points1, cv2.RANSAC)
# height, width, channels = img1.shape
# # Warp img2 to align with img1
# img2_reg = cv2.warpPerspective(img2, h, (width, height))
# region_data = []
# for region in regions:
# x, y, width, height = (
# region['shape_attributes']['x'],
# region['shape_attributes']['y'],
# region['shape_attributes']['width'],
# region['shape_attributes']['height']
# )
# name = (
# f"{region['region_attributes']['parent']}_"
# f"{region['region_attributes']['key']}_"
# f"{region['region_attributes'].get('group', '')}"
# )
# name_type = region['region_attributes']['type']
# region_data.append({
# "x": x,
# "y": y,
# "width": width,
# "height": height,
# "name": name,
# "type": name_type
# })
# # iterate through the region_data and crop the images portion and if type is check call predict_check function else call predict function
# for region in region_data:
# x, y, width, height = region["x"], region["y"], region["width"], region["height"]
# cropped_image = img2_reg[y:y+height, x:x+width] # Assuming 'image' is defined
# # plt.imshow(cropped_image, cmap='gray')
# # plt.axis("off")
# # plt.show()
# # IF Checkbox, then run checkbox function
# # else Check if the cropped image contains any significant edges suggesting there is text and send it to OCR
# # If no significant edges are found then not found is returned
# # if region["type"] == "check":
# # pred = predict_check(cropped_image,threshold=0.0850)
# # print(check_status)
# if region["type"] == "check":
# region["page"] = f"page_{j}"
# check_regions.append((region, cropped_image))
# else:
# cedge = cv2.Canny(cropped_image[7:-7, 7:-7], 100, 200)
# cex_ = cedge.astype(float).sum(axis=0)/255
# cey_ = cedge.astype(float).sum(axis=1)/255
# cex_ = np.count_nonzero(cex_>5)
# cey_ = np.count_nonzero(cey_>5)
# colr = (0,0,255)
# if cex_ > 7 and cey_ > 7:
# # Image.fromarray(im).convert('RGB')
# im = Image.fromarray(center_pad_image(cropped_image))
# region["page"] = f"page_{j}"
# ocr_regions.append((region, im))
# else:
# pred = "not found"
# region["status"] = pred
# region["page"] = f"page_{j}"
# blank_regions.append(region)
# # if len(check_regions) >= BATCH_SIZE:
# # batch_checkpoint(check_regions)
# # check_regions = []
# # if len(ocr_regions) >= BATCH_SIZE:
# # batch_ocr(ocr_regions)
# # ocr_regions =[]
# j += 1
# print("Check Regions Started")
# # return check_regions,ocr_regions,blank_regions
# check_region_data = []
# for check_region in check_regions:
# check_region[0]["status"] = predict_check(check_region[1])
# check_region_data.append(check_region[0])
# print("Check Regions End")
# print("OCR Regions Started")
# region_data = []
# count = 0
# for batch in make_batch(ocr_regions,batch_size):
# images = []
# for data in batch:
# images.append(data[1])
# print(f"-----Batch {count}------")
# save_path = f"{str(uuid.uuid4())}"
# response = batch_predict_ext(images,save_path)
# extracted_texts = response["extracted_texts"]
# print(f"-----Batch {count} Completed------")
# for text,region in zip(extracted_texts,batch):
# region[0]["status"] = text
# region_data.append(region[0])
# count = count + 1
# # Combine all region data
# region_data.extend(check_region_data)
# region_data.extend(blank_regions)
# string_data = json.dumps(region_data)
# print(type(string_data))
# # Store the time take for the process to complete
# end_time = datetime.now()
# time_elapsed = end_time-start_time
# time_elapsed_str = str(time_elapsed) # Convert seconds to string
# os.remove(file_name)
# # Update database
# conn = sqlite3.connect('/mnt/data//mnt/data/translations.db')
# cursor = conn.cursor()
# cursor.execute('UPDATE OCR SET region = ?, time_elapsed = ?, status=?, updated_at = ? WHERE task_id = ? ',
# (string_data,time_elapsed_str,"completed",datetime.now(),task_id))
# conn.commit()
# conn.close()
# print("SUCESSFUL")
# except Exception as e:
# print(f"OCR Failed : {e}")
# try:
# conn = sqlite3.connect('/mnt/data//mnt/data/translations.db')
# cursor = conn.cursor()
# cursor.execute('UPDATE OCR SET status = ? WHERE task_id = ?', ("failed", task_id))
# conn.commit()
# conn.close()
# except Exception as exec:
# print(f"Updating status to database failed: {exec}")
from io import BytesIO
import requests
import cv2
import numpy as np
from PIL import Image
import json
import sqlite3
from datetime import datetime
import uuid
import aiohttp
# Assume data is loaded globally
with open("ocr/VGG Image Annotator_files/mach_labeler.json", "r") as f:
data = json.load(f)
async def check_health():
# Simulating an async health check (e.g., HTTP request)
async with aiohttp.ClientSession() as session:
async with session.get("https://aioverlords-amnil-ocr-test-pali.hf.space/health") as response:
if response.status == 200:
return "healthy"
else:
return "unhealthy"
from PIL import Image, ImageDraw
def create_at_image(image_height):
# Set the height of the image
height = image_height
# Set a fixed font size, you can adjust it to suit your needs
font_size = int(height*0.8) # Font size as a fraction of image height
# Create an image with a white background
image = Image.new("RGB", (font_size, height), color="white")
# Set up the drawing context
draw = ImageDraw.Draw(image)
# Load a font (You can specify a path to a .ttf file for custom fonts)
# font = ImageFont.truetype("arial.ttf", font_size) # Make sure you have arial.ttf
# Text to write
text = "AAA"
# Get the bounding box of the text (replaces textsize())
bbox = draw.textbbox((0, 0), text)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
# Position the text at the center of the image
text_x = (font_size - text_width) // 2
text_y = (height - text_height) // 2
text_position = (text_x, text_y)
# Add the "@" symbol to the image
draw.text(text_position, text, fill="black")
# Save the image
image_array = np.array(image)
return image_array
import cv2
import numpy as np
def center_pad_image(image, target_size=448):
h, w = image.shape[:2]
if h > target_size or w > target_size:
scale = target_size / max(h, w)
new_h, new_w = int(h * scale), int(w * scale)
image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
else:
new_h, new_w = h, w
pad_h = (target_size - new_h) // 2
pad_w = (target_size - new_w) // 2
padded_image = cv2.copyMakeBorder(
image, pad_h, target_size - new_h - pad_h, pad_w, target_size - new_w - pad_w,
cv2.BORDER_CONSTANT, value=(255, 255, 255)
)
return padded_image
def resize_or_pad(image, target_height, target_width=448):
h, w = image.shape[:2]
if h > target_height:
# Resize keeping aspect ratio (height first)
scale = target_height / h
new_w = int(w * scale)
image = cv2.resize(image, (new_w, target_height), interpolation=cv2.INTER_AREA)
else:
# Pad height to match target_height
pad_h = (target_height - h) // 2
image = cv2.copyMakeBorder(image, pad_h, target_height - h - pad_h, 0, 0,
cv2.BORDER_CONSTANT, value=(255, 255, 255))
# Adjust width (resize if too wide, pad if too narrow)
h, w = image.shape[:2]
if w > target_width:
image = cv2.resize(image, (target_width, h), interpolation=cv2.INTER_AREA)
elif w < target_width:
pad_w = (target_width - w) // 2
image = cv2.copyMakeBorder(image, 0, 0, pad_w, target_width - w - pad_w,
cv2.BORDER_CONSTANT, value=(255, 255, 255))
return image
def stack_images_vertically(stack_size, ocr_buffer, target_size=448):
occupied_height = target_size - 30 * (stack_size - 1)
height = occupied_height // stack_size
stacked_image = np.zeros((0, target_size, 3), dtype=np.uint8)
for i in range(stack_size):
resized_img = resize_or_pad(ocr_buffer[i],height)
if i == stack_size - 1:
stacked_image = np.vstack([stacked_image, resized_img])
else:
stacked_image = np.vstack([stacked_image, resized_img, resize_or_pad(create_at_image(30),30)])
img = center_pad_image(stacked_image, target_size)
# cv2.imwrite(f"img/stacked_image_{str(uuid.uuid4())}.png", img)
return img
def predict_check(cropped_image, threshold=0.0870):
gray_image = np.mean(cropped_image, axis=2)
gray_image = cv2.GaussianBlur(gray_image, (5, 5), 0)
pixel_density = np.count_nonzero(gray_image < 128) / gray_image.size
return "Ticked" if pixel_density > threshold else "NotTicked"
def make_batch(ocr_regions, batch_size=6):
for i in range(0, len(ocr_regions), batch_size):
yield ocr_regions[i:i + batch_size]
# def batch_predict_ext(image_batch):
# files = []
# for i, img in enumerate(image_batch):
# buffer = BytesIO()
# img.save(buffer, format="PNG")
# file_content = buffer.getvalue()
# files.append(("files", (f"image_{i}.png", file_content, "image/png")))
# url = "https://aioverlords-amnil-ocr-test-pali.hf.space/batch_extract_text"
# headers = {"accept": "application/json"}
# response = requests.post(url, files=files, headers=headers)
# if response.status_code == 200:
# return response.json()
# else:
# return {"error": f"Request failed with status code {response.status_code}"}
async def batch_predict_ext_async(image_batch,batch_size):
files = aiohttp.FormData() # Async form data
for i, img in enumerate(image_batch):
buffer = BytesIO()
img.save(buffer, format="PNG")
file_content = buffer.getvalue()
files.add_field("files", file_content, filename=f"image_{i}.png", content_type="image/png")
print("Files added to form data")
url = f"https://aioverlords-amnil-ocr-test-pali.hf.space/batch_extract_text?batch_size={batch_size}"
headers = {"accept": "application/json"}
try:
async with aiohttp.ClientSession() as session:
async with session.post(url,data= files, headers=headers) as response:
if response.status == 200:
print("OCR Success")
return await response.json() # ✅ Fully async
else:
# print(await response.json())
return {"error": f"Request failed with status code {response.status}"}
except Exception as e:
print("Error: ",e)
async def batch_predict_ext_async_vllm(image_batch):
files = aiohttp.FormData() # Async form data
for i, img in enumerate(image_batch):
buffer = BytesIO()
img.save(buffer, format="png")
file_content = buffer.getvalue()
files.add_field("files", file_content, filename=f"image_{i}.png", content_type="image/png")
print("Files added to form data")
url = f"https://aioverlords-amnil-ocr-test-pali.hf.space/batch_extract_text_vllm"
headers = {"accept": "application/json"}
try:
async with aiohttp.ClientSession() as session:
async with session.post(url,data= files, headers=headers) as response:
if response.status == 200:
print("OCR Success")
print(await response.json())
return await response.json() # ✅ Fully async
else:
# print(await response.json())
return {"error": f"Request failed with status code {response.status}"}
except Exception as e:
print("Error: ",e)
async def batch_ocr_ext_async_vllm(file_name, task_id):
try:
start_time = datetime.now()
check_regions = []
ocr_regions = []
blank_regions = []
j = 0
for k, v in data['_via_img_metadata'].items():
regions = data['_via_img_metadata'][k]['regions']
file = file_name
if file.endswith("pdf"):
targ_img = np.array(convert_from_path(file)[j].convert("L"))
else:
targ_img = cv2.imread(file, cv2.IMREAD_GRAYSCALE)
# ORB and alignment code remains unchanged for brevity
# Used for feature detection and image matching
# Possible to optimize?
MAX_NUM_FEATURES = 10000
orb = cv2.ORB_create(MAX_NUM_FEATURES)
# Load the blank form of j-th page
orig_img = np.array(Image.open(f"ocr/VGG Image Annotator_files/mach_bank_form_page{j}.jpg").convert("L"))
# Detects keypoints (corner-like features) in orig_img and targ_img.
# and computes descriptors, which are binary feature representations for each keypoint.
keypoints1, descriptors1 = orb.detectAndCompute(orig_img, None)
keypoints2, descriptors2 = orb.detectAndCompute(targ_img, None)
# ORB typically works on grayscale images.
# Converts images back to BGR for displaying colored keypoints.
# just for visualization or any other use-case?
img1 = cv2.cvtColor(orig_img, cv2.COLOR_GRAY2BGR)
img2 = cv2.cvtColor(targ_img, cv2.COLOR_GRAY2BGR)
# Match features.
# ORB uses binary descriptors, and Hamming distance counts the number of differing bits.
# Faster than Euclidean distance for binary descriptors.
matcher = cv2.DescriptorMatcher_create(cv2.DESCRIPTOR_MATCHER_BRUTEFORCE_HAMMING)
# match() finds the best match for each descriptor in descriptor1 and descriptor2
# matches stores a list of cv2.DMatch objects where
# .queryIdx --> index of the keypoint in orig_img
# .trainIdx --> index of matching keypoint in trag_img
# .distance --> Hamming distance
# Converting to list for sorting as tuples are immutable objects.
matches = list(matcher.match(descriptors1, descriptors2, None))
# Sort matches by score
# Sorting the matches based on hamming distance
matches.sort(key = lambda x: x.distance, reverse = False)
# Remove not so good matches
numGoodMatches = int(0.1*len(matches))
matches = matches[:numGoodMatches]
# matches = matches[:len(matches)//10]
# Initialize arrays to store Keypoint locations
# float32 used for compatibility with cv2.findHomography()
points1 = np.zeros((len(matches), 2), dtype = np.float32)
points2 = np.zeros((len(matches), 2), dtype = np.float32)
# Extract location of good matches
for i, match in enumerate(matches):
points1[i, :] = keypoints1[match.queryIdx].pt
points2[i, :] = keypoints2[match.trainIdx].pt
# Find homography
h, mask = cv2.findHomography(points2, points1, cv2.RANSAC)
height, width, channels = img1.shape
# Warp img2 to align with img1
img2_reg = cv2.warpPerspective(img2, h, (width, height))
# For brevity, assume img2_reg is computed as in original
# img2_reg = targ_img # Placeholder; replace with actual aligned image
region_data = []
for region in regions:
x, y, width, height = (
region['shape_attributes']['x'],
region['shape_attributes']['y'],
region['shape_attributes']['width'],
region['shape_attributes']['height']
)
name = (
f"{region['region_attributes']['parent']}_"
f"{region['region_attributes']['key']}_"
f"{region['region_attributes'].get('group', '')}"
)
name_type = region['region_attributes']['type']
region_data.append({"x": x, "y": y, "width": width, "height": height, "name": name, "type": name_type})
for region in region_data:
x, y, width, height = region["x"], region["y"], region["width"], region["height"]
cropped_image = img2_reg[y:y+height, x:x+width]
if region["type"] == "check":
region["page"] = f"page_{j}"
check_regions.append((region, cropped_image))
else:
cedge = cv2.Canny(cropped_image[7:-7, 7:-7], 100, 200)
cex_ = cedge.astype(float).sum(axis=0) / 255
cey_ = cedge.astype(float).sum(axis=1) / 255
cex_ = np.count_nonzero(cex_ > 5)
cey_ = np.count_nonzero(cey_ > 5)
if cex_ > 7 and cey_ > 7:
im = Image.fromarray(center_pad_image(cropped_image))
region["page"] = f"page_{j}"
ocr_regions.append((region, im))
else:
region["status"] = "not found"
region["page"] = f"page_{j}"
blank_regions.append(region)
j += 1
# Process check regions
check_region_data = []
for check_region in check_regions:
check_region[0]["status"] = predict_check(check_region[1])
check_region_data.append(check_region[0])
# Process OCR regions
region_data = []
# for batch in make_batch(ocr_regions, batch_size):
# i = 0
# print(task_id,"_s_",i)
print("Retrieving images")
ocr_images = [data[1] for data in ocr_regions]
print("Images retrieved")
print("Sending request vllm")
response = await batch_predict_ext_async_vllm(ocr_images)
print("Request completed")
#print(response)
extracted_texts = response["extracted_texts"]
print("Text Extracted")
for text, region in zip(extracted_texts, ocr_regions):
region[0]["status"] = text
region_data.append(region[0])
# print(task_id,"_c_",i)
# i += 1
print("text appended")
# Combine and store results
region_data.extend(check_region_data)
print("Check region data appended")
region_data.extend(blank_regions)
print("Blank region data appended")
string_data = json.dumps(region_data)
end_time = datetime.now()
time_elapsed_str = str(end_time - start_time)
print(time_elapsed_str)
os.remove(file_name)
conn = sqlite3.connect('/mnt/data/translations.db')
cursor = conn.cursor()
cursor.execute(
'UPDATE OCR SET region = ?, time_elapsed = ?, status=?, updated_at = ? WHERE task_id = ?',
(string_data, time_elapsed_str, "completed", datetime.now(), task_id)
)
conn.commit()
conn.close()
print("SUCCESSFUL vllm")
except Exception as e:
print(f"OCR Failed vllm: {e}")
try:
conn = sqlite3.connect('/mnt/data/translations.db')
cursor = conn.cursor()
cursor.execute('UPDATE OCR SET status = ? WHERE task_id = ?', ("failed", task_id))
conn.commit()
conn.close()
except Exception as exec:
print(f"Updating vllm status to database failed: {exec}")
async def batch_ocr_ext_async(file_name, task_id,batch_size):
try:
start_time = datetime.now()
check_regions = []
ocr_regions = []
blank_regions = []
j = 0
for k, v in data['_via_img_metadata'].items():
regions = data['_via_img_metadata'][k]['regions']
file = file_name
if file.endswith("pdf"):
targ_img = np.array(convert_from_path(file)[j].convert("L"))
else:
targ_img = cv2.imread(file, cv2.IMREAD_GRAYSCALE)
# ORB and alignment code remains unchanged for brevity
# Used for feature detection and image matching
# Possible to optimize?
MAX_NUM_FEATURES = 10000
orb = cv2.ORB_create(MAX_NUM_FEATURES)
# Load the blank form of j-th page
orig_img = np.array(Image.open(f"ocr/VGG Image Annotator_files/mach_bank_form_page{j}.jpg").convert("L"))
# Detects keypoints (corner-like features) in orig_img and targ_img.
# and computes descriptors, which are binary feature representations for each keypoint.
keypoints1, descriptors1 = orb.detectAndCompute(orig_img, None)
keypoints2, descriptors2 = orb.detectAndCompute(targ_img, None)
# ORB typically works on grayscale images.
# Converts images back to BGR for displaying colored keypoints.
# just for visualization or any other use-case?
img1 = cv2.cvtColor(orig_img, cv2.COLOR_GRAY2BGR)
img2 = cv2.cvtColor(targ_img, cv2.COLOR_GRAY2BGR)
# Match features.
# ORB uses binary descriptors, and Hamming distance counts the number of differing bits.
# Faster than Euclidean distance for binary descriptors.
matcher = cv2.DescriptorMatcher_create(cv2.DESCRIPTOR_MATCHER_BRUTEFORCE_HAMMING)
# match() finds the best match for each descriptor in descriptor1 and descriptor2
# matches stores a list of cv2.DMatch objects where
# .queryIdx --> index of the keypoint in orig_img
# .trainIdx --> index of matching keypoint in trag_img
# .distance --> Hamming distance
# Converting to list for sorting as tuples are immutable objects.
matches = list(matcher.match(descriptors1, descriptors2, None))
# Sort matches by score
# Sorting the matches based on hamming distance
matches.sort(key = lambda x: x.distance, reverse = False)
# Remove not so good matches
numGoodMatches = int(0.1*len(matches))
matches = matches[:numGoodMatches]
# matches = matches[:len(matches)//10]
# Initialize arrays to store Keypoint locations
# float32 used for compatibility with cv2.findHomography()
points1 = np.zeros((len(matches), 2), dtype = np.float32)
points2 = np.zeros((len(matches), 2), dtype = np.float32)
# Extract location of good matches
for i, match in enumerate(matches):
points1[i, :] = keypoints1[match.queryIdx].pt
points2[i, :] = keypoints2[match.trainIdx].pt
# Find homography
h, mask = cv2.findHomography(points2, points1, cv2.RANSAC)
height, width, channels = img1.shape
# Warp img2 to align with img1
img2_reg = cv2.warpPerspective(img2, h, (width, height))
# For brevity, assume img2_reg is computed as in original
# img2_reg = targ_img # Placeholder; replace with actual aligned image
region_data = []
for region in regions:
x, y, width, height = (
region['shape_attributes']['x'],
region['shape_attributes']['y'],
region['shape_attributes']['width'],
region['shape_attributes']['height']
)
name = (
f"{region['region_attributes']['parent']}_"
f"{region['region_attributes']['key']}_"
f"{region['region_attributes'].get('group', '')}"
)
name_type = region['region_attributes']['type']
region_data.append({"x": x, "y": y, "width": width, "height": height, "name": name, "type": name_type})
for region in region_data:
x, y, width, height = region["x"], region["y"], region["width"], region["height"]
cropped_image = img2_reg[y:y+height, x:x+width]
if region["type"] == "check":
region["page"] = f"page_{j}"
check_regions.append((region, cropped_image))
else:
cedge = cv2.Canny(cropped_image[7:-7, 7:-7], 100, 200)
cex_ = cedge.astype(float).sum(axis=0) / 255
cey_ = cedge.astype(float).sum(axis=1) / 255
cex_ = np.count_nonzero(cex_ > 5)
cey_ = np.count_nonzero(cey_ > 5)
if cex_ > 7 and cey_ > 7:
im = Image.fromarray(center_pad_image(cropped_image))
region["page"] = f"page_{j}"
ocr_regions.append((region, im))
else:
region["status"] = "not found"
region["page"] = f"page_{j}"
blank_regions.append(region)
j += 1
# Process check regions
check_region_data = []
for check_region in check_regions:
check_region[0]["status"] = predict_check(check_region[1])
check_region_data.append(check_region[0])
# Process OCR regions
region_data = []
# for batch in make_batch(ocr_regions, batch_size):
# i = 0
# print(task_id,"_s_",i)
print("Retrieving images")
ocr_images = [data[1] for data in ocr_regions]
print("Images retrieved")
print("Sending request")
response = await batch_predict_ext_async(ocr_images,batch_size)
print("Request completed")
print(response)
extracted_texts = response["extracted_texts"]
print("Text Extracted")
for text, region in zip(extracted_texts, ocr_regions):
region[0]["status"] = text
region_data.append(region[0])
# print(task_id,"_c_",i)
# i += 1
print("text appended")
# Combine and store results
region_data.extend(check_region_data)
print("Check region data appended")
region_data.extend(blank_regions)
print("Blank region data appended")
string_data = json.dumps(check_region_data)
end_time = datetime.now()
time_elapsed_str = str(end_time - start_time)
print(time_elapsed_str)
os.remove(file_name)
conn = sqlite3.connect('/mnt/data/translations.db')
cursor = conn.cursor()
cursor.execute(
'UPDATE OCR SET region = ?, time_elapsed = ?, status=?, updated_at = ? WHERE task_id = ?',
(string_data, time_elapsed_str, "completed", datetime.now(), task_id)
)
conn.commit()
conn.close()
print("SUCCESSFUL")
except Exception as e:
print(f"OCR Failed: {e}")
try:
conn = sqlite3.connect('/mnt/data/translations.db')
cursor = conn.cursor()
cursor.execute('UPDATE OCR SET status = ? WHERE task_id = ?', ("failed", task_id))
conn.commit()
conn.close()
except Exception as exec:
print(f"Updating status to database failed: {exec}")
async def batch_ocr_ext_async_stack(file_name, task_id,batch_size,stack_size):
try:
start_time = datetime.now()
check_regions = []
ocr_regions = []
blank_regions = []
j = 0
for k, v in data['_via_img_metadata'].items():
regions = data['_via_img_metadata'][k]['regions']
file = file_name
if file.endswith("pdf"):
targ_img = np.array(convert_from_path(file)[j].convert("L"))
else:
targ_img = cv2.imread(file, cv2.IMREAD_GRAYSCALE)
# ORB and alignment code remains unchanged for brevity
# Used for feature detection and image matching
# Possible to optimize?
MAX_NUM_FEATURES = 10000
orb = cv2.ORB_create(MAX_NUM_FEATURES)
# Load the blank form of j-th page
orig_img = np.array(Image.open(f"ocr/VGG Image Annotator_files/mach_bank_form_page{j}.jpg").convert("L"))
# Detects keypoints (corner-like features) in orig_img and targ_img.
# and computes descriptors, which are binary feature representations for each keypoint.
keypoints1, descriptors1 = orb.detectAndCompute(orig_img, None)
keypoints2, descriptors2 = orb.detectAndCompute(targ_img, None)
# ORB typically works on grayscale images.
# Converts images back to BGR for displaying colored keypoints.
# just for visualization or any other use-case?
img1 = cv2.cvtColor(orig_img, cv2.COLOR_GRAY2BGR)
img2 = cv2.cvtColor(targ_img, cv2.COLOR_GRAY2BGR)
# Match features.
# ORB uses binary descriptors, and Hamming distance counts the number of differing bits.
# Faster than Euclidean distance for binary descriptors.
matcher = cv2.DescriptorMatcher_create(cv2.DESCRIPTOR_MATCHER_BRUTEFORCE_HAMMING)
# match() finds the best match for each descriptor in descriptor1 and descriptor2
# matches stores a list of cv2.DMatch objects where
# .queryIdx --> index of the keypoint in orig_img
# .trainIdx --> index of matching keypoint in trag_img
# .distance --> Hamming distance
# Converting to list for sorting as tuples are immutable objects.
matches = list(matcher.match(descriptors1, descriptors2, None))
# Sort matches by score
# Sorting the matches based on hamming distance
matches.sort(key = lambda x: x.distance, reverse = False)
# Remove not so good matches
numGoodMatches = int(0.1*len(matches))
matches = matches[:numGoodMatches]
# matches = matches[:len(matches)//10]
# Initialize arrays to store Keypoint locations
# float32 used for compatibility with cv2.findHomography()
points1 = np.zeros((len(matches), 2), dtype = np.float32)
points2 = np.zeros((len(matches), 2), dtype = np.float32)
# Extract location of good matches
for i, match in enumerate(matches):
points1[i, :] = keypoints1[match.queryIdx].pt
points2[i, :] = keypoints2[match.trainIdx].pt
# Find homography
h, mask = cv2.findHomography(points2, points1, cv2.RANSAC)
height, width, channels = img1.shape
# Warp img2 to align with img1
img2_reg = cv2.warpPerspective(img2, h, (width, height))
# For brevity, assume img2_reg is computed as in original
# img2_reg = targ_img # Placeholder; replace with actual aligned image
region_data = []
for region in regions:
x, y, width, height = (
region['shape_attributes']['x'],
region['shape_attributes']['y'],
region['shape_attributes']['width'],
region['shape_attributes']['height']
)
name = (
f"{region['region_attributes']['parent']}_"
f"{region['region_attributes']['key']}_"
f"{region['region_attributes'].get('group', '')}"
)
name_type = region['region_attributes']['type']
region_data.append({"x": x, "y": y, "width": width, "height": height, "name": name, "type": name_type})
ocr_buffer=[]
buffer_metadata = []
for region in region_data:
x, y, width, height = region["x"], region["y"], region["width"], region["height"]
cropped_image = img2_reg[y:y+height, x:x+width]
if region["type"] == "check":
region["page"] = f"page_{j}"
check_regions.append((region, cropped_image))
else:
cedge = cv2.Canny(cropped_image[7:-7, 7:-7], 100, 200)
cex_ = cedge.astype(float).sum(axis=0) / 255
cey_ = cedge.astype(float).sum(axis=1) / 255
cex_ = np.count_nonzero(cex_ > 5)
cey_ = np.count_nonzero(cey_ > 5)
if cex_ > 7 and cey_ > 7:
ocr_buffer.append(cropped_image)
region["page"] = f"page_{j}"
buffer_metadata.append(region)
# stack_images_vertically
# im = Image.fromarray(center_pad_image(cropped_image))
if len(ocr_buffer) == stack_size:
img = Image.fromarray(stack_images_vertically(stack_size,ocr_buffer))
ocr_regions.append((buffer_metadata, img))
ocr_buffer = []
buffer_metadata = []
else:
region["status"] = "not found"
region["page"] = f"page_{j}"
blank_regions.append(region)
j += 1
# Insert any remaining images
if len(ocr_buffer) > 0:
img = Image.fromarray(stack_images_vertically(len(ocr_buffer),ocr_buffer))
ocr_regions.append((buffer_metadata, img))
# Process check regions
check_region_data = []
for check_region in check_regions:
check_region[0]["status"] = predict_check(check_region[1])
check_region_data.append(check_region[0])
# Process OCR regions
region_data = []
# for batch in make_batch(ocr_regions, batch_size):
# i = 0
# print(task_id,"_s_",i)
print("Retrieving images")
ocr_images = [data[1] for data in ocr_regions]
print("Images retrieved")
print("Sending request")
response = await batch_predict_ext_async(ocr_images,batch_size)
print("Request completed")
print(response)
extracted_texts = response["extracted_texts"]
print("Text Extracted")
for text, region in zip(extracted_texts, ocr_regions):
splitted_text = text.split("AAA")
for single_region,single_text in zip(region[0],splitted_text):
single_region["status"] = single_text
region_data.append(single_region)
print("text appended")
# print(region_data)
# Combine and store results
region_data.extend(check_region_data)
print("Check region data appended")
region_data.extend(blank_regions)
print("Blank region data appended")
string_data = json.dumps(region_data)
end_time = datetime.now()
time_elapsed_str = str(end_time - start_time)
print(time_elapsed_str)
os.remove(file_name)
conn = sqlite3.connect('/mnt/data/translations.db')
cursor = conn.cursor()
cursor.execute(
'UPDATE OCR SET region = ?, time_elapsed = ?, status=?, updated_at = ? WHERE task_id = ?',
(string_data, time_elapsed_str, "completed", datetime.now(), task_id)
)
conn.commit()
conn.close()
print("SUCCESSFUL")
except Exception as e:
print(f"OCR Failed: {e}")
try:
conn = sqlite3.connect('/mnt/data/translations.db') # For local
cursor = conn.cursor()
cursor.execute('UPDATE OCR SET status = ? WHERE task_id = ?', ("failed", task_id))
conn.commit()
conn.close()
except Exception as exec:
print(f"Updating status to database failed: {exec}")
def batch_ocr_ext(file_name, task_id, batch_size):
try:
start_time = datetime.now()
check_regions = []
ocr_regions = []
blank_regions = []
j = 0
for k, v in data['_via_img_metadata'].items():
regions = data['_via_img_metadata'][k]['regions']
file = file_name
if file.endswith("pdf"):
targ_img = np.array(convert_from_path(file)[j].convert("L"))
else:
targ_img = cv2.imread(file, cv2.IMREAD_GRAYSCALE)
# ORB and alignment code remains unchanged for brevity
# Used for feature detection and image matching
# Possible to optimize?
MAX_NUM_FEATURES = 10000
orb = cv2.ORB_create(MAX_NUM_FEATURES)
# Load the blank form of j-th page
orig_img = np.array(Image.open(f"ocr/VGG Image Annotator_files/mach_bank_form_page{j}.jpg").convert("L"))
# Detects keypoints (corner-like features) in orig_img and targ_img.
# and computes descriptors, which are binary feature representations for each keypoint.
keypoints1, descriptors1 = orb.detectAndCompute(orig_img, None)
keypoints2, descriptors2 = orb.detectAndCompute(targ_img, None)
# ORB typically works on grayscale images.
# Converts images back to BGR for displaying colored keypoints.
# just for visualization or any other use-case?
img1 = cv2.cvtColor(orig_img, cv2.COLOR_GRAY2BGR)
img2 = cv2.cvtColor(targ_img, cv2.COLOR_GRAY2BGR)
# Match features.
# ORB uses binary descriptors, and Hamming distance counts the number of differing bits.
# Faster than Euclidean distance for binary descriptors.
matcher = cv2.DescriptorMatcher_create(cv2.DESCRIPTOR_MATCHER_BRUTEFORCE_HAMMING)
# match() finds the best match for each descriptor in descriptor1 and descriptor2
# matches stores a list of cv2.DMatch objects where
# .queryIdx --> index of the keypoint in orig_img
# .trainIdx --> index of matching keypoint in trag_img
# .distance --> Hamming distance
# Converting to list for sorting as tuples are immutable objects.
matches = list(matcher.match(descriptors1, descriptors2, None))
# Sort matches by score
# Sorting the matches based on hamming distance
matches.sort(key = lambda x: x.distance, reverse = False)
# Remove not so good matches
numGoodMatches = int(0.1*len(matches))
matches = matches[:numGoodMatches]
# matches = matches[:len(matches)//10]
# Initialize arrays to store Keypoint locations
# float32 used for compatibility with cv2.findHomography()
points1 = np.zeros((len(matches), 2), dtype = np.float32)
points2 = np.zeros((len(matches), 2), dtype = np.float32)
# Extract location of good matches
for i, match in enumerate(matches):
points1[i, :] = keypoints1[match.queryIdx].pt
points2[i, :] = keypoints2[match.trainIdx].pt
# Find homography
h, mask = cv2.findHomography(points2, points1, cv2.RANSAC)
height, width, channels = img1.shape
# Warp img2 to align with img1
img2_reg = cv2.warpPerspective(img2, h, (width, height))
# For brevity, assume img2_reg is computed as in original
# img2_reg = targ_img # Placeholder; replace with actual aligned image
region_data = []
for region in regions:
x, y, width, height = (
region['shape_attributes']['x'],
region['shape_attributes']['y'],
region['shape_attributes']['width'],
region['shape_attributes']['height']
)
name = (
f"{region['region_attributes']['parent']}_"
f"{region['region_attributes']['key']}_"
f"{region['region_attributes'].get('group', '')}"
)
name_type = region['region_attributes']['type']
region_data.append({"x": x, "y": y, "width": width, "height": height, "name": name, "type": name_type})
for region in region_data:
x, y, width, height = region["x"], region["y"], region["width"], region["height"]
cropped_image = img2_reg[y:y+height, x:x+width]
if region["type"] == "check":
region["page"] = f"page_{j}"
check_regions.append((region, cropped_image))
else:
cedge = cv2.Canny(cropped_image[7:-7, 7:-7], 100, 200)
cex_ = cedge.astype(float).sum(axis=0) / 255
cey_ = cedge.astype(float).sum(axis=1) / 255
cex_ = np.count_nonzero(cex_ > 5)
cey_ = np.count_nonzero(cey_ > 5)
if cex_ > 7 and cey_ > 7:
im = Image.fromarray(center_pad_image(cropped_image))
region["page"] = f"page_{j}"
ocr_regions.append((region, im))
else:
region["status"] = "not found"
region["page"] = f"page_{j}"
blank_regions.append(region)
j += 1
# Process check regions
check_region_data = []
for check_region in check_regions:
check_region[0]["status"] = predict_check(check_region[1])
check_region_data.append(check_region[0])
# Process OCR regions
region_data = []
for batch in make_batch(ocr_regions, batch_size):
i = 0
print(task_id,"_s_",i)
images = [data[1] for data in batch]
response = batch_predict_ext(images)
extracted_texts = response["extracted_texts"]
for text, region in zip(extracted_texts, batch):
region[0]["status"] = text
region_data.append(region[0])
print(task_id,"_c_",i)
i += 1
# Combine and store results
region_data.extend(check_region_data)
region_data.extend(blank_regions)
string_data = json.dumps(region_data)
end_time = datetime.now()
time_elapsed_str = str(end_time - start_time)
print(time_elapsed_str)
os.remove(file_name)
conn = sqlite3.connect('/mnt/data/translations.db') # For local
cursor = conn.cursor()
cursor.execute(
'UPDATE OCR SET region = ?, time_elapsed = ?, status=?, updated_at = ? WHERE task_id = ?',
(string_data, time_elapsed_str, "completed", datetime.now(), task_id)
)
conn.commit()
conn.close()
print("SUCCESSFUL")
except Exception as e:
print(f"OCR Failed: {e}")
try:
conn = sqlite3.connect('/mnt/data/translations.db')
cursor = conn.cursor()
cursor.execute('UPDATE OCR SET status = ? WHERE task_id = ?', ("failed", task_id))
conn.commit()
conn.close()
except Exception as exec:
print(f"Updating status to database failed: {exec}")
# Example call
# batch_ocr_ext("example.pdf", "task123", 10, data)