App / demo_v2.py
LinhKL2002's picture
Upload folder using huggingface_hub
4dbe5d1 verified
import os
import cv2
import numpy as np
from pdf2image import convert_from_path
from main import RapidOCR
ocr_engine = RapidOCR()
dataPath = '/home/tung/Tung_Works/OCR_code/OCR-20250423T073748Z-001/OCR/OCR辨識失敗-部分樣本'
from image_enhancement import enhance_image
list_pdf = []
for root, dirs, files in os.walk(dataPath):
for file in files:
if file.endswith('.pdf'):
pdf_f = os.path.join(root, file)
assert os.path.exists(pdf_f)
list_pdf.append(pdf_f)
sorted(list_pdf)
def adaptive_threshold_to_rgb(image_rgb):
"""
Apply adaptive thresholding on the L channel of LAB color space
and reconstruct the thresholded image as RGB.
Parameters:
image_rgb (numpy.ndarray): Input RGB image.
Returns:
thresholded_rgb (numpy.ndarray): RGB image after thresholding the L channel.
"""
# Convert RGB to LAB color space
image_lab = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2LAB)
# Split LAB channels
l_channel, a_channel, b_channel = cv2.split(image_lab)
# Apply adaptive thresholding to the L channel
thresholded_l = cv2.adaptiveThreshold(
l_channel,
maxValue=255,
adaptiveMethod=cv2.ADAPTIVE_THRESH_GAUSSIAN_C, # or ADAPTIVE_THRESH_MEAN_C
thresholdType=cv2.THRESH_BINARY,
blockSize=11,
C=2
)
# Merge thresholded L channel back with original A and B channels
updated_lab = cv2.merge((thresholded_l, a_channel, b_channel))
# Convert LAB back to RGB
thresholded_rgb = cv2.cvtColor(updated_lab, cv2.COLOR_LAB2RGB)
return thresholded_rgb
for idx, pdf_f in enumerate(list_pdf):
bs_name = os.path.basename(pdf_f)
bs_name_0 = os.path.splitext(bs_name)[0]
# images = convert_from_path(pdf_f, dpi=900)
images = convert_from_path(pdf_f, dpi=300, first_page=1, last_page=3)
for i, image in enumerate(images):
img = np.array(image)
print(img.shape)
parameters = {}
parameters['local_contrast'] = 1.2 # 1.2x increase in details
parameters['mid_tones'] = 0.5 # middle of range
parameters['tonal_width'] = 0.5 # middle of range
parameters['areas_dark'] = 0.7 # 70% improvement in dark areas
parameters['areas_bright'] = 0.5 # 50% improvement in bright areas
parameters['brightness'] = 0.1 # slight increase in overall brightness
parameters['saturation_degree'] = 1.2 # 1.2x increase in color saturation
parameters['preserve_tones'] = True
parameters['color_correction'] = True
img = enhance_image(img, parameters, verbose=False)
#print(img.shape, img.dtype, img.max(), img.min())
img = np.uint8(img*255.)
enhanced_img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # Save in OpenCV-compatible format
cv2.imwrite(f'{bs_name_0}_{i + 1}.jpg', enhanced_img_bgr)
print(bs_name_0, i )
rotation_attempts = 0 # Track rotation count
while rotation_attempts < 4: # Rotate at most 4 times (90°, 180°, 270°, and back to original orientation)
result, _ = ocr_engine(img, use_det=True, use_cls=False, use_rec=True)
detected = False # Flag to check detection status
if result:
test_list = [r[1] for r in result]
#print(test_list[-5:])
for j in range(len(test_list) - 1): # Loop up to the second-to-last row
count1 = test_list[j].count("<")
count2 = test_list[j + 1].count("<")
if count1 > 1 and count2 > 1:
print(bs_name_0)
print(f"Consecutive rows with '<' more than 2 times each:")
print(f"Row 1: {test_list[j]} (Occurrences: {count1})")
print(f"Row 2: {test_list[j + 1]} (Occurrences: {count2})")
detected = True
break
if detected:
break # Stop further rotation since rows are detected
# Rotate the image by 90 degrees
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
rotation_attempts += 1
if not detected:
img = adaptive_threshold_to_rgb(img)
rotation_attempts = 0 # Track rotation count
while rotation_attempts < 4: # Rotate at most 4 times (90°, 180°, 270°, and back to original orientation)
result, _ = ocr_engine(img, use_det=True, use_cls=False, use_rec=True)
detected = False # Flag to check detection status
if result:
test_list = [r[1] for r in result]
#print(test_list[-5:])
for j in range(len(test_list) - 1): # Loop up to the second-to-last row
count1 = test_list[j].count("<")
count2 = test_list[j + 1].count("<")
if count1 > 1 and count2 > 1:
print(bs_name_0)
print(f"Consecutive rows with '<' more than 2 times each:")
print(f"Row 1: {test_list[j]} (Occurrences: {count1})")
print(f"Row 2: {test_list[j + 1]} (Occurrences: {count2})")
detected = True
break
if detected:
break # Stop further rotation since rows are detected
# Rotate the image by 90 degrees
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
rotation_attempts += 1