App / demo_v3.py
LinhKL2002's picture
Upload folder using huggingface_hub
4dbe5d1 verified
import os
import cv2
import numpy as np
from pdf2image import convert_from_path
from main import RapidOCR
from image_enhancement import enhance_image
# Initialize OCR engine once.
ocr_engine = RapidOCR()
def adaptive_threshold_to_rgb(image_rgb):
"""
Apply adaptive thresholding on the L channel of LAB color space
and reconstruct the thresholded image as RGB.
Parameters:
image_rgb (numpy.ndarray): Input RGB image.
Returns:
thresholded_rgb (numpy.ndarray): RGB image after thresholding the L channel.
"""
# Convert RGB to LAB color space and split channels.
image_lab = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2LAB)
l_channel, a_channel, b_channel = cv2.split(image_lab)
# Apply adaptive thresholding to the L channel.
thresholded_l = cv2.adaptiveThreshold(
l_channel,
maxValue=255,
adaptiveMethod=cv2.ADAPTIVE_THRESH_GAUSSIAN_C, # or ADAPTIVE_THRESH_MEAN_C
thresholdType=cv2.THRESH_BINARY,
blockSize=11,
C=2
)
# Merge the thresholded L channel back with A and B channels.
updated_lab = cv2.merge((thresholded_l, a_channel, b_channel))
thresholded_rgb = cv2.cvtColor(updated_lab, cv2.COLOR_LAB2RGB)
return thresholded_rgb
def ocr_detect(image, ocr_engine):
"""
Run OCR on the image using the provided ocr_engine and check if consecutive
rows containing the '<' character are detected.
Parameters:
image (numpy.ndarray): Input image.
ocr_engine: The OCR engine instance.
Returns:
detected (bool): True if the desired pattern is detected, False otherwise.
"""
result, _ = ocr_engine(image, use_det=True, use_cls=False, use_rec=True)
if result:
test_list = [r[1] for r in result]
for j in range(len(test_list) - 1):
count1 = test_list[j].count("<")
count2 = test_list[j + 1].count("<")
if count1 > 1 and count2 > 1:
return True
return False
def rotate_until_detect(image, ocr_engine, max_attempts=4):
"""
Rotate the image 90° clockwise, up to max_attempts times, until the OCR
conveys the expected result.
Parameters:
image (numpy.ndarray): Input image.
ocr_engine: The OCR engine instance.
max_attempts (int): Maximum number of rotations.
Returns:
image (numpy.ndarray): Rotated image with detection (or final rotation if undetected).
detected (bool): Whether the expected OCR pattern was detected.
"""
attempt = 0
detected = False
while attempt < max_attempts:
if ocr_detect(image, ocr_engine):
detected = True
break
image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
attempt += 1
return image, detected
def process_pdf(pdf_f, ocr_engine, enhance_params):
"""
Process a single PDF file by converting pages, enhancing images,
running OCR with rotation, and using adaptive thresholding as a fallback.
Parameters:
pdf_f (str): PDF file path.
ocr_engine: The OCR engine instance.
enhance_params (dict): Parameters for the image enhancement.
"""
# Convert specified pages of PDF into images.
images = convert_from_path(pdf_f, dpi=300, first_page=1, last_page=3)
bs_name = os.path.basename(pdf_f)
bs_name_0 = os.path.splitext(bs_name)[0]
for i, pil_image in enumerate(images):
# Convert PIL image to a NumPy array.
img = np.array(pil_image)
print("Original image shape:", img.shape)
# Enhance the image.
img = enhance_image(img, enhance_params, verbose=False)
img = np.uint8(img * 255.)
# Save the enhanced image as a reference.
enhanced_img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
output_filename = f'{bs_name_0}_{i + 1}.jpg'
cv2.imwrite(output_filename, enhanced_img_bgr)
print(f"Saved enhanced image: {output_filename}")
# First: Try OCR on the enhanced image with rotation.
processed_img, detected = rotate_until_detect(img, ocr_engine)
if detected:
print(f"OCR success on {output_filename} with enhanced image rotation.")
else:
# Second: Apply adaptive thresholding and re-run OCR with rotation.
print(f"No OCR detection from enhanced image. Applying adaptive thresholding for {output_filename}.")
adaptive_img = adaptive_threshold_to_rgb(img)
processed_img, detected = rotate_until_detect(adaptive_img, ocr_engine)
if detected:
print(f"OCR success on {output_filename} with adaptive thresholding and rotation.")
else:
print(f"OCR detection failed for {output_filename} after fallback.")
def main():
# Set the data path and gather list of PDF files.
dataPath = '/home/tung/Tung_Works/OCR_code/OCR-20250423T073748Z-001/OCR/OCR辨識失敗-部分樣本'
list_pdf = [
os.path.join(root, file)
for root, _, files in os.walk(dataPath)
for file in files if file.endswith('.pdf')
]
# Optionally, sort the list.
# list_pdf = sorted(list_pdf)
# Define image enhancement parameters (applied to every image).
enhance_params = {
'local_contrast': 1.2, # 1.2x increase in details
'mid_tones': 0.5, # middle of range
'tonal_width': 0.5, # middle of range
'areas_dark': 0.7, # 70% improvement in dark areas
'areas_bright': 0.5, # 50% improvement in bright areas
'brightness': 0.1, # slight increase in overall brightness
'saturation_degree': 1.2, # 1.2x increase in color saturation
'preserve_tones': True,
'color_correction': True,
}
# Process each PDF.
for pdf_f in list_pdf:
process_pdf(pdf_f, ocr_engine, enhance_params)
if __name__ == '__main__':
main()