import os import cv2 import numpy as np from pdf2image import convert_from_path from main import RapidOCR from image_enhancement import enhance_image import gradio as gr import time # Initialize OCR engine once. ocr_engine = RapidOCR() def adaptive_threshold_to_rgb(image_rgb): """ Convert an RGB image to LAB, apply adaptive thresholding only on the L channel, then convert back to RGB. Parameters: image_rgb (numpy.ndarray): Input RGB image. Returns: thresholded_rgb (numpy.ndarray): RGB image after thresholding the L channel. """ image_lab = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2LAB) l_channel, a_channel, b_channel = cv2.split(image_lab) thresholded_l = cv2.adaptiveThreshold( l_channel, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2 ) updated_lab = cv2.merge((thresholded_l, a_channel, b_channel)) return cv2.cvtColor(updated_lab, cv2.COLOR_LAB2RGB) def ocr_detect(image, ocr_engine): """ Run OCR on the image and check for two consecutive rows that contain the '<' character. Parameters: image (numpy.ndarray): Input image. ocr_engine: OCR engine instance. Returns: detected (bool): True if found, else False. row1 (str): The first detected row with '<'. row2 (str): The second detected row with '<'. """ result, _ = ocr_engine(image, use_det=True, use_cls=False, use_rec=True) if result: test_list = [r[1] for r in result] for j in range(len(test_list) - 1): count1 = test_list[j].count("<") count2 = test_list[j + 1].count("<") if count1 > 1 and count2 > 1: return True, test_list[j], test_list[j + 1] return False, None, None def rotate_until_detect(image, ocr_engine, max_attempts=4): """ Rotate the image 90° clockwise up to max_attempts times until OCR returns two consecutive rows that meet the specified criteria. Parameters: image (numpy.ndarray): Input image. ocr_engine: OCR engine instance. max_attempts (int): Maximum number of rotations. Returns: image (numpy.ndarray): Final rotated image. detected (bool): True if OCR detection succeeded. row1, row2 (str, str): The two detected rows (if found; otherwise None). """ for attempt in range(max_attempts): detected, row1, row2 = ocr_detect(image, ocr_engine) if detected: return image, True, row1, row2 image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE) return image, False, None, None def process_pdf(pdf_f, ocr_engine, enhance_params): """ Process a single PDF file by converting pages, enhancing images, and attempting OCR detections. A PDF is considered successful if at least one page yields two consecutive rows detected. Returns the (row1, row2) pair on success. Parameters: pdf_f (str): File path of the PDF. ocr_engine: The OCR engine instance. enhance_params (dict): Parameters for image enhancement. Returns: (pdf_success, detected_rows): pdf_success (bool): True if detection succeeded in any page. detected_rows (tuple): (row1, row2) from the successful page, or (None, None) if not. """ images = convert_from_path(pdf_f, dpi=300, first_page=1, last_page=3) pdf_success = False detected_rows = (None, None) for pil_image in images: img = np.array(pil_image) img = enhance_image(img, enhance_params, verbose=False) img = np.uint8(img * 255.) _, detected, row1, row2 = rotate_until_detect(img, ocr_engine) if detected: pdf_success = True detected_rows = (row1, row2) break else: adaptive_img = adaptive_threshold_to_rgb(img) _, detected, row1, row2 = rotate_until_detect(adaptive_img, ocr_engine) if detected: pdf_success = True detected_rows = (row1, row2) break return pdf_success, detected_rows # def main(): # # Define the folder containing PDFs. # # dataPath = '/home/tung/Tung_Works/OCR_code/OCR-20250423T073748Z-001/OCR/OCR辨識失敗-部分樣本' # dataPath = 'C:/Users/Duy/Downloads/passport/' # result_file = os.path.join(dataPath,'results.txt') # list_pdf = [ # os.path.join(root, file) # for root, _, files in os.walk(dataPath) # for file in files if file.endswith('.pdf') # ] # enhance_params = { # 'local_contrast': 1.2, 'mid_tones': 0.5, 'tonal_width': 0.5, 'areas_dark': 0.7, # 'areas_bright': 0.5, 'brightness': 0.1, 'saturation_degree': 1.2, # 'preserve_tones': True, 'color_correction': True, # } # # Open the result file for writing # with open(result_file, 'w') as f: # for pdf_f in list_pdf: # pdf_name = os.path.basename(pdf_f) # print(f"Processing {pdf_f}...") # success, detected_rows = process_pdf(pdf_f, ocr_engine, enhance_params) # if success: # f.write(f"--- PDF: {pdf_name} ---\n") # f.write("Success\n") # f.write(f"Row 1: {detected_rows[0]}\n") # f.write(f"Row 2: {detected_rows[1]}\n\n") # print(f"Success: {pdf_name}") # print("Row 1:", detected_rows[0]) # print("Row 2:", detected_rows[1]) # else: # f.write(f"--- PDF: {pdf_name} ---\n") # f.write("No successful detection\n\n") # print(f"No detection: {pdf_name}") # print(f"Results written to {result_file}") def handle_file_upload(file_bytes): enhance_params = { 'local_contrast': 1.2, 'mid_tones': 0.5, 'tonal_width': 0.5, 'areas_dark': 0.7, 'areas_bright': 0.5, 'brightness': 0.1, 'saturation_degree': 1.2, 'preserve_tones': True, 'color_correction': True, } # print(f"Processing uploaded file: {file_path}") current_dir = os.path.dirname(os.path.abspath(__file__)) # 2. Tạo thư mục tmp nếu chưa tồn tại tmp_dir = os.path.join(current_dir, "tmp") os.makedirs(tmp_dir, exist_ok=True) timestamp = int(time.time()) save_path = os.path.join(tmp_dir, f"uploaded_{timestamp}.pdf") # 4. Save binary thành file PDF with open(save_path, "wb") as f: f.write(file_bytes) pdf_success, detected_rows = process_pdf(save_path, ocr_engine, enhance_params) return detected_rows if pdf_success else ("Error", "Error") if __name__ == '__main__': demo = gr.Interface( fn=handle_file_upload, inputs=gr.File(type="binary", file_types=[".pdf"], label="Select your PDF"), outputs=[ gr.Textbox(label="Row 1"), gr.Textbox(label="Row 2"), ], title="PDF Information Extractor", description="Upload a PDF file to get basic information.", allow_flagging="never" ) demo.launch(share=True)