App / demo_v4.py
LinhKL2002's picture
Upload folder using huggingface_hub
4dbe5d1 verified
import os
import cv2
import numpy as np
from pdf2image import convert_from_path
from main import RapidOCR
from image_enhancement import enhance_image
# Initialize OCR engine once.
ocr_engine = RapidOCR()
def adaptive_threshold_to_rgb(image_rgb):
"""
Convert an RGB image to LAB, apply adaptive thresholding only on the L channel,
then convert back to RGB.
Parameters:
image_rgb (numpy.ndarray): Input RGB image.
Returns:
thresholded_rgb (numpy.ndarray): RGB image after thresholding the L channel.
"""
# Convert RGB to LAB color space.
image_lab = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2LAB)
l_channel, a_channel, b_channel = cv2.split(image_lab)
# Adaptive thresholding on the L channel.
thresholded_l = cv2.adaptiveThreshold(
l_channel,
maxValue=255,
adaptiveMethod=cv2.ADAPTIVE_THRESH_GAUSSIAN_C, # or ADAPTIVE_THRESH_MEAN_C
thresholdType=cv2.THRESH_BINARY,
blockSize=11,
C=2
)
# Merge the thresholded L channel with original A and B, then convert back to RGB.
updated_lab = cv2.merge((thresholded_l, a_channel, b_channel))
thresholded_rgb = cv2.cvtColor(updated_lab, cv2.COLOR_LAB2RGB)
return thresholded_rgb
def ocr_detect(image, ocr_engine):
"""
Run OCR on the image and check for two consecutive rows that contain the '<' character.
Parameters:
image (numpy.ndarray): Input image.
ocr_engine: OCR engine instance.
Returns:
detected (bool): True if found, else False.
row1 (str): The first detected row with '<'.
row2 (str): The second detected row with '<'.
"""
result, _ = ocr_engine(image, use_det=True, use_cls=False, use_rec=True)
if result:
# Get recognized strings
test_list = [r[1] for r in result]
for j in range(len(test_list) - 1):
count1 = test_list[j].count("<")
count2 = test_list[j + 1].count("<")
if count1 > 1 and count2 > 1:
return True, test_list[j], test_list[j + 1]
return False, None, None
def rotate_until_detect(image, ocr_engine, max_attempts=4):
"""
Rotate the image 90° clockwise up to max_attempts times until OCR returns
two consecutive rows that meet the specified criteria.
Parameters:
image (numpy.ndarray): Input image.
ocr_engine: OCR engine instance.
max_attempts (int): Maximum number of rotations.
Returns:
image (numpy.ndarray): Final rotated image.
detected (bool): True if OCR detection succeeded.
row1, row2 (str, str): The two detected rows (if found; otherwise None).
"""
attempt = 0
detected = False
row1, row2 = None, None
while attempt < max_attempts:
detected, row1, row2 = ocr_detect(image, ocr_engine)
if detected:
break
image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
attempt += 1
return image, detected, row1, row2
def process_pdf(pdf_f, ocr_engine, enhance_params, save_images=False):
"""
Process a single PDF file by converting a range of pages, enhancing images,
and attempting OCR detections. A PDF is considered successful if at least one page
yields two consecutive rows detected. Returns the (row1, row2) pair on success.
Parameters:
pdf_f (str): File path of the PDF.
ocr_engine: The OCR engine instance.
enhance_params (dict): Parameters for image enhancement.
save_images (bool): If True, save intermediate enhanced images (default: False).
Returns:
(pdf_success, detected_rows):
pdf_success (bool): True if detection succeeded in any page.
detected_rows (tuple): (row1, row2) from the successful page, or (None, None) if not.
"""
images = convert_from_path(pdf_f, dpi=300, first_page=1, last_page=3)
bs_name = os.path.basename(pdf_f)
bs_name_0 = os.path.splitext(bs_name)[0]
pdf_success = False
detected_rows = (None, None)
for i, pil_image in enumerate(images):
# Convert the PIL image to a NumPy array.
img = np.array(pil_image)
# print(f"Processing page {i + 1} of {bs_name}")
# Enhance the image.
img = enhance_image(img, enhance_params, verbose=False)
img = np.uint8(img * 255.)
# Optionally save the enhanced image.
if save_images:
enhanced_img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imwrite(f'{bs_name_0}_{i + 1}.jpg', enhanced_img_bgr)
# Attempt OCR on the enhanced image (with rotations).
proc_img, detected, row1, row2 = rotate_until_detect(img, ocr_engine)
if detected:
# print(f"OCR detection succeeded on page {i + 1} of {bs_name}")
pdf_success = True
detected_rows = (row1, row2)
break
else:
# Fallback: perform adaptive thresholding then try OCR.
# print(f"No detection on page {i + 1} of {bs_name}. Trying adaptive thresholding.")
adaptive_img = adaptive_threshold_to_rgb(img)
proc_img, detected, row1, row2 = rotate_until_detect(adaptive_img, ocr_engine)
if detected:
# print(f"OCR detection (via adaptive thresholding) succeeded on page {i + 1} of {bs_name}")
pdf_success = True
detected_rows = (row1, row2)
break
else:
print(f"OCR detection failed on page {i + 1} of {bs_name}.")
if pdf_success:
print(f"PDF file {bs_name_0} processed successfully.")
else:
print(f"PDF file {bs_name_0} did NOT yield a successful OCR detection.")
return pdf_success, detected_rows
def main():
# Define the folder containing PDFs.
dataPath = '/home/tung/Tung_Works/OCR_code/OCR-20250423T073748Z-001/OCR/OCR辨識失敗-部分樣本'
list_pdf = [
os.path.join(root, file)
for root, _, files in os.walk(dataPath)
for file in files if file.endswith('.pdf')
]
# Define image enhancement parameters.
enhance_params = {
'local_contrast': 1.2, # 1.2x increase in detail
'mid_tones': 0.5, # middle range
'tonal_width': 0.5, # middle range
'areas_dark': 0.7, # 70% improvement in dark areas
'areas_bright': 0.5, # 50% improvement in bright areas
'brightness': 0.1, # slight increase in overall brightness
'saturation_degree': 1.2, # 1.2x increase in color saturation
'preserve_tones': True,
'color_correction': True,
}
# Process each PDF and collect results.
for pdf_f in list_pdf:
print("")
print(f"--- Processing PDF: {pdf_f} ---")
success, detected_rows = process_pdf(pdf_f, ocr_engine, enhance_params, save_images=False)
if success:
# print("\nSuccess in detecting two rows for this PDF:")
print("PDF:", os.path.basename(pdf_f))
print("Row 1:", detected_rows[0])
print("Row 2:", detected_rows[1])
else:
print("No successful detection for this PDF.")
if __name__ == '__main__':
main()