App / demo_v5.py
LinhKL2002's picture
Upload folder using huggingface_hub
4dbe5d1 verified
import os
import cv2
import numpy as np
from pdf2image import convert_from_path
from main import RapidOCR
from image_enhancement import enhance_image
import gradio as gr
import time
# Initialize OCR engine once.
ocr_engine = RapidOCR()
def adaptive_threshold_to_rgb(image_rgb):
"""
Convert an RGB image to LAB, apply adaptive thresholding only on the L channel,
then convert back to RGB.
Parameters:
image_rgb (numpy.ndarray): Input RGB image.
Returns:
thresholded_rgb (numpy.ndarray): RGB image after thresholding the L channel.
"""
image_lab = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2LAB)
l_channel, a_channel, b_channel = cv2.split(image_lab)
thresholded_l = cv2.adaptiveThreshold(
l_channel, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2
)
updated_lab = cv2.merge((thresholded_l, a_channel, b_channel))
return cv2.cvtColor(updated_lab, cv2.COLOR_LAB2RGB)
def ocr_detect(image, ocr_engine):
"""
Run OCR on the image and check for two consecutive rows that contain the '<' character.
Parameters:
image (numpy.ndarray): Input image.
ocr_engine: OCR engine instance.
Returns:
detected (bool): True if found, else False.
row1 (str): The first detected row with '<'.
row2 (str): The second detected row with '<'.
"""
result, _ = ocr_engine(image, use_det=True, use_cls=False, use_rec=True)
if result:
test_list = [r[1] for r in result]
for j in range(len(test_list) - 1):
count1 = test_list[j].count("<")
count2 = test_list[j + 1].count("<")
if count1 > 1 and count2 > 1:
return True, test_list[j], test_list[j + 1]
return False, None, None
def rotate_until_detect(image, ocr_engine, max_attempts=4):
"""
Rotate the image 90° clockwise up to max_attempts times until OCR returns
two consecutive rows that meet the specified criteria.
Parameters:
image (numpy.ndarray): Input image.
ocr_engine: OCR engine instance.
max_attempts (int): Maximum number of rotations.
Returns:
image (numpy.ndarray): Final rotated image.
detected (bool): True if OCR detection succeeded.
row1, row2 (str, str): The two detected rows (if found; otherwise None).
"""
for attempt in range(max_attempts):
detected, row1, row2 = ocr_detect(image, ocr_engine)
if detected:
return image, True, row1, row2
image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
return image, False, None, None
def process_pdf(pdf_f, ocr_engine, enhance_params):
"""
Process a single PDF file by converting pages, enhancing images,
and attempting OCR detections. A PDF is considered successful if at least one page
yields two consecutive rows detected. Returns the (row1, row2) pair on success.
Parameters:
pdf_f (str): File path of the PDF.
ocr_engine: The OCR engine instance.
enhance_params (dict): Parameters for image enhancement.
Returns:
(pdf_success, detected_rows):
pdf_success (bool): True if detection succeeded in any page.
detected_rows (tuple): (row1, row2) from the successful page, or (None, None) if not.
"""
images = convert_from_path(pdf_f, dpi=300, first_page=1, last_page=3)
pdf_success = False
detected_rows = (None, None)
for pil_image in images:
img = np.array(pil_image)
img = enhance_image(img, enhance_params, verbose=False)
img = np.uint8(img * 255.)
_, detected, row1, row2 = rotate_until_detect(img, ocr_engine)
if detected:
pdf_success = True
detected_rows = (row1, row2)
break
else:
adaptive_img = adaptive_threshold_to_rgb(img)
_, detected, row1, row2 = rotate_until_detect(adaptive_img, ocr_engine)
if detected:
pdf_success = True
detected_rows = (row1, row2)
break
return pdf_success, detected_rows
# def main():
# # Define the folder containing PDFs.
# # dataPath = '/home/tung/Tung_Works/OCR_code/OCR-20250423T073748Z-001/OCR/OCR辨識失敗-部分樣本'
# dataPath = 'C:/Users/Duy/Downloads/passport/'
# result_file = os.path.join(dataPath,'results.txt')
# list_pdf = [
# os.path.join(root, file)
# for root, _, files in os.walk(dataPath)
# for file in files if file.endswith('.pdf')
# ]
# enhance_params = {
# 'local_contrast': 1.2, 'mid_tones': 0.5, 'tonal_width': 0.5, 'areas_dark': 0.7,
# 'areas_bright': 0.5, 'brightness': 0.1, 'saturation_degree': 1.2,
# 'preserve_tones': True, 'color_correction': True,
# }
# # Open the result file for writing
# with open(result_file, 'w') as f:
# for pdf_f in list_pdf:
# pdf_name = os.path.basename(pdf_f)
# print(f"Processing {pdf_f}...")
# success, detected_rows = process_pdf(pdf_f, ocr_engine, enhance_params)
# if success:
# f.write(f"--- PDF: {pdf_name} ---\n")
# f.write("Success\n")
# f.write(f"Row 1: {detected_rows[0]}\n")
# f.write(f"Row 2: {detected_rows[1]}\n\n")
# print(f"Success: {pdf_name}")
# print("Row 1:", detected_rows[0])
# print("Row 2:", detected_rows[1])
# else:
# f.write(f"--- PDF: {pdf_name} ---\n")
# f.write("No successful detection\n\n")
# print(f"No detection: {pdf_name}")
# print(f"Results written to {result_file}")
def handle_file_upload(file_bytes):
enhance_params = {
'local_contrast': 1.2, 'mid_tones': 0.5, 'tonal_width': 0.5, 'areas_dark': 0.7,
'areas_bright': 0.5, 'brightness': 0.1, 'saturation_degree': 1.2,
'preserve_tones': True, 'color_correction': True,
}
# print(f"Processing uploaded file: {file_path}")
current_dir = os.path.dirname(os.path.abspath(__file__))
# 2. Tạo thư mục tmp nếu chưa tồn tại
tmp_dir = os.path.join(current_dir, "tmp")
os.makedirs(tmp_dir, exist_ok=True)
timestamp = int(time.time())
save_path = os.path.join(tmp_dir, f"uploaded_{timestamp}.pdf")
# 4. Save binary thành file PDF
with open(save_path, "wb") as f:
f.write(file_bytes)
pdf_success, detected_rows = process_pdf(save_path, ocr_engine, enhance_params)
return detected_rows if pdf_success else ("Error", "Error")
if __name__ == '__main__':
demo = gr.Interface(
fn=handle_file_upload,
inputs=gr.File(type="binary", file_types=[".pdf"], label="Select your PDF"),
outputs=[
gr.Textbox(label="Row 1"),
gr.Textbox(label="Row 2"),
],
title="PDF Information Extractor",
description="Upload a PDF file to get basic information.",
allow_flagging="never"
)
demo.launch(share=True)