Spaces:
Running
Running
| import fitz # PyMuPDF | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import torch.serialization | |
| _original_torch_load = torch.load | |
| def patched_torch_load(*args, **kwargs): | |
| # FORCE classic behavior | |
| kwargs["weights_only"] = False | |
| return _original_torch_load(*args, **kwargs) | |
| torch.load = patched_torch_load | |
| import json | |
| import argparse | |
| import os | |
| import re | |
| import torch.nn as nn | |
| from TorchCRF import CRF | |
| # from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model, LayoutLMv3Config | |
| from transformers import LayoutLMv3Tokenizer, LayoutLMv3Model, LayoutLMv3Config | |
| from typing import List, Dict, Any, Optional, Union, Tuple | |
| from ultralytics import YOLO | |
| import glob | |
| import pytesseract | |
| from PIL import Image | |
| from scipy.signal import find_peaks | |
| from scipy.ndimage import gaussian_filter1d | |
| import sys | |
| import io | |
| import base64 | |
| import tempfile | |
| import time | |
| import shutil | |
| from sklearn.feature_extraction.text import CountVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import logging | |
| from transformers import TrOCRProcessor | |
| from optimum.onnxruntime import ORTModelForVision2Seq | |
| # ============================================================================ | |
| # --- TR-OCR/ORT MODEL INITIALIZATION --- | |
| # ============================================================================ | |
| logging.basicConfig(level=logging.WARNING) | |
| processor = None | |
| ort_model = None | |
| try: | |
| MODEL_NAME = 'breezedeus/pix2text-mfr-1.5' | |
| processor = TrOCRProcessor.from_pretrained(MODEL_NAME) | |
| # Initialize the model for ONNX Runtime | |
| # NOTE: Set use_cache=False to avoid caching warnings/issues if reloading | |
| ort_model = ORTModelForVision2Seq.from_pretrained(MODEL_NAME, use_cache=False) | |
| print("β ORTModelForVision2Seq and TrOCRProcessor initialized successfully for equation conversion.") | |
| except Exception as e: | |
| print(f"β Error initializing TrOCR/ORT model. Equations will not be converted: {e}") | |
| processor = None | |
| ort_model = None | |
| from typing import Optional | |
| def sanitize_text(text: Optional[str]) -> str: | |
| """Removes surrogate characters and other invalid code points that cause UTF-8 encoding errors.""" | |
| if not isinstance(text, str) or text is None: | |
| return "" | |
| # Matches all surrogates (\ud800-\udfff) and common non-characters (\ufffe, \uffff). | |
| # This specifically removes '\udefd' which is causing your error. | |
| surrogates_and_nonchars = re.compile(r'[\ud800-\udfff\ufffe\uffff]') | |
| # Replace the invalid characters with a standard space. | |
| # We strip afterward in the calling function. | |
| return surrogates_and_nonchars.sub(' ', text) | |
| def get_latex_from_base64(base64_string: str) -> str: | |
| """ | |
| Decodes a Base64 image string and uses the pre-initialized TrOCR/ORT model | |
| to recognize the formula. It cleans the output by removing spaces and | |
| crucially, replacing double backslashes with single backslashes for correct LaTeX. | |
| """ | |
| if ort_model is None or processor is None: | |
| return "[MODEL_ERROR: Model not initialized]" | |
| try: | |
| # 1. Decode Base64 to Image | |
| image_data = base64.b64decode(base64_string) | |
| # We must ensure the image is RGB format for the model input | |
| image = Image.open(io.BytesIO(image_data)).convert('RGB') | |
| # 2. Preprocess the image | |
| pixel_values = processor(images=image, return_tensors="pt").pixel_values | |
| # 3. Text Generation (OCR) | |
| generated_ids = ort_model.generate(pixel_values) | |
| raw_generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
| if not raw_generated_text: | |
| return "[OCR_WARNING: No formula found]" | |
| latex_string = raw_generated_text[0] | |
| # --- 4. Post-processing and Cleanup --- | |
| # # A. Remove all spaces/line breaks | |
| # cleaned_latex = re.sub(r'\s+', '', latex_string) | |
| cleaned_latex = re.sub(r'[\r\n]+', '', latex_string) | |
| # B. CRITICAL FIX: Replace double backslashes (\\) with single backslashes (\). | |
| # This corrects model output that already over-escaped the LaTeX commands. | |
| # Python literal: '\\\\' is replaced with '\\'. | |
| #cleaned_latex = cleaned_latex.replace('\\\\', '\\') | |
| return cleaned_latex | |
| except Exception as e: | |
| # Catch any unexpected errors | |
| print(f" β TR-OCR Recognition failed: {e}") | |
| return f"[TR_OCR_ERROR: Recognition failed: {e}]" | |
| # ============================================================================ | |
| # --- CONFIGURATION AND CONSTANTS --- | |
| # ============================================================================ | |
| # NOTE: Update these paths to match your environment before running! | |
| WEIGHTS_PATH = 'best.pt' | |
| DEFAULT_LAYOUTLMV3_MODEL_PATH = "98.pth" | |
| # DIRECTORY CONFIGURATION | |
| OCR_JSON_OUTPUT_DIR = './ocr_json_output_final' | |
| FIGURE_EXTRACTION_DIR = './figure_extraction' | |
| TEMP_IMAGE_DIR = './temp_pdf_images' | |
| # Detection parameters | |
| CONF_THRESHOLD = 0.2 | |
| TARGET_CLASSES = ['figure', 'equation'] | |
| IOU_MERGE_THRESHOLD = 0.4 | |
| IOA_SUPPRESSION_THRESHOLD = 0.7 | |
| LINE_TOLERANCE = 15 | |
| # Similarity | |
| SIMILARITY_THRESHOLD = 0.10 | |
| RESOLUTION_MARGIN = 0.05 | |
| # Global counters for sequential numbering across the entire PDF | |
| GLOBAL_FIGURE_COUNT = 0 | |
| GLOBAL_EQUATION_COUNT = 0 | |
| # LayoutLMv3 Labels | |
| ID_TO_LABEL = { | |
| 0: "O", | |
| 1: "B-QUESTION", 2: "I-QUESTION", | |
| 3: "B-OPTION", 4: "I-OPTION", | |
| 5: "B-ANSWER", 6: "I-ANSWER", | |
| 7: "B-SECTION_HEADING", 8: "I-SECTION_HEADING", | |
| 9: "B-PASSAGE", 10: "I-PASSAGE" | |
| } | |
| NUM_LABELS = len(ID_TO_LABEL) | |
| # ============================================================================ | |
| # --- PERFORMANCE OPTIMIZATION: OCR CACHE --- | |
| # ============================================================================ | |
| class OCRCache: | |
| """Caches OCR results per page to avoid redundant Tesseract runs.""" | |
| def __init__(self): | |
| self.cache = {} | |
| def get_key(self, pdf_path: str, page_num: int) -> str: | |
| return f"{pdf_path}:{page_num}" | |
| def has_ocr(self, pdf_path: str, page_num: int) -> bool: | |
| return self.get_key(pdf_path, page_num) in self.cache | |
| def get_ocr(self, pdf_path: str, page_num: int) -> Optional[list]: | |
| return self.cache.get(self.get_key(pdf_path, page_num)) | |
| def set_ocr(self, pdf_path: str, page_num: int, ocr_data: list): | |
| self.cache[self.get_key(pdf_path, page_num)] = ocr_data | |
| def clear(self): | |
| self.cache.clear() | |
| # Global OCR cache instance | |
| _ocr_cache = OCRCache() | |
| # ============================================================================ | |
| # --- PHASE 1: YOLO/OCR PREPROCESSING FUNCTIONS --- | |
| # ============================================================================ | |
| def calculate_iou(box1, box2): | |
| x1_a, y1_a, x2_a, y2_a = box1 | |
| x1_b, y1_b, x2_b, y2_b = box2 | |
| x_left = max(x1_a, x1_b) | |
| y_top = max(y1_a, y1_b) | |
| x_right = min(x2_a, x2_b) | |
| y_bottom = min(y2_a, y2_b) | |
| intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top) | |
| box_a_area = (x2_a - x1_a) * (y2_a - y1_a) | |
| box_b_area = (x2_b - x1_b) * (y2_b - y1_b) | |
| union_area = float(box_a_area + box_b_area - intersection_area) | |
| return intersection_area / union_area if union_area > 0 else 0 | |
| def calculate_ioa(box1, box2): | |
| x1_a, y1_a, x2_a, y2_a = box1 | |
| x1_b, y1_b, x2_b, y2_b = box2 | |
| x_left = max(x1_a, x1_b) | |
| y_top = max(y1_a, y1_b) | |
| x_right = min(x2_a, x2_b) | |
| y_bottom = min(y2_a, y2_b) | |
| intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top) | |
| box_a_area = (x2_a - x1_a) * (y2_a - y1_a) | |
| return intersection_area / box_a_area if box_a_area > 0 else 0 | |
| def filter_nested_boxes(detections, ioa_threshold=0.80): | |
| """ | |
| Removes boxes that are inside larger boxes (Containment Check). | |
| Prioritizes keeping the LARGEST box (the 'parent' container). | |
| """ | |
| if not detections: | |
| return [] | |
| # 1. Calculate Area for all detections | |
| for d in detections: | |
| x1, y1, x2, y2 = d['coords'] | |
| d['area'] = (x2 - x1) * (y2 - y1) | |
| # 2. Sort by Area Descending (Largest to Smallest) | |
| # This ensures we process the 'container' first | |
| detections.sort(key=lambda x: x['area'], reverse=True) | |
| keep_indices = [] | |
| is_suppressed = [False] * len(detections) | |
| for i in range(len(detections)): | |
| if is_suppressed[i]: continue | |
| keep_indices.append(i) | |
| box_a = detections[i]['coords'] | |
| # Compare with all smaller boxes | |
| for j in range(i + 1, len(detections)): | |
| if is_suppressed[j]: continue | |
| box_b = detections[j]['coords'] | |
| # Calculate Intersection | |
| x_left = max(box_a[0], box_b[0]) | |
| y_top = max(box_a[1], box_b[1]) | |
| x_right = min(box_a[2], box_b[2]) | |
| y_bottom = min(box_a[3], box_b[3]) | |
| if x_right < x_left or y_bottom < y_top: | |
| intersection = 0 | |
| else: | |
| intersection = (x_right - x_left) * (y_bottom - y_top) | |
| # Calculate IoA (Intersection over Area of the SMALLER box) | |
| area_b = detections[j]['area'] | |
| if area_b > 0: | |
| ioa_small = intersection / area_b | |
| # If the small box is > 90% inside the big box, suppress the small one. | |
| if ioa_small > ioa_threshold: | |
| is_suppressed[j] = True | |
| # print(f" [Suppress] Removed nested object inside larger '{detections[i]['class']}'") | |
| return [detections[i] for i in keep_indices] | |
| def merge_overlapping_boxes(detections, iou_threshold): | |
| if not detections: return [] | |
| detections.sort(key=lambda d: d['conf'], reverse=True) | |
| merged_detections = [] | |
| is_merged = [False] * len(detections) | |
| for i in range(len(detections)): | |
| if is_merged[i]: continue | |
| current_box = detections[i]['coords'] | |
| current_class = detections[i]['class'] | |
| merged_x1, merged_y1, merged_x2, merged_y2 = current_box | |
| for j in range(i + 1, len(detections)): | |
| if is_merged[j] or detections[j]['class'] != current_class: continue | |
| other_box = detections[j]['coords'] | |
| iou = calculate_iou(current_box, other_box) | |
| if iou > iou_threshold: | |
| merged_x1 = min(merged_x1, other_box[0]) | |
| merged_y1 = min(merged_y1, other_box[1]) | |
| merged_x2 = max(merged_x2, other_box[2]) | |
| merged_y2 = max(merged_y2, other_box[3]) | |
| is_merged[j] = True | |
| merged_detections.append({ | |
| 'coords': (merged_x1, merged_y1, merged_x2, merged_y2), | |
| 'y1': merged_y1, 'class': current_class, 'conf': detections[i]['conf'] | |
| }) | |
| return merged_detections | |
| def merge_yolo_into_word_data(raw_word_data: list, yolo_detections: list, scale_factor: float) -> list: | |
| """ | |
| Filters out raw words that are inside YOLO boxes and replaces them with | |
| a single solid 'placeholder' block for the column detector. | |
| """ | |
| if not yolo_detections: | |
| return raw_word_data | |
| # 1. Convert YOLO boxes (Pixels) to PDF Coordinates (Points) | |
| pdf_space_boxes = [] | |
| for det in yolo_detections: | |
| x1, y1, x2, y2 = det['coords'] | |
| pdf_box = ( | |
| x1 / scale_factor, | |
| y1 / scale_factor, | |
| x2 / scale_factor, | |
| y2 / scale_factor | |
| ) | |
| pdf_space_boxes.append(pdf_box) | |
| # 2. Filter out raw words that are inside YOLO boxes | |
| cleaned_word_data = [] | |
| for word_tuple in raw_word_data: | |
| wx1, wy1, wx2, wy2 = word_tuple[1], word_tuple[2], word_tuple[3], word_tuple[4] | |
| w_center_x = (wx1 + wx2) / 2 | |
| w_center_y = (wy1 + wy2) / 2 | |
| is_inside_yolo = False | |
| for px1, py1, px2, py2 in pdf_space_boxes: | |
| if px1 <= w_center_x <= px2 and py1 <= w_center_y <= py2: | |
| is_inside_yolo = True | |
| break | |
| if not is_inside_yolo: | |
| cleaned_word_data.append(word_tuple) | |
| # 3. Add the YOLO boxes themselves as "Solid Words" | |
| for i, (px1, py1, px2, py2) in enumerate(pdf_space_boxes): | |
| dummy_entry = (f"BLOCK_{i}", px1, py1, px2, py2) | |
| cleaned_word_data.append(dummy_entry) | |
| return cleaned_word_data | |
| # ============================================================================ | |
| # --- MISSING HELPER FUNCTION --- | |
| # ============================================================================ | |
| def preprocess_image_for_ocr(img_np): | |
| """ | |
| Converts image to grayscale and applies Otsu's Binarization | |
| to separate text from background clearly. | |
| """ | |
| # 1. Convert to Grayscale if needed | |
| if len(img_np.shape) == 3: | |
| gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) | |
| else: | |
| gray = img_np | |
| # 2. Apply Otsu's Thresholding (Automatic binary threshold) | |
| # This makes text solid black and background solid white | |
| _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| return thresh | |
| def calculate_vertical_gap_coverage(word_data: list, sep_x: int, page_height: float, gutter_width: int = 10) -> float: | |
| """ | |
| Calculates what percentage of the page's vertical text span is 'cleanly split' by the separator. | |
| A valid column split should split > 65% of the page verticality. | |
| """ | |
| if not word_data: | |
| return 0.0 | |
| # Determine the vertical span of the actual text content | |
| y_coords = [w[2] for w in word_data] + [w[4] for w in word_data] # y1 and y2 | |
| min_y, max_y = min(y_coords), max(y_coords) | |
| total_text_height = max_y - min_y | |
| if total_text_height <= 0: | |
| return 0.0 | |
| # Create a boolean array representing the Y-axis (1 pixel per unit) | |
| gap_open_mask = np.ones(int(total_text_height) + 1, dtype=bool) | |
| zone_left = sep_x - (gutter_width / 2) | |
| zone_right = sep_x + (gutter_width / 2) | |
| offset_y = int(min_y) | |
| for _, x1, y1, x2, y2 in word_data: | |
| # Check if this word horizontally interferes with the separator | |
| if x2 > zone_left and x1 < zone_right: | |
| y_start_idx = max(0, int(y1) - offset_y) | |
| y_end_idx = min(len(gap_open_mask), int(y2) - offset_y) | |
| if y_end_idx > y_start_idx: | |
| gap_open_mask[y_start_idx:y_end_idx] = False | |
| open_pixels = np.sum(gap_open_mask) | |
| coverage_ratio = open_pixels / len(gap_open_mask) | |
| return coverage_ratio | |
| def calculate_x_gutters(word_data: list, params: Dict, page_height: float) -> List[int]: | |
| """ | |
| Calculates X-axis histogram and validates using BRIDGING DENSITY and Vertical Coverage. | |
| """ | |
| if not word_data: return [] | |
| x_points = [] | |
| # Use only word_data elements 1 (x1) and 3 (x2) | |
| for item in word_data: | |
| x_points.extend([item[1], item[3]]) | |
| if not x_points: return [] | |
| max_x = max(x_points) | |
| # 1. Determine total text height for ratio calculation | |
| y_coords = [item[2] for item in word_data] + [item[4] for item in word_data] | |
| min_y, max_y = min(y_coords), max(y_coords) | |
| total_text_height = max_y - min_y | |
| if total_text_height <= 0: return [] | |
| # Histogram Setup | |
| bin_size = params.get('cluster_bin_size', 5) | |
| smoothing = params.get('cluster_smoothing', 1) | |
| min_width = params.get('cluster_min_width', 20) | |
| threshold_percentile = params.get('cluster_threshold_percentile', 85) | |
| num_bins = int(np.ceil(max_x / bin_size)) | |
| hist, bin_edges = np.histogram(x_points, bins=num_bins, range=(0, max_x)) | |
| smoothed_hist = gaussian_filter1d(hist.astype(float), sigma=smoothing) | |
| inverted_signal = np.max(smoothed_hist) - smoothed_hist | |
| peaks, properties = find_peaks( | |
| inverted_signal, | |
| height=np.max(inverted_signal) - np.percentile(smoothed_hist, threshold_percentile), | |
| distance=min_width / bin_size | |
| ) | |
| if not peaks.size: return [] | |
| separator_x_coords = [int(bin_edges[p]) for p in peaks] | |
| final_separators = [] | |
| for x_coord in separator_x_coords: | |
| # --- CHECK 1: BRIDGING DENSITY (The "Cut Through" Check) --- | |
| # Calculate the total vertical height of words that physically cross this line. | |
| bridging_height = 0 | |
| bridging_count = 0 | |
| for item in word_data: | |
| wx1, wy1, wx2, wy2 = item[1], item[2], item[3], item[4] | |
| # Check if this word physically sits on top of the separator line | |
| if wx1 < x_coord and wx2 > x_coord: | |
| word_h = wy2 - wy1 | |
| bridging_height += word_h | |
| bridging_count += 1 | |
| # Calculate Ratio: How much of the page's text height is blocked by these crossing words? | |
| bridging_ratio = bridging_height / total_text_height | |
| # THRESHOLD: If bridging blocks > 8% of page height, REJECT. | |
| # This allows for page numbers or headers (usually < 5%) to cross, but NOT paragraphs. | |
| if bridging_ratio > 0.08: | |
| print( | |
| f" β Separator X={x_coord} REJECTED: Bridging Ratio {bridging_ratio:.1%} (>15%) cuts through text.") | |
| continue | |
| # --- CHECK 2: VERTICAL GAP COVERAGE (The "Clean Split" Check) --- | |
| # The gap must exist cleanly for > 65% of the text height. | |
| coverage = calculate_vertical_gap_coverage(word_data, x_coord, page_height, gutter_width=min_width) | |
| if coverage >= 0.80: | |
| final_separators.append(x_coord) | |
| print(f" -> Separator X={x_coord} ACCEPTED (Coverage: {coverage:.1%}, Bridging: {bridging_ratio:.1%})") | |
| else: | |
| print(f" β Separator X={x_coord} REJECTED (Coverage: {coverage:.1%}, Bridging: {bridging_ratio:.1%})") | |
| return sorted(final_separators) | |
| def get_word_data_for_detection(page: fitz.Page, pdf_path: str, page_num: int, | |
| top_margin_percent=0.10, bottom_margin_percent=0.10) -> list: | |
| """Extract word data with OCR caching to avoid redundant Tesseract runs.""" | |
| word_data = page.get_text("words") | |
| if len(word_data) > 0: | |
| word_data = [(w[4], w[0], w[1], w[2], w[3]) for w in word_data] | |
| else: | |
| if _ocr_cache.has_ocr(pdf_path, page_num): | |
| word_data = _ocr_cache.get_ocr(pdf_path, page_num) | |
| else: | |
| try: | |
| # --- OPTIMIZATION START --- | |
| # 1. Render at Higher Resolution (Zoom 4.0 = ~300 DPI) | |
| zoom_level = 4.0 | |
| pix = page.get_pixmap(matrix=fitz.Matrix(zoom_level, zoom_level)) | |
| # 2. Convert directly to OpenCV format (Faster than PIL) | |
| img_np = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n) | |
| if pix.n == 3: | |
| img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) | |
| elif pix.n == 4: | |
| img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2BGR) | |
| # 3. Apply Preprocessing (Thresholding) | |
| processed_img = preprocess_image_for_ocr(img_np) | |
| # 4. Optimized Tesseract Config | |
| # --psm 6: Assume a single uniform block of text (Great for columns/questions) | |
| # --oem 3: Default engine (LSTM) | |
| custom_config = r'--oem 3 --psm 6' | |
| data = pytesseract.image_to_data(processed_img, output_type=pytesseract.Output.DICT, | |
| config=custom_config) | |
| full_word_data = [] | |
| for i in range(len(data['level'])): | |
| text = data['text'][i].strip() | |
| if text: | |
| # Scale coordinates back to PDF points | |
| x1 = data['left'][i] / zoom_level | |
| y1 = data['top'][i] / zoom_level | |
| x2 = (data['left'][i] + data['width'][i]) / zoom_level | |
| y2 = (data['top'][i] + data['height'][i]) / zoom_level | |
| full_word_data.append((text, x1, y1, x2, y2)) | |
| word_data = full_word_data | |
| _ocr_cache.set_ocr(pdf_path, page_num, word_data) | |
| # --- OPTIMIZATION END --- | |
| except Exception as e: | |
| print(f" β OCR Error in detection phase: {e}") | |
| return [] | |
| # Apply margin filtering | |
| page_height = page.rect.height | |
| y_min = page_height * top_margin_percent | |
| y_max = page_height * (1 - bottom_margin_percent) | |
| return [d for d in word_data if d[2] >= y_min and d[4] <= y_max] | |
| def pixmap_to_numpy(pix: fitz.Pixmap) -> np.ndarray: | |
| img_data = pix.samples | |
| img = np.frombuffer(img_data, dtype=np.uint8).reshape(pix.height, pix.width, pix.n) | |
| if pix.n == 4: | |
| img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR) | |
| elif pix.n == 3: | |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
| return img | |
| def extract_native_words_and_convert(fitz_page, scale_factor: float = 2.0) -> list: | |
| # 1. Get raw data | |
| try: | |
| raw_word_data = fitz_page.get_text("words") | |
| except Exception as e: | |
| print(f" β PyMuPDF extraction failed completely: {e}") | |
| return [] | |
| # ============================================================================== | |
| # --- DEBUGGING BLOCK: CHECK FIRST 50 NATIVE WORDS (SAFE PRINT) --- | |
| # ============================================================================== | |
| print(f"\n[DEBUG] Native Extraction (Page {fitz_page.number + 1}): Checking first 50 words...") | |
| debug_count = 0 | |
| for item in raw_word_data: | |
| if debug_count >= 50: break | |
| word_text = item[4] | |
| # --- SAFE PRINTING LOGIC --- | |
| # We encode/decode to ignore surrogates just for the print statement | |
| # This prevents the "UnicodeEncodeError" that was crashing your script | |
| safe_text = word_text.encode('utf-8', 'ignore').decode('utf-8') | |
| # Get hex codes (handling potential errors in 'ord') | |
| try: | |
| unicode_points = [f"\\u{ord(c):04x}" for c in word_text] | |
| except: | |
| unicode_points = ["ERROR"] | |
| print(f" Word {debug_count}: '{safe_text}' -> Codes: {unicode_points}") | |
| debug_count += 1 | |
| print("----------------------------------------------------------------------\n") | |
| # ============================================================================== | |
| converted_ocr_output = [] | |
| DEFAULT_CONFIDENCE = 99.0 | |
| for x1, y1, x2, y2, word, *rest in raw_word_data: | |
| # --- FIX: ROBUST SANITIZATION --- | |
| # 1. Encode to UTF-8 ignoring errors (strips surrogates) | |
| # 2. Decode back to string | |
| cleaned_word_bytes = word.encode('utf-8', 'ignore') | |
| cleaned_word = cleaned_word_bytes.decode('utf-8') | |
| cleaned_word = word.encode('utf-8', 'ignore').decode('utf-8').strip() | |
| # cleaned_word = cleaned_word.strip() | |
| if not cleaned_word: continue | |
| x1_pix = int(x1 * scale_factor) | |
| y1_pix = int(y1 * scale_factor) | |
| x2_pix = int(x2 * scale_factor) | |
| y2_pix = int(y2 * scale_factor) | |
| converted_ocr_output.append({ | |
| 'type': 'text', | |
| 'word': cleaned_word, | |
| 'confidence': DEFAULT_CONFIDENCE, | |
| 'bbox': [x1_pix, y1_pix, x2_pix, y2_pix], | |
| 'y0': y1_pix, 'x0': x1_pix | |
| }) | |
| return converted_ocr_output | |
| #=================================================================================================== | |
| #=================================================================================================== | |
| #=================================================================================================== | |
| import pandas as pd | |
| import pickle | |
| import os | |
| import time | |
| import json | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| import numpy as np | |
| from collections import defaultdict | |
| # --- Model File Paths (Required for the Classifier to load) --- | |
| VECTORIZER_FILE = 'tfidf_vectorizer_conditional.pkl' | |
| SUBJECT_MODEL_FILE = 'subject_classifier_model_conditional.pkl' | |
| CONDITIONAL_CONCEPT_MODELS_FILE = 'conditional_concept_models.pkl' | |
| # --- Hierarchical Classifier Class (Dependency for the helper function) --- | |
| class HierarchicalClassifier: | |
| """ | |
| A two-stage classification system based on conditional training. | |
| Loads the vectorizer, subject classifier, and conditional concept models. | |
| """ | |
| def __init__(self): | |
| self.vectorizer = None | |
| self.subject_model = None | |
| self.conditional_concept_models = {} | |
| self.is_ready = False | |
| def load_models(self): | |
| """Loads the vectorizer, subject model, and conditional concept models.""" | |
| try: | |
| start_time = time.time() | |
| # 1. Load the TF-IDF Vectorizer | |
| with open(VECTORIZER_FILE, 'rb') as f: | |
| self.vectorizer = pickle.load(f) | |
| # 2. Load the Level 1 (Subject) Classifier | |
| with open(SUBJECT_MODEL_FILE, 'rb') as f: | |
| self.subject_model = pickle.load(f) | |
| # 3. Load the dictionary of conditional Level 2 (Concept) Models | |
| with open(CONDITIONAL_CONCEPT_MODELS_FILE, 'rb') as f: | |
| conditional_data = pickle.load(f) | |
| # Extract just the models for easy access | |
| for subject, data in conditional_data.items(): | |
| self.conditional_concept_models[subject] = data['model'] | |
| print(f"Loaded models successfully in {time.time() - start_time:.2f} seconds.") | |
| self.is_ready = True | |
| except FileNotFoundError as e: | |
| print(f"Error: Required model file not found: {e.filename}.") | |
| self.is_ready = False | |
| except Exception as e: | |
| print(f"An error occurred while loading models: {e}") | |
| self.is_ready = False | |
| return self.is_ready | |
| def predict_subject(self, text_chunk): | |
| """Predicts the top Subject (Level 1).""" | |
| if not self.is_ready: | |
| return "Unknown", 0.0 | |
| # Vectorize the input | |
| text_vector = self.vectorizer.transform([text_chunk]).astype(np.float64) | |
| if hasattr(self.subject_model, 'predict_proba'): | |
| probabilities = self.subject_model.predict_proba(text_vector)[0] | |
| classes = self.subject_model.classes_ | |
| top_index = np.argmax(probabilities) | |
| return classes[top_index], probabilities[top_index] | |
| else: | |
| return self.subject_model.predict(text_vector)[0], 1.0 | |
| def predict_concept_hierarchical(self, text_chunk, predicted_subject): | |
| """ | |
| Predicts the top Concept (Level 2) using the specialized conditional model. | |
| """ | |
| if not self.is_ready: | |
| return "Unknown", 0.0 | |
| concept_model = self.conditional_concept_models.get(predicted_subject) | |
| if concept_model is None or len(getattr(concept_model, 'classes_', [])) <= 1: | |
| return "No_Conditional_Model_Found", 0.0 | |
| # Vectorize the input | |
| text_vector = self.vectorizer.transform([text_chunk]).astype(np.float64) | |
| if hasattr(concept_model, 'predict_proba'): | |
| probabilities = concept_model.predict_proba(text_vector)[0] | |
| classes = concept_model.classes_ | |
| top_index = np.argmax(probabilities) | |
| return classes[top_index], probabilities[top_index] | |
| else: | |
| return concept_model.predict(text_vector)[0], 1.0 | |
| # -------------------------------------------------------------------------------------- | |
| # --- The Requested Helper Function --- | |
| def post_process_json_with_inference(json_data, classifier): | |
| """ | |
| Takes JSON data, runs hierarchical inference on all question/option text, | |
| and adds 'predicted_subject' and 'predicted_concept' tags to each entry. | |
| Args: | |
| json_data (list): The list of dictionaries containing question entries. | |
| classifier (HierarchicalClassifier): An initialized and loaded classifier object. | |
| Returns: | |
| list: The modified list of dictionaries with classification tags added. | |
| """ | |
| if not classifier.is_ready: | |
| print("Classifier not ready. Skipping inference.") | |
| return json_data | |
| # This print statement can be removed for silent pipeline integration | |
| print("\n--- Starting Subject/Concept Detection ---") | |
| for entry in json_data: | |
| # Only process entries that have a 'question' field | |
| if 'question' not in entry: | |
| continue | |
| # 1. Combine Question Text and Option Text for robust feature extraction | |
| full_text = entry.get('question', '') | |
| options = entry.get('options', {}) | |
| for option_key, option_value in options.items(): | |
| # Use the text component of the option if available | |
| option_text = option_value if isinstance(option_value, str) else option_key | |
| full_text += " " + option_text.replace('\n', ' ') | |
| # Clean up text (remove multiple spaces and surrounding whitespace) | |
| full_text = ' '.join(full_text.split()).strip() | |
| # Handle empty text | |
| if not full_text: | |
| entry['predicted_subject'] = {'label': 'Empty_Text', 'confidence': 0.0} | |
| entry['predicted_concept'] = {'label': 'Empty_Text', 'confidence': 0.0} | |
| continue | |
| # 2. STAGE 1: Predict Subject | |
| subj_label, subj_conf = classifier.predict_subject(full_text) | |
| # 3. STAGE 2: Predict Concept (Conditional on predicted subject) | |
| conc_label, conc_conf = classifier.predict_concept_hierarchical(full_text, subj_label) | |
| # 4. Add results to the JSON entry | |
| entry['predicted_subject'] = { | |
| 'label': subj_label, | |
| 'confidence': round(subj_conf, 4) | |
| } | |
| entry['predicted_concept'] = { | |
| 'label': conc_label, | |
| 'confidence': round(conc_conf, 4) | |
| } | |
| # This print statement can be removed for silent pipeline integration | |
| # print("--- JSON Post-Processing Complete ---") | |
| return json_data | |
| #=================================================================================================== | |
| #=================================================================================================== | |
| #=================================================================================================== | |
| def preprocess_and_ocr_page(original_img: np.ndarray, model, pdf_path: str, | |
| page_num: int, fitz_page: fitz.Page, | |
| pdf_name: str) -> Tuple[List[Dict[str, Any]], Optional[int]]: | |
| """ | |
| OPTIMIZED FLOW: | |
| 1. Run YOLO to find Equations/Tables. | |
| 2. Mask raw text with YOLO boxes. | |
| 3. Run Column Detection on the MASKED data. | |
| 4. Proceed with OCR (Native or High-Res Tesseract Fallback) and Output. | |
| """ | |
| global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT | |
| start_time_total = time.time() | |
| if original_img is None: | |
| print(f" β Invalid image for page {page_num}.") | |
| return None, None | |
| # ==================================================================== | |
| # --- STEP 1: YOLO DETECTION --- | |
| # ==================================================================== | |
| start_time_yolo = time.time() | |
| results = model.predict(source=original_img, conf=CONF_THRESHOLD, imgsz=640, verbose=False) | |
| relevant_detections = [] | |
| if results and results[0].boxes: | |
| for box in results[0].boxes: | |
| class_id = int(box.cls[0]) | |
| class_name = model.names[class_id] | |
| if class_name in TARGET_CLASSES: | |
| x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) | |
| relevant_detections.append( | |
| {'coords': (x1, y1, x2, y2), 'y1': y1, 'class': class_name, 'conf': float(box.conf[0])} | |
| ) | |
| merged_detections = merge_overlapping_boxes(relevant_detections, IOU_MERGE_THRESHOLD) | |
| print(f" [LOG] YOLO found {len(merged_detections)} objects in {time.time() - start_time_yolo:.3f}s.") | |
| # ==================================================================== | |
| # --- STEP 2: PREPARE DATA FOR COLUMN DETECTION (MASKING) --- | |
| # ==================================================================== | |
| # Note: This uses the updated 'get_word_data_for_detection' which has its own optimizations | |
| raw_words_for_layout = get_word_data_for_detection( | |
| fitz_page, pdf_path, page_num, | |
| top_margin_percent=0.10, bottom_margin_percent=0.10 | |
| ) | |
| masked_word_data = merge_yolo_into_word_data(raw_words_for_layout, merged_detections, scale_factor=2.0) | |
| # ==================================================================== | |
| # --- STEP 3: COLUMN DETECTION --- | |
| # ==================================================================== | |
| page_width_pdf = fitz_page.rect.width | |
| page_height_pdf = fitz_page.rect.height | |
| column_detection_params = { | |
| 'cluster_bin_size': 2, 'cluster_smoothing': 2, | |
| 'cluster_min_width': 10, 'cluster_threshold_percentile': 85, | |
| } | |
| separators = calculate_x_gutters(masked_word_data, column_detection_params, page_height_pdf) | |
| page_separator_x = None | |
| if separators: | |
| central_min = page_width_pdf * 0.35 | |
| central_max = page_width_pdf * 0.65 | |
| central_separators = [s for s in separators if central_min <= s <= central_max] | |
| if central_separators: | |
| center_x = page_width_pdf / 2 | |
| page_separator_x = min(central_separators, key=lambda x: abs(x - center_x)) | |
| print(f" β Column Split Confirmed at X={page_separator_x:.1f}") | |
| else: | |
| print(" β οΈ Gutter found off-center. Ignoring.") | |
| else: | |
| print(" -> Single Column Layout Confirmed.") | |
| # ==================================================================== | |
| # --- STEP 4: COMPONENT EXTRACTION (Save Images) --- | |
| # ==================================================================== | |
| start_time_components = time.time() | |
| component_metadata = [] | |
| fig_count_page = 0 | |
| eq_count_page = 0 | |
| for detection in merged_detections: | |
| x1, y1, x2, y2 = detection['coords'] | |
| class_name = detection['class'] | |
| if class_name == 'figure': | |
| GLOBAL_FIGURE_COUNT += 1 | |
| counter = GLOBAL_FIGURE_COUNT | |
| component_word = f"FIGURE{counter}" | |
| fig_count_page += 1 | |
| elif class_name == 'equation': | |
| GLOBAL_EQUATION_COUNT += 1 | |
| counter = GLOBAL_EQUATION_COUNT | |
| component_word = f"EQUATION{counter}" | |
| eq_count_page += 1 | |
| else: | |
| continue | |
| component_crop = original_img[y1:y2, x1:x2] | |
| component_filename = f"{pdf_name}_page{page_num}_{class_name}{counter}.png" | |
| cv2.imwrite(os.path.join(FIGURE_EXTRACTION_DIR, component_filename), component_crop) | |
| y_midpoint = (y1 + y2) // 2 | |
| component_metadata.append({ | |
| 'type': class_name, 'word': component_word, | |
| 'bbox': [int(x1), int(y1), int(x2), int(y2)], | |
| 'y0': int(y_midpoint), 'x0': int(x1) | |
| }) | |
| # ==================================================================== | |
| # --- STEP 5: HYBRID OCR (Native Text + Cached Tesseract Fallback) --- | |
| # ==================================================================== | |
| raw_ocr_output = [] | |
| scale_factor = 2.0 # Pipeline standard scale | |
| try: | |
| # Try getting native text first | |
| # NOTE: extract_native_words_and_convert MUST ALSO BE UPDATED TO USE sanitize_text | |
| raw_ocr_output = extract_native_words_and_convert(fitz_page, scale_factor=scale_factor) | |
| except Exception as e: | |
| print(f" β Native text extraction failed: {e}") | |
| # If native text is missing, fall back to OCR | |
| if not raw_ocr_output: | |
| if _ocr_cache.has_ocr(pdf_path, page_num): | |
| print(f" β‘ Using cached Tesseract OCR for page {page_num}") | |
| cached_word_data = _ocr_cache.get_ocr(pdf_path, page_num) | |
| for word_tuple in cached_word_data: | |
| word_text, x1, y1, x2, y2 = word_tuple | |
| # Scale from PDF points to Pipeline Pixels (2.0) | |
| x1_pix = int(x1 * scale_factor) | |
| y1_pix = int(y1 * scale_factor) | |
| x2_pix = int(x2 * scale_factor) | |
| y2_pix = int(y2 * scale_factor) | |
| raw_ocr_output.append({ | |
| 'type': 'text', 'word': word_text, 'confidence': 95.0, | |
| 'bbox': [x1_pix, y1_pix, x2_pix, y2_pix], | |
| 'y0': y1_pix, 'x0': x1_pix | |
| }) | |
| else: | |
| # === START OF OPTIMIZED OCR BLOCK === | |
| try: | |
| # 1. Re-render Page at High Resolution (Zoom 4.0 = ~300 DPI) | |
| ocr_zoom = 4.0 | |
| pix_ocr = fitz_page.get_pixmap(matrix=fitz.Matrix(ocr_zoom, ocr_zoom)) | |
| # Convert PyMuPDF Pixmap to OpenCV format | |
| img_ocr_np = np.frombuffer(pix_ocr.samples, dtype=np.uint8).reshape(pix_ocr.height, pix_ocr.width, | |
| pix_ocr.n) | |
| if pix_ocr.n == 3: | |
| img_ocr_np = cv2.cvtColor(img_ocr_np, cv2.COLOR_RGB2BGR) | |
| elif pix_ocr.n == 4: | |
| img_ocr_np = cv2.cvtColor(img_ocr_np, cv2.COLOR_RGBA2BGR) | |
| # 2. Preprocess (Binarization) | |
| processed_img = preprocess_image_for_ocr(img_ocr_np) | |
| # 3. Run Tesseract with Optimized Configuration | |
| custom_config = r'--oem 3 --psm 6' | |
| hocr_data = pytesseract.image_to_data( | |
| processed_img, | |
| output_type=pytesseract.Output.DICT, | |
| config=custom_config | |
| ) | |
| for i in range(len(hocr_data['level'])): | |
| text = hocr_data['text'][i] # Retrieve raw Tesseract text | |
| # --- FIX: SANITIZE TEXT AND THEN STRIP --- | |
| cleaned_text = sanitize_text(text).strip() | |
| if cleaned_text and hocr_data['conf'][i] > -1: | |
| # 4. Coordinate Mapping | |
| scale_adjustment = scale_factor / ocr_zoom | |
| x1 = int(hocr_data['left'][i] * scale_adjustment) | |
| y1 = int(hocr_data['top'][i] * scale_adjustment) | |
| w = int(hocr_data['width'][i] * scale_adjustment) | |
| h = int(hocr_data['height'][i] * scale_adjustment) | |
| x2 = x1 + w | |
| y2 = y1 + h | |
| raw_ocr_output.append({ | |
| 'type': 'text', | |
| 'word': cleaned_text, # Use the sanitized word | |
| 'confidence': float(hocr_data['conf'][i]), | |
| 'bbox': [x1, y1, x2, y2], | |
| 'y0': y1, | |
| 'x0': x1 | |
| }) | |
| except Exception as e: | |
| print(f" β Tesseract OCR Error: {e}") | |
| # === END OF OPTIMIZED OCR BLOCK === | |
| # ==================================================================== | |
| # --- STEP 6: OCR CLEANING AND MERGING --- | |
| # ==================================================================== | |
| items_to_sort = [] | |
| for ocr_word in raw_ocr_output: | |
| is_suppressed = False | |
| for component in component_metadata: | |
| # Do not include words that are inside figure/equation boxes | |
| ioa = calculate_ioa(ocr_word['bbox'], component['bbox']) | |
| if ioa > IOA_SUPPRESSION_THRESHOLD: | |
| is_suppressed = True | |
| break | |
| if not is_suppressed: | |
| items_to_sort.append(ocr_word) | |
| # Add figures/equations back into the flow as "words" | |
| items_to_sort.extend(component_metadata) | |
| # ==================================================================== | |
| # --- STEP 7: LINE-BASED SORTING --- | |
| # ==================================================================== | |
| items_to_sort.sort(key=lambda x: (x['y0'], x['x0'])) | |
| lines = [] | |
| for item in items_to_sort: | |
| placed = False | |
| for line in lines: | |
| y_ref = min(it['y0'] for it in line) | |
| if abs(y_ref - item['y0']) < LINE_TOLERANCE: | |
| line.append(item) | |
| placed = True | |
| break | |
| if not placed and item['type'] in ['equation', 'figure']: | |
| for line in lines: | |
| y_ref = min(it['y0'] for it in line) | |
| if abs(y_ref - item['y0']) < 20: | |
| line.append(item) | |
| placed = True | |
| break | |
| if not placed: | |
| lines.append([item]) | |
| for line in lines: | |
| line.sort(key=lambda x: x['x0']) | |
| final_output = [] | |
| for line in lines: | |
| for item in line: | |
| data_item = {"word": item["word"], "bbox": item["bbox"], "type": item["type"]} | |
| if 'tag' in item: data_item['tag'] = item['tag'] | |
| final_output.append(data_item) | |
| return final_output, page_separator_x | |
| def run_single_pdf_preprocessing(pdf_path: str, preprocessed_json_path: str) -> Optional[str]: | |
| global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT | |
| GLOBAL_FIGURE_COUNT = 0 | |
| GLOBAL_EQUATION_COUNT = 0 | |
| _ocr_cache.clear() | |
| print("\n" + "=" * 80) | |
| print("--- 1. STARTING OPTIMIZED YOLO/OCR PREPROCESSING PIPELINE ---") | |
| print("=" * 80) | |
| if not os.path.exists(pdf_path): | |
| print(f"β FATAL ERROR: Input PDF not found at {pdf_path}.") | |
| return None | |
| os.makedirs(os.path.dirname(preprocessed_json_path), exist_ok=True) | |
| os.makedirs(FIGURE_EXTRACTION_DIR, exist_ok=True) | |
| model = YOLO(WEIGHTS_PATH) | |
| pdf_name = os.path.splitext(os.path.basename(pdf_path))[0] | |
| try: | |
| doc = fitz.open(pdf_path) | |
| print(f"β Opened PDF: {pdf_name} ({doc.page_count} pages)") | |
| except Exception as e: | |
| print(f"β ERROR loading PDF file: {e}") | |
| return None | |
| all_pages_data = [] | |
| total_pages_processed = 0 | |
| mat = fitz.Matrix(2.0, 2.0) | |
| print("\n[STEP 1.2: ITERATING PAGES - IN-MEMORY PROCESSING]") | |
| for page_num_0_based in range(doc.page_count): | |
| page_num = page_num_0_based + 1 | |
| print(f" -> Processing Page {page_num}/{doc.page_count}...") | |
| fitz_page = doc.load_page(page_num_0_based) | |
| try: | |
| pix = fitz_page.get_pixmap(matrix=mat) | |
| original_img = pixmap_to_numpy(pix) | |
| except Exception as e: | |
| print(f" β Error converting page {page_num} to image: {e}") | |
| continue | |
| final_output, page_separator_x = preprocess_and_ocr_page( | |
| original_img, | |
| model, | |
| pdf_path, | |
| page_num, | |
| fitz_page, | |
| pdf_name | |
| ) | |
| if final_output is not None: | |
| page_data = { | |
| "page_number": page_num, | |
| "data": final_output, | |
| "column_separator_x": page_separator_x | |
| } | |
| all_pages_data.append(page_data) | |
| total_pages_processed += 1 | |
| else: | |
| print(f" β Skipped page {page_num} due to processing error.") | |
| doc.close() | |
| if all_pages_data: | |
| try: | |
| with open(preprocessed_json_path, 'w') as f: | |
| json.dump(all_pages_data, f, indent=4) | |
| print(f"\n β Combined structured OCR JSON saved to: {os.path.basename(preprocessed_json_path)}") | |
| except Exception as e: | |
| print(f"β ERROR saving combined JSON output: {e}") | |
| return None | |
| else: | |
| print("β WARNING: No page data generated. Halting pipeline.") | |
| return None | |
| print("\n" + "=" * 80) | |
| print(f"--- YOLO/OCR PREPROCESSING COMPLETE ({total_pages_processed} pages processed) ---") | |
| print("=" * 80) | |
| return preprocessed_json_path | |
| # ============================================================================ | |
| # --- PHASE 2: LAYOUTLMV3 INFERENCE FUNCTIONS --- | |
| # ============================================================================ | |
| class LayoutLMv3ForTokenClassification(nn.Module): | |
| def __init__(self, num_labels: int = NUM_LABELS): | |
| super().__init__() | |
| self.num_labels = num_labels | |
| config = LayoutLMv3Config.from_pretrained("microsoft/layoutlmv3-base", num_labels=num_labels) | |
| self.layoutlmv3 = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base", config=config) | |
| self.classifier = nn.Linear(config.hidden_size, num_labels) | |
| self.crf = CRF(num_labels) | |
| self.init_weights() | |
| def init_weights(self): | |
| nn.init.xavier_uniform_(self.classifier.weight) | |
| if self.classifier.bias is not None: nn.init.zeros_(self.classifier.bias) | |
| def forward(self, input_ids: torch.Tensor, bbox: torch.Tensor, attention_mask: torch.Tensor, | |
| labels: Optional[torch.Tensor] = None): | |
| outputs = self.layoutlmv3(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, return_dict=True) | |
| sequence_output = outputs.last_hidden_state | |
| emissions = self.classifier(sequence_output) | |
| mask = attention_mask.bool() | |
| if labels is not None: | |
| loss = -self.crf(emissions, labels, mask=mask).mean() | |
| return loss | |
| else: | |
| return self.crf.viterbi_decode(emissions, mask=mask) | |
| def _merge_integrity(all_token_data: List[Dict[str, Any]], | |
| column_separator_x: Optional[int]) -> List[List[Dict[str, Any]]]: | |
| """Splits the token data objects into column chunks based on a separator.""" | |
| if column_separator_x is None: | |
| print(" -> No column separator. Treating as one chunk.") | |
| return [all_token_data] | |
| left_column_tokens, right_column_tokens = [], [] | |
| for token_data in all_token_data: | |
| bbox_raw = token_data['bbox_raw_pdf_space'] | |
| center_x = (bbox_raw[0] + bbox_raw[2]) / 2 | |
| if center_x < column_separator_x: | |
| left_column_tokens.append(token_data) | |
| else: | |
| right_column_tokens.append(token_data) | |
| chunks = [c for c in [left_column_tokens, right_column_tokens] if c] | |
| print(f" -> Data split into {len(chunks)} column chunk(s) using separator X={column_separator_x}.") | |
| return chunks | |
| def run_inference_and_get_raw_words(pdf_path: str, model_path: str, | |
| preprocessed_json_path: str, | |
| column_detection_params: Optional[Dict] = None) -> List[Dict[str, Any]]: | |
| print("\n" + "=" * 80) | |
| print("--- 2. STARTING LAYOUTLMV3 INFERENCE PIPELINE (Raw Word Output) ---") | |
| print("=" * 80) | |
| tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f" -> Using device: {device}") | |
| try: | |
| model = LayoutLMv3ForTokenClassification(num_labels=NUM_LABELS) | |
| checkpoint = torch.load(model_path, map_location=device) | |
| model_state = checkpoint.get('model_state_dict', checkpoint) | |
| # Apply patch for layoutlmv3 compatibility with saved state_dict | |
| fixed_state_dict = {key.replace('layoutlm.', 'layoutlmv3.'): value for key, value in model_state.items()} | |
| model.load_state_dict(fixed_state_dict) | |
| model.to(device) | |
| model.eval() | |
| print(f"β LayoutLMv3 Model loaded successfully from {os.path.basename(model_path)}.") | |
| except Exception as e: | |
| print(f"β FATAL ERROR during LayoutLMv3 model loading: {e}") | |
| return [] | |
| try: | |
| with open(preprocessed_json_path, 'r', encoding='utf-8') as f: | |
| preprocessed_data = json.load(f) | |
| print(f"β Loaded preprocessed data with {len(preprocessed_data)} pages.") | |
| except Exception: | |
| print("β Error loading preprocessed JSON.") | |
| return [] | |
| try: | |
| doc = fitz.open(pdf_path) | |
| except Exception: | |
| print("β Error loading PDF.") | |
| return [] | |
| final_page_predictions = [] | |
| CHUNK_SIZE = 500 | |
| for page_data in preprocessed_data: | |
| page_num_1_based = page_data['page_number'] | |
| page_num_0_based = page_num_1_based - 1 | |
| page_raw_predictions = [] | |
| print(f"\n *** Processing Page {page_num_1_based} ({len(page_data['data'])} raw tokens) ***") | |
| fitz_page = doc.load_page(page_num_0_based) | |
| page_width, page_height = fitz_page.rect.width, fitz_page.rect.height | |
| print(f" -> Page dimensions: {page_width:.0f}x{page_height:.0f} (PDF points).") | |
| all_token_data = [] | |
| scale_factor = 2.0 | |
| for item in page_data['data']: | |
| raw_yolo_bbox = item['bbox'] | |
| bbox_pdf = [ | |
| int(raw_yolo_bbox[0] / scale_factor), int(raw_yolo_bbox[1] / scale_factor), | |
| int(raw_yolo_bbox[2] / scale_factor), int(raw_yolo_bbox[3] / scale_factor) | |
| ] | |
| normalized_bbox = [ | |
| max(0, min(1000, int(1000 * bbox_pdf[0] / page_width))), | |
| max(0, min(1000, int(1000 * bbox_pdf[1] / page_height))), | |
| max(0, min(1000, int(1000 * bbox_pdf[2] / page_width))), | |
| max(0, min(1000, int(1000 * bbox_pdf[3] / page_height))) | |
| ] | |
| all_token_data.append({ | |
| "word": item['word'], | |
| "bbox_raw_pdf_space": bbox_pdf, | |
| "bbox_normalized": normalized_bbox, | |
| "item_original_data": item | |
| }) | |
| if not all_token_data: | |
| continue | |
| column_separator_x = page_data.get('column_separator_x', None) | |
| if column_separator_x is not None: | |
| print(f" -> Using SAVED column separator: X={column_separator_x}") | |
| else: | |
| print(" -> No column separator found. Assuming single chunk.") | |
| token_chunks = _merge_integrity(all_token_data, column_separator_x) | |
| total_chunks = len(token_chunks) | |
| for chunk_idx, chunk_tokens in enumerate(token_chunks): | |
| if not chunk_tokens: continue | |
| # 1. Sanitize: Convert everything to strings and aggressively clean Unicode errors. | |
| chunk_words = [ | |
| str(t['word']).encode('utf-8', errors='ignore').decode('utf-8') | |
| for t in chunk_tokens | |
| ] | |
| chunk_normalized_bboxes = [t['bbox_normalized'] for t in chunk_tokens] | |
| total_sub_chunks = (len(chunk_words) + CHUNK_SIZE - 1) // CHUNK_SIZE | |
| for i in range(0, len(chunk_words), CHUNK_SIZE): | |
| sub_chunk_idx = i // CHUNK_SIZE + 1 | |
| sub_words = chunk_words[i:i + CHUNK_SIZE] | |
| sub_bboxes = chunk_normalized_bboxes[i:i + CHUNK_SIZE] | |
| sub_tokens_data = chunk_tokens[i:i + CHUNK_SIZE] | |
| print(f" -> Chunk {chunk_idx + 1}/{total_chunks}, Sub-chunk {sub_chunk_idx}/{total_sub_chunks}: {len(sub_words)} words. Running Inference...") | |
| # 2. Manual generation of word_ids | |
| manual_word_ids = [] | |
| for current_word_idx, word in enumerate(sub_words): | |
| sub_tokens = tokenizer.tokenize(word) | |
| for _ in sub_tokens: | |
| manual_word_ids.append(current_word_idx) | |
| encoded_input = tokenizer( | |
| sub_words, | |
| boxes=sub_bboxes, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=512, | |
| is_split_into_words=True, | |
| return_tensors="pt" | |
| ) | |
| # Check for empty sequence | |
| if encoded_input['input_ids'].shape[0] == 0: | |
| print(f" -> Warning: Sub-chunk {sub_chunk_idx} encoded to an empty sequence. Skipping.") | |
| continue | |
| # 3. Finalize word_ids based on encoded output length | |
| sequence_length = int(torch.sum(encoded_input['attention_mask']).item()) | |
| content_token_length = max(0, sequence_length - 2) | |
| manual_word_ids = manual_word_ids[:content_token_length] | |
| final_word_ids = [None] # CLS token (index 0) | |
| final_word_ids.extend(manual_word_ids) | |
| if sequence_length > 1: | |
| final_word_ids.append(None) # SEP token | |
| final_word_ids.extend([None] * (512 - len(final_word_ids))) | |
| word_ids = final_word_ids[:512] # Final array for mapping | |
| # Inputs are already batched by the tokenizer as [1, 512] | |
| input_ids = encoded_input['input_ids'].to(device) | |
| bbox = encoded_input['bbox'].to(device) | |
| attention_mask = encoded_input['attention_mask'].to(device) | |
| with torch.no_grad(): | |
| model_outputs = model(input_ids, bbox, attention_mask) | |
| # --- Robust extraction: support several forward return types --- | |
| # We'll try (in order): | |
| # 1) model_outputs is (emissions_tensor, viterbi_list) -> use emissions for logits, keep decoded | |
| # 2) model_outputs has .logits attribute (HF ModelOutput) | |
| # 3) model_outputs is tuple/list containing a logits tensor | |
| # 4) model_outputs is a tensor (assume logits) | |
| # 5) model_outputs is a list-of-lists of ints (viterbi decoded) -> use that directly (no logits) | |
| logits_tensor = None | |
| decoded_labels_list = None | |
| # case 1: tuple/list with (emissions, viterbi) | |
| if isinstance(model_outputs, (tuple, list)) and len(model_outputs) == 2: | |
| a, b = model_outputs | |
| # a might be tensor (emissions), b might be viterbi list | |
| if isinstance(a, torch.Tensor): | |
| logits_tensor = a | |
| if isinstance(b, list): | |
| decoded_labels_list = b | |
| # case 2: HF ModelOutput with .logits | |
| if logits_tensor is None and hasattr(model_outputs, 'logits') and isinstance(model_outputs.logits, torch.Tensor): | |
| logits_tensor = model_outputs.logits | |
| # case 3: tuple/list - search for a 3D tensor (B, L, C) | |
| if logits_tensor is None and isinstance(model_outputs, (tuple, list)): | |
| found_tensor = None | |
| for item in model_outputs: | |
| if isinstance(item, torch.Tensor): | |
| # prefer 3D (batch, seq, labels) | |
| if item.dim() == 3: | |
| logits_tensor = item | |
| break | |
| if found_tensor is None: | |
| found_tensor = item | |
| if logits_tensor is None and found_tensor is not None: | |
| # found_tensor may be (batch, seq, hidden) or (seq, hidden); we avoid guessing. | |
| # Keep found_tensor only if it matches num_labels dimension | |
| if found_tensor.dim() == 3 and found_tensor.shape[-1] == NUM_LABELS: | |
| logits_tensor = found_tensor | |
| elif found_tensor.dim() == 2 and found_tensor.shape[-1] == NUM_LABELS: | |
| logits_tensor = found_tensor.unsqueeze(0) | |
| # case 4: model_outputs directly a tensor | |
| if logits_tensor is None and isinstance(model_outputs, torch.Tensor): | |
| logits_tensor = model_outputs | |
| # case 5: model_outputs is a decoded viterbi list (common for CRF-only forward) | |
| if decoded_labels_list is None and isinstance(model_outputs, list) and model_outputs and isinstance(model_outputs[0], list): | |
| # assume model_outputs is already viterbi decoded: List[List[int]] with batch dim first | |
| decoded_labels_list = model_outputs | |
| # If neither logits nor decoded exist, that's fatal | |
| if logits_tensor is None and decoded_labels_list is None: | |
| # helpful debug info | |
| try: | |
| elem_shapes = [ (type(x), getattr(x, 'shape', None)) for x in model_outputs ] if isinstance(model_outputs, (list, tuple)) else [(type(model_outputs), getattr(model_outputs, 'shape', None))] | |
| except Exception: | |
| elem_shapes = str(type(model_outputs)) | |
| raise RuntimeError(f"Model output of type {type(model_outputs)} did not contain a valid logits tensor or decoded viterbi. Contents: {elem_shapes}") | |
| # If we have logits_tensor, normalize shape to [seq_len, num_labels] | |
| if logits_tensor is not None: | |
| # If shape is [B, L, C] with B==1, squeeze batch | |
| if logits_tensor.dim() == 3 and logits_tensor.shape[0] == 1: | |
| preds_tensor = logits_tensor.squeeze(0) # [L, C] | |
| else: | |
| preds_tensor = logits_tensor # possibly [L, C] already | |
| # Safety: ensure we have at least seq_len x channels | |
| if preds_tensor.dim() != 2: | |
| # try to reshape or error | |
| raise RuntimeError(f"Unexpected logits tensor shape: {tuple(preds_tensor.shape)}") | |
| # We'll use preds_tensor[token_idx] to argmax | |
| else: | |
| preds_tensor = None # no logits available | |
| # If decoded labels provided, make a token-level list-of-ints aligned to tokenizer tokens | |
| decoded_token_labels = None | |
| if decoded_labels_list is not None: | |
| # decoded_labels_list is batch-first; we used batch size 1 | |
| # if multiple sequences returned, take first | |
| decoded_token_labels = decoded_labels_list[0] if isinstance(decoded_labels_list[0], list) else decoded_labels_list | |
| # Now map token-level predictions -> word-level predictions using word_ids | |
| word_idx_to_pred_id = {} | |
| if preds_tensor is not None: | |
| # We have logits. Use argmax of logits for each token id up to sequence_length | |
| for token_idx, word_idx in enumerate(word_ids): | |
| if token_idx >= sequence_length: | |
| break | |
| if word_idx is not None and word_idx < len(sub_words): | |
| if word_idx not in word_idx_to_pred_id: | |
| pred_id = torch.argmax(preds_tensor[token_idx]).item() | |
| word_idx_to_pred_id[word_idx] = pred_id | |
| else: | |
| # No logits, but we have decoded_token_labels from CRF (one label per token) | |
| # We'll align decoded_token_labels to token positions. | |
| if decoded_token_labels is None: | |
| # should not happen due to earlier checks | |
| raise RuntimeError("No logits and no decoded labels available for mapping.") | |
| # decoded_token_labels length may be equal to content_token_length (no special tokens) | |
| # or equal to sequence_length; try to align intelligently: | |
| # Prefer using decoded_token_labels aligned to the tokenizer tokens (starting at token 1 for CLS) | |
| # If decoded length == content_token_length, then manual_word_ids maps sub-token -> word idx for content tokens only. | |
| # We'll iterate tokens and pick label accordingly. | |
| # Build token_idx -> decoded_label mapping: | |
| # We'll assume decoded_token_labels correspond to content tokens (no CLS/SEP). If decoded length == sequence_length, then shift by 0. | |
| decoded_len = len(decoded_token_labels) | |
| # Heuristic: if decoded_len == content_token_length -> alignment starts at token_idx 1 (skip CLS) | |
| if decoded_len == content_token_length: | |
| decoded_start = 1 | |
| elif decoded_len == sequence_length: | |
| decoded_start = 0 | |
| else: | |
| # fallback: prefer decoded_start=1 (most common) | |
| decoded_start = 1 | |
| for tok_idx_in_decoded, label_id in enumerate(decoded_token_labels): | |
| tok_idx = decoded_start + tok_idx_in_decoded | |
| if tok_idx >= 512: | |
| break | |
| if tok_idx >= sequence_length: | |
| break | |
| # map this token to a word index if present | |
| word_idx = word_ids[tok_idx] if tok_idx < len(word_ids) else None | |
| if word_idx is not None and word_idx < len(sub_words): | |
| if word_idx not in word_idx_to_pred_id: | |
| # label_id may already be an int | |
| word_idx_to_pred_id[word_idx] = int(label_id) | |
| # Finally convert mapped word preds -> page_raw_predictions entries | |
| for current_word_idx in range(len(sub_words)): | |
| pred_id = word_idx_to_pred_id.get(current_word_idx, 0) # default to 0 | |
| predicted_label = ID_TO_LABEL[pred_id] | |
| original_token = sub_tokens_data[current_word_idx] | |
| page_raw_predictions.append({ | |
| "word": original_token['word'], | |
| "bbox": original_token['bbox_raw_pdf_space'], | |
| "predicted_label": predicted_label, | |
| "page_number": page_num_1_based | |
| }) | |
| if page_raw_predictions: | |
| final_page_predictions.append({ | |
| "page_number": page_num_1_based, | |
| "data": page_raw_predictions | |
| }) | |
| print(f" *** Page {page_num_1_based} Finalized: {len(page_raw_predictions)} labeled words. ***") | |
| doc.close() | |
| print("\n" + "=" * 80) | |
| print("--- LAYOUTLMV3 INFERENCE COMPLETE ---") | |
| print("=" * 80) | |
| return final_page_predictions | |
| # ============================================================================ | |
| # --- PHASE 3: BIO TO STRUCTURED JSON DECODER --- | |
| # ============================================================================ | |
| # def convert_bio_to_structured_json_relaxed(input_path: str, output_path: str) -> Optional[List[Dict[str, Any]]]: | |
| # print("\n" + "=" * 80) | |
| # print("--- 3. STARTING BIO TO STRUCTURED JSON DECODING ---") | |
| # print("=" * 80) | |
| # try: | |
| # with open(input_path, 'r', encoding='utf-8') as f: | |
| # predictions_by_page = json.load(f) | |
| # except Exception as e: | |
| # print(f"β Error loading raw prediction file: {e}") | |
| # return None | |
| # predictions = [] | |
| # for page_item in predictions_by_page: | |
| # if isinstance(page_item, dict) and 'data' in page_item: | |
| # predictions.extend(page_item['data']) | |
| # structured_data = [] | |
| # current_item = None | |
| # current_option_key = None | |
| # current_passage_buffer = [] | |
| # current_text_buffer = [] | |
| # first_question_started = False | |
| # last_entity_type = None | |
| # just_finished_i_option = False | |
| # is_in_new_passage = False | |
| # def finalize_passage_to_item(item, passage_buffer): | |
| # if passage_buffer: | |
| # passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip() | |
| # if item.get('passage'): | |
| # item['passage'] += ' ' + passage_text | |
| # else: | |
| # item['passage'] = passage_text | |
| # passage_buffer.clear() | |
| # for item in predictions: | |
| # word = item['word'] | |
| # label = item['predicted_label'] | |
| # entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None | |
| # current_text_buffer.append(word) | |
| # previous_entity_type = last_entity_type | |
| # is_passage_label = (entity_type == 'PASSAGE') | |
| # if not first_question_started: | |
| # if label != 'B-QUESTION' and not is_passage_label: | |
| # just_finished_i_option = False | |
| # is_in_new_passage = False | |
| # continue | |
| # if is_passage_label: | |
| # current_passage_buffer.append(word) | |
| # last_entity_type = 'PASSAGE' | |
| # just_finished_i_option = False | |
| # is_in_new_passage = False | |
| # continue | |
| # if label == 'B-QUESTION': | |
| # if not first_question_started: | |
| # header_text = ' '.join(current_text_buffer[:-1]).strip() | |
| # if header_text or current_passage_buffer: | |
| # metadata_item = {'type': 'METADATA', 'passage': ''} | |
| # finalize_passage_to_item(metadata_item, current_passage_buffer) | |
| # if header_text: metadata_item['text'] = header_text | |
| # structured_data.append(metadata_item) | |
| # first_question_started = True | |
| # current_text_buffer = [word] | |
| # if current_item is not None: | |
| # finalize_passage_to_item(current_item, current_passage_buffer) | |
| # current_item['text'] = ' '.join(current_text_buffer[:-1]).strip() | |
| # structured_data.append(current_item) | |
| # current_text_buffer = [word] | |
| # current_item = { | |
| # 'question': word, 'options': {}, 'answer': '', 'passage': '', 'text': '' | |
| # } | |
| # current_option_key = None | |
| # last_entity_type = 'QUESTION' | |
| # just_finished_i_option = False | |
| # is_in_new_passage = False | |
| # continue | |
| # if current_item is not None: | |
| # if is_in_new_passage: | |
| # # π Robust Initialization and Appending for 'new_passage' | |
| # if 'new_passage' not in current_item: | |
| # current_item['new_passage'] = word | |
| # else: | |
| # current_item['new_passage'] += f' {word}' | |
| # if label.startswith('B-') or (label.startswith('I-') and entity_type != 'PASSAGE'): | |
| # is_in_new_passage = False | |
| # if label.startswith(('B-', 'I-')): last_entity_type = entity_type | |
| # continue | |
| # is_in_new_passage = False | |
| # if label.startswith('B-'): | |
| # if entity_type in ['QUESTION', 'OPTION', 'ANSWER', 'SECTION_HEADING']: | |
| # finalize_passage_to_item(current_item, current_passage_buffer) | |
| # current_passage_buffer = [] | |
| # last_entity_type = entity_type | |
| # if entity_type == 'PASSAGE': | |
| # if previous_entity_type == 'OPTION' and just_finished_i_option: | |
| # current_item['new_passage'] = word # Initialize the new passage start | |
| # is_in_new_passage = True | |
| # else: | |
| # current_passage_buffer.append(word) | |
| # elif entity_type == 'OPTION': | |
| # current_option_key = word | |
| # current_item['options'][current_option_key] = word | |
| # just_finished_i_option = False | |
| # elif entity_type == 'ANSWER': | |
| # current_item['answer'] = word | |
| # current_option_key = None | |
| # just_finished_i_option = False | |
| # elif entity_type == 'QUESTION': | |
| # current_item['question'] += f' {word}' | |
| # just_finished_i_option = False | |
| # elif label.startswith('I-'): | |
| # if entity_type == 'QUESTION': | |
| # current_item['question'] += f' {word}' | |
| # elif entity_type == 'PASSAGE': | |
| # if previous_entity_type == 'OPTION' and just_finished_i_option: | |
| # current_item['new_passage'] = word # Initialize the new passage start | |
| # is_in_new_passage = True | |
| # else: | |
| # if not current_passage_buffer: last_entity_type = 'PASSAGE' | |
| # current_passage_buffer.append(word) | |
| # elif entity_type == 'OPTION' and current_option_key is not None: | |
| # current_item['options'][current_option_key] += f' {word}' | |
| # just_finished_i_option = True | |
| # elif entity_type == 'ANSWER': | |
| # current_item['answer'] += f' {word}' | |
| # just_finished_i_option = (entity_type == 'OPTION') | |
| # elif label == 'O': | |
| # if last_entity_type == 'QUESTION': | |
| # current_item['question'] += f' {word}' | |
| # just_finished_i_option = False | |
| # if current_item is not None: | |
| # finalize_passage_to_item(current_item, current_passage_buffer) | |
| # current_item['text'] = ' '.join(current_text_buffer).strip() | |
| # structured_data.append(current_item) | |
| # for item in structured_data: | |
| # item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip() | |
| # if 'new_passage' in item: | |
| # item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip() | |
| # try: | |
| # with open(output_path, 'w', encoding='utf-8') as f: | |
| # json.dump(structured_data, f, indent=2, ensure_ascii=False) | |
| # except Exception: | |
| # pass | |
| # return structured_data | |
| def convert_bio_to_structured_json_relaxed(input_path: str, output_path: str) -> Optional[List[Dict[str, Any]]]: | |
| print("\n" + "=" * 80) | |
| print("--- 3. STARTING BIO TO STRUCTURED JSON DECODING ---") | |
| print(f"Source: {input_path}") | |
| print("=" * 80) | |
| start_time = time.time() | |
| try: | |
| with open(input_path, 'r', encoding='utf-8') as f: | |
| predictions_by_page = json.load(f) | |
| print(f"β Successfully loaded raw predictions ({len(predictions_by_page)} pages found)") | |
| except Exception as e: | |
| print(f"β Error loading raw prediction file: {e}") | |
| return None | |
| predictions = [] | |
| for page_item in predictions_by_page: | |
| if isinstance(page_item, dict) and 'data' in page_item: | |
| predictions.extend(page_item['data']) | |
| total_words = len(predictions) | |
| print(f"π Total words to process: {total_words}") | |
| structured_data = [] | |
| current_item = None | |
| current_option_key = None | |
| current_passage_buffer = [] | |
| current_text_buffer = [] | |
| first_question_started = False | |
| last_entity_type = None | |
| just_finished_i_option = False | |
| is_in_new_passage = False | |
| def finalize_passage_to_item(item, passage_buffer): | |
| if passage_buffer: | |
| passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip() | |
| print(f" β³ [Buffer] Finalizing passage ({len(passage_buffer)} words) into current item") | |
| if item.get('passage'): | |
| item['passage'] += ' ' + passage_text | |
| else: | |
| item['passage'] = passage_text | |
| passage_buffer.clear() | |
| # Iterate through every predicted word | |
| for idx, item in enumerate(predictions): | |
| word = item['word'] | |
| label = item['predicted_label'] | |
| entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None | |
| current_text_buffer.append(word) | |
| previous_entity_type = last_entity_type | |
| is_passage_label = (entity_type == 'PASSAGE') | |
| # --- LOGGING: Track progress every 500 words or on B- labels --- | |
| if label.startswith('B-'): | |
| print(f"[Word {idx}/{total_words}] Found Label: {label} | Word: '{word}'") | |
| if not first_question_started: | |
| if label != 'B-QUESTION' and not is_passage_label: | |
| just_finished_i_option = False | |
| is_in_new_passage = False | |
| continue | |
| if is_passage_label: | |
| current_passage_buffer.append(word) | |
| last_entity_type = 'PASSAGE' | |
| just_finished_i_option = False | |
| is_in_new_passage = False | |
| continue | |
| if label == 'B-QUESTION': | |
| print(f"π Detection: New Question Started at word {idx}") | |
| if not first_question_started: | |
| header_text = ' '.join(current_text_buffer[:-1]).strip() | |
| if header_text or current_passage_buffer: | |
| print(f" -> Creating METADATA item for text found before first question") | |
| metadata_item = {'type': 'METADATA', 'passage': ''} | |
| finalize_passage_to_item(metadata_item, current_passage_buffer) | |
| if header_text: metadata_item['text'] = header_text | |
| structured_data.append(metadata_item) | |
| first_question_started = True | |
| current_text_buffer = [word] | |
| if current_item is not None: | |
| finalize_passage_to_item(current_item, current_passage_buffer) | |
| current_item['text'] = ' '.join(current_text_buffer[:-1]).strip() | |
| structured_data.append(current_item) | |
| print(f" -> Saved Question. Total structured items so far: {len(structured_data)}") | |
| current_text_buffer = [word] | |
| current_item = { | |
| 'question': word, 'options': {}, 'answer': '', 'passage': '', 'text': '' | |
| } | |
| current_option_key = None | |
| last_entity_type = 'QUESTION' | |
| just_finished_i_option = False | |
| is_in_new_passage = False | |
| continue | |
| if current_item is not None: | |
| if is_in_new_passage: | |
| if 'new_passage' not in current_item: | |
| current_item['new_passage'] = word | |
| else: | |
| current_item['new_passage'] += f' {word}' | |
| if label.startswith('B-') or (label.startswith('I-') and entity_type != 'PASSAGE'): | |
| print(f" β³ [State] Exiting new_passage mode at label {label}") | |
| is_in_new_passage = False | |
| if label.startswith(('B-', 'I-')): | |
| last_entity_type = entity_type | |
| continue | |
| is_in_new_passage = False | |
| if label.startswith('B-'): | |
| if entity_type in ['QUESTION', 'OPTION', 'ANSWER', 'SECTION_HEADING']: | |
| finalize_passage_to_item(current_item, current_passage_buffer) | |
| current_passage_buffer = [] | |
| last_entity_type = entity_type | |
| if entity_type == 'PASSAGE': | |
| if previous_entity_type == 'OPTION' and just_finished_i_option: | |
| print(f" β³ [State] Transitioning to new_passage (Option -> Passage boundary)") | |
| current_item['new_passage'] = word | |
| is_in_new_passage = True | |
| else: | |
| current_passage_buffer.append(word) | |
| elif entity_type == 'OPTION': | |
| current_option_key = word | |
| current_item['options'][current_option_key] = word | |
| just_finished_i_option = False | |
| elif entity_type == 'ANSWER': | |
| current_item['answer'] = word | |
| current_option_key = None | |
| just_finished_i_option = False | |
| elif entity_type == 'QUESTION': | |
| current_item['question'] += f' {word}' | |
| just_finished_i_option = False | |
| elif label.startswith('I-'): | |
| if entity_type == 'QUESTION': | |
| current_item['question'] += f' {word}' | |
| elif entity_type == 'PASSAGE': | |
| if previous_entity_type == 'OPTION' and just_finished_i_option: | |
| current_item['new_passage'] = word | |
| is_in_new_passage = True | |
| else: | |
| if not current_passage_buffer: last_entity_type = 'PASSAGE' | |
| current_passage_buffer.append(word) | |
| elif entity_type == 'OPTION' and current_option_key is not None: | |
| current_item['options'][current_option_key] += f' {word}' | |
| just_finished_i_option = True | |
| elif entity_type == 'ANSWER': | |
| current_item['answer'] += f' {word}' | |
| just_finished_i_option = (entity_type == 'OPTION') | |
| elif label == 'O': | |
| # if last_entity_type == 'QUESTION': | |
| # current_item['question'] += f' {word}' | |
| # just_finished_i_option = False | |
| pass | |
| # Final wrap up | |
| if current_item is not None: | |
| print(f"π Finalizing the very last item...") | |
| finalize_passage_to_item(current_item, current_passage_buffer) | |
| current_item['text'] = ' '.join(current_text_buffer).strip() | |
| structured_data.append(current_item) | |
| # Clean up and regex replacement | |
| for item in structured_data: | |
| item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip() | |
| if 'new_passage' in item: | |
| item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip() | |
| print(f"πΎ Saving {len(structured_data)} items to {output_path}") | |
| try: | |
| with open(output_path, 'w', encoding='utf-8') as f: | |
| json.dump(structured_data, f, indent=2, ensure_ascii=False) | |
| print(f"β Decoding Complete. Total time: {time.time() - start_time:.2f}s") | |
| except Exception as e: | |
| print(f"β οΈ Error saving final JSON: {e}") | |
| return structured_data | |
| def create_query_text(entry: Dict[str, Any]) -> str: | |
| """Combines question and options into a single string for similarity matching.""" | |
| query_parts = [] | |
| if entry.get("question"): | |
| query_parts.append(entry["question"]) | |
| for key in ["options", "options_text"]: | |
| options = entry.get(key) | |
| if options and isinstance(options, dict): | |
| for value in options.values(): | |
| if value and isinstance(value, str): | |
| query_parts.append(value) | |
| return " ".join(query_parts) | |
| def calculate_similarity(doc1: str, doc2: str) -> float: | |
| """Calculates Cosine Similarity between two text strings.""" | |
| if not doc1 or not doc2: | |
| return 0.0 | |
| def clean_text(text): | |
| return re.sub(r'^\s*[\(\d\w]+\.?\s*', '', text, flags=re.MULTILINE) | |
| clean_doc1 = clean_text(doc1) | |
| clean_doc2 = clean_text(doc2) | |
| corpus = [clean_doc1, clean_doc2] | |
| try: | |
| vectorizer = CountVectorizer(stop_words='english', lowercase=True, token_pattern=r'(?u)\b\w\w+\b') | |
| tfidf_matrix = vectorizer.fit_transform(corpus) | |
| if tfidf_matrix.shape[1] == 0: | |
| return 0.0 | |
| vectors = tfidf_matrix.toarray() | |
| # Handle cases where vectors might be empty or too short | |
| if len(vectors) < 2: | |
| return 0.0 | |
| score = cosine_similarity(vectors[0:1], vectors[1:2])[0][0] | |
| return score | |
| except Exception: | |
| return 0.0 | |
| def process_context_linking(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| """ | |
| Links questions to passages based on 'passage' flow vs 'new_passage' priority. | |
| Includes 'Decay Logic': If 2 consecutive questions fail to match the active passage, | |
| the passage context is dropped to prevent false positives downstream. | |
| """ | |
| print("\n" + "=" * 80) | |
| print("--- STARTING CONTEXT LINKING (WITH DECAY LOGIC) ---") | |
| print("=" * 80) | |
| if not data: return [] | |
| # --- PHASE 1: IDENTIFY PASSAGE DEFINERS --- | |
| passage_definer_indices = [] | |
| for i, entry in enumerate(data): | |
| if entry.get("passage") and entry["passage"].strip(): | |
| passage_definer_indices.append(i) | |
| if entry.get("new_passage") and entry["new_passage"].strip(): | |
| if i not in passage_definer_indices: | |
| passage_definer_indices.append(i) | |
| # --- PHASE 2: CONTEXT TRANSFER & LINKING --- | |
| current_passage_text = None | |
| current_new_passage_text = None | |
| # NEW: Counter to track consecutive linking failures | |
| consecutive_failures = 0 | |
| MAX_CONSECUTIVE_FAILURES = 2 | |
| for i, entry in enumerate(data): | |
| item_type = entry.get("type", "Question") | |
| # A. UNCONDITIONALLY UPDATE CONTEXTS (And Reset Decay Counter) | |
| if entry.get("passage") and entry["passage"].strip(): | |
| current_passage_text = entry["passage"] | |
| consecutive_failures = 0 # Reset because we have fresh explicit context | |
| # print(f" [Flow] Updated Standard Context from Item {i}") | |
| if entry.get("new_passage") and entry["new_passage"].strip(): | |
| current_new_passage_text = entry["new_passage"] | |
| # We don't necessarily reset standard failures here as this is a local override | |
| # B. QUESTION LINKING | |
| if entry.get("question") and item_type != "METADATA": | |
| combined_query = create_query_text(entry) | |
| # Skip if query is too short (noise) | |
| if len(combined_query.strip()) < 5: | |
| continue | |
| # Calculate scores | |
| score_old = calculate_similarity(current_passage_text, combined_query) if current_passage_text else 0.0 | |
| score_new = calculate_similarity(current_new_passage_text, | |
| combined_query) if current_new_passage_text else 0.0 | |
| # ------------------------------------------------------------------ | |
| # π CRITICAL FIX APPLIED HERE π | |
| # The original line: q_preview = entry['question'][:30] + '...' | |
| # 1. Capture the raw preview string (which might contain the bad surrogate) | |
| q_preview_raw = entry['question'][:30] + '...' | |
| # 2. Safely clean the string by encoding to UTF-8 and ignoring errors, | |
| # then decoding back. This removes the invalid surrogate character. | |
| q_preview = q_preview_raw.encode('utf-8', errors='ignore').decode('utf-8') | |
| # ------------------------------------------------------------------ | |
| # RESOLUTION LOGIC | |
| linked = False | |
| # 1. Prefer New Passage if significantly better | |
| if current_new_passage_text and (score_new > score_old + RESOLUTION_MARGIN) and ( | |
| score_new >= SIMILARITY_THRESHOLD): | |
| entry["passage"] = current_new_passage_text | |
| print(f" [Linker] π Q{i} ('{q_preview}') -> NEW PASSAGE (Score: {score_new:.3f})") | |
| linked = True | |
| # Note: We do not reset 'consecutive_failures' for the standard passage here, | |
| # because we matched the *new* passage, not the standard one. | |
| # 2. Otherwise use Standard Passage if it meets threshold | |
| elif current_passage_text and (score_old >= SIMILARITY_THRESHOLD): | |
| entry["passage"] = current_passage_text | |
| print(f" [Linker] β Q{i} ('{q_preview}') -> STANDARD PASSAGE (Score: {score_old:.3f})") | |
| linked = True | |
| consecutive_failures = 0 # Success! Reset the kill switch. | |
| if not linked: | |
| # 3. DECAY LOGIC | |
| if current_passage_text: | |
| consecutive_failures += 1 | |
| # This is the line that was failing (or similar logging lines) | |
| print( | |
| f" [Linker] β οΈ Q{i} NOT LINKED. (Failures: {consecutive_failures}/{MAX_CONSECUTIVE_FAILURES})") | |
| if consecutive_failures >= MAX_CONSECUTIVE_FAILURES: | |
| print(f" [Linker] ποΈ Context dropped due to {consecutive_failures} consecutive misses.") | |
| current_passage_text = None | |
| consecutive_failures = 0 | |
| else: | |
| print(f" [Linker] β οΈ Q{i} NOT LINKED (No active context).") | |
| # --- PHASE 3: CLEANUP AND INTERPOLATION --- | |
| print(" [Linker] Running Cleanup & Interpolation...") | |
| # 3A. Self-Correction (Remove weak links) | |
| for i in passage_definer_indices: | |
| entry = data[i] | |
| if entry.get("question") and entry.get("type") != "METADATA": | |
| passage_to_check = entry.get("passage") or entry.get("new_passage") | |
| if passage_to_check: | |
| self_sim = calculate_similarity(passage_to_check, create_query_text(entry)) | |
| if self_sim < SIMILARITY_THRESHOLD: | |
| entry["passage"] = "" | |
| if "new_passage" in entry: entry["new_passage"] = "" | |
| print(f" [Cleanup] Removed weak link for Q{i}") | |
| # 3B. Interpolation (Fill gaps) | |
| # We only interpolate if the gap is strictly 1 question wide to avoid undoing the decay logic | |
| for i in range(1, len(data) - 1): | |
| current_entry = data[i] | |
| is_gap = current_entry.get("question") and not current_entry.get("passage") | |
| if is_gap: | |
| prev_p = data[i - 1].get("passage") | |
| next_p = data[i + 1].get("passage") | |
| if prev_p and next_p and (prev_p == next_p) and prev_p.strip(): | |
| current_entry["passage"] = prev_p | |
| print(f" [Linker] π₯ͺ Q{i} Interpolated from neighbors.") | |
| return data | |
| def correct_misaligned_options(structured_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| print("\n" + "=" * 80) | |
| print("--- 5. STARTING POST-PROCESSING: OPTION ALIGNMENT CORRECTION ---") | |
| print("=" * 80) | |
| tag_pattern = re.compile(r'(EQUATION\d+|FIGURE\d+)') | |
| corrected_count = 0 | |
| for item in structured_data: | |
| if item.get('type') in ['METADATA']: continue | |
| options = item.get('options') | |
| if not options or len(options) < 2: continue | |
| option_keys = list(options.keys()) | |
| for i in range(len(option_keys) - 1): | |
| current_key = option_keys[i] | |
| next_key = option_keys[i + 1] | |
| current_value = options[current_key].strip() | |
| next_value = options[next_key].strip() | |
| is_current_empty = current_value == current_key | |
| content_in_next = next_value.replace(next_key, '', 1).strip() | |
| tags_in_next = tag_pattern.findall(content_in_next) | |
| has_two_tags = len(tags_in_next) == 2 | |
| if is_current_empty and has_two_tags: | |
| tag_to_move = tags_in_next[0] | |
| options[current_key] = f"{current_key} {tag_to_move}".strip() | |
| options[next_key] = f"{next_key} {tags_in_next[1]}".strip() | |
| corrected_count += 1 | |
| print(f"β Option alignment correction finished. Total corrections: {corrected_count}.") | |
| return structured_data | |
| def get_base64_for_file(filepath: str) -> Optional[str]: | |
| """Reads a file and returns its Base64 encoded string without the data URI prefix.""" | |
| try: | |
| with open(filepath, "rb") as image_file: | |
| # Return raw base64 string | |
| return base64.b64encode(image_file.read()).decode('utf-8') | |
| except Exception as e: | |
| print(f"Error reading and encoding file {filepath}: {e}") | |
| return None | |
| def embed_images_as_base64_in_memory(structured_data: List[Dict[str, Any]], figure_extraction_dir: str) -> List[ Dict[str, Any]]: | |
| print("\n" + "=" * 80) | |
| print("--- 4. STARTING IMAGE EMBEDDING (Base64) / EQUATION TO LATEX CONVERSION ---") | |
| print("=" * 80) | |
| if not structured_data: | |
| return [] | |
| image_files = glob.glob(os.path.join(figure_extraction_dir, "*.png")) | |
| image_lookup = {} | |
| tag_regex = re.compile(r'(figure|equation)(\d+)', re.IGNORECASE) | |
| for filepath in image_files: | |
| filename = os.path.basename(filepath) | |
| match = re.search(r'_(figure|equation)(\d+)\.png$', filename, re.IGNORECASE) | |
| if match: | |
| key = f"{match.group(1).upper()}{match.group(2)}" | |
| image_lookup[key] = filepath | |
| print(f" -> Found {len(image_lookup)} image components.") | |
| final_structured_data = [] | |
| for item in structured_data: | |
| text_fields = [item.get('question', ''), item.get('passage', '')] | |
| if 'options' in item: | |
| for opt_val in item['options'].values(): | |
| text_fields.append(opt_val) | |
| if 'new_passage' in item: | |
| text_fields.append(item['new_passage']) | |
| unique_tags_to_embed = set() | |
| for text in text_fields: | |
| if not text: continue | |
| for match in tag_regex.finditer(text): | |
| tag = match.group(0).upper() | |
| if tag in image_lookup: | |
| unique_tags_to_embed.add(tag) | |
| # List of tags that were successfully converted to LaTeX | |
| tags_converted_to_latex = set() | |
| for tag in sorted(list(unique_tags_to_embed)): | |
| filepath = image_lookup[tag] | |
| base_key = tag.replace(' ', '').lower() # e.g., figure1 or equation1 | |
| if 'EQUATION' in tag: | |
| # Equation to LaTeX conversion | |
| base64_code = get_base64_for_file(filepath) # This reads the file for conversion | |
| if base64_code: | |
| latex_output = get_latex_from_base64(base64_code) | |
| if not latex_output.startswith('[P2T_ERROR') and not latex_output.startswith('[P2T_WARNING'): | |
| # *** CORE CHANGE: Store the clean LaTeX output directly *** | |
| item[base_key] = latex_output | |
| tags_converted_to_latex.add(tag) | |
| print(f" β Embedded Clean LaTeX for {tag}") | |
| else: | |
| # On failure, embed the error message | |
| item[base_key] = latex_output | |
| print(f" β οΈ Failed to convert {tag} to LaTeX. Embedding error message.") | |
| else: | |
| item[base_key] = "[FILE_ERROR: Could not read image file]" | |
| print(f" β File read error for {tag}.") | |
| elif 'FIGURE' in tag: | |
| # Figure to Base64 conversion | |
| base64_code = get_base64_for_file(filepath) | |
| item[base_key] = base64_code | |
| print(f" β Embedded Base64 for {tag}") | |
| final_structured_data.append(item) | |
| print(f"β Image embedding complete.") | |
| return final_structured_data | |
| # ============================================================================ | |
| # --- MAIN FUNCTION --- | |
| # ============================================================================ | |
| # def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str) -> Optional[ | |
| # List[Dict[str, Any]]]: | |
| # def run_document_pipeline( input_pdf_path: str, layoutlmv3_model_path: str, structured_intermediate_output_path: Optional[str] = None) -> Optional[List[Dict[str, Any]]]: | |
| # if not os.path.exists(input_pdf_path): return None | |
| # print("\n" + "#" * 80) | |
| # print("### STARTING OPTIMIZED FULL DOCUMENT ANALYSIS PIPELINE ###") | |
| # print("#" * 80) | |
| # pdf_name = os.path.splitext(os.path.basename(input_pdf_path))[0] | |
| # temp_pipeline_dir = os.path.join(tempfile.gettempdir(), f"pipeline_run_{pdf_name}_{os.getpid()}") | |
| # os.makedirs(temp_pipeline_dir, exist_ok=True) | |
| # preprocessed_json_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_preprocessed.json") | |
| # raw_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_raw_predictions.json") | |
| # structured_intermediate_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_structured_intermediate.json") | |
| # final_result = None | |
| # try: | |
| # # Phase 1: Preprocessing with YOLO First + Masking | |
| # preprocessed_json_path_out = run_single_pdf_preprocessing(input_pdf_path, preprocessed_json_path) | |
| # if not preprocessed_json_path_out: return None | |
| # # Phase 2: Inference | |
| # page_raw_predictions_list = run_inference_and_get_raw_words( | |
| # input_pdf_path, layoutlmv3_model_path, preprocessed_json_path_out | |
| # ) | |
| # if not page_raw_predictions_list: return None | |
| # # --- DEBUG STEP: SAVE RAW PREDICTIONS --- | |
| # # Save raw predictions to the temporary file | |
| # with open(raw_output_path, 'w', encoding='utf-8') as f: | |
| # json.dump(page_raw_predictions_list, f, indent=4) | |
| # # Explicitly copy/save the raw predictions to the user-specified debug path | |
| # # if raw_predictions_output_path: | |
| # # shutil.copy(raw_output_path, raw_predictions_output_path) | |
| # # print(f"\nβ DEBUG: Raw predictions saved to: {raw_predictions_output_path}") | |
| # # ---------------------------------------- | |
| # # Phase 3: Decoding | |
| # structured_data_list = convert_bio_to_structured_json_relaxed( | |
| # raw_output_path, structured_intermediate_output_path | |
| # ) | |
| # if not structured_data_list: return None | |
| # structured_data_list = correct_misaligned_options(structured_data_list) | |
| # structured_data_list = process_context_linking(structured_data_list) | |
| # # Phase 4: Embedding / Equation to LaTeX Conversion | |
| # final_result = embed_images_as_base64_in_memory(structured_data_list, FIGURE_EXTRACTION_DIR) | |
| # #================================================================================ | |
| # # --- NEW FINAL STEP: HIERARCHICAL CLASSIFICATION TAGGING --- | |
| # #================================================================================ | |
| # print("\n" + "=" * 80) | |
| # print("--- FINAL STEP: HIERARCHICAL SUBJECT/CONCEPT TAGGING ---") | |
| # print("=" * 80) | |
| # # 1. Initialize and Load the Classifier | |
| # classifier = HierarchicalClassifier() | |
| # if classifier.load_models(): | |
| # # 2. Run Classification on the *Final* Result | |
| # # The function modifies the list in place and returns it | |
| # final_result = post_process_json_with_inference( | |
| # final_result, classifier | |
| # ) | |
| # print("β Classification complete. Tags added to final output.") | |
| # else: | |
| # print("β Classification model loading failed. Outputting un-tagged data.") | |
| # # ==================================================================== | |
| # except Exception as e: | |
| # print(f"β FATAL ERROR: {e}") | |
| # import traceback | |
| # traceback.print_exc() | |
| # return None | |
| # finally: | |
| # try: | |
| # for f in glob.glob(os.path.join(temp_pipeline_dir, '*')): | |
| # os.remove(f) | |
| # os.rmdir(temp_pipeline_dir) | |
| # except Exception: | |
| # pass | |
| # print("\n" + "#" * 80) | |
| # print("### OPTIMIZED PIPELINE EXECUTION COMPLETE ###") | |
| # print("#" * 80) | |
| # return final_result | |
| def classify_question_type(item: Dict[str, Any]) -> str: | |
| """ | |
| Classifies a question as 'MCQ', 'DESCRIPTIVE', or 'INTEGER' based on its options. | |
| Args: | |
| item: Dictionary containing question data with 'options' field | |
| Returns: | |
| str: 'MCQ' if options exist and are non-empty, 'DESCRIPTIVE' otherwise | |
| """ | |
| # Check if options exist and have meaningful content | |
| options = item.get('options', {}) | |
| if not options: | |
| return 'DESCRIPTIVE' | |
| # Check if options dict has keys and at least one non-empty value | |
| has_valid_options = False | |
| for key, value in options.items(): | |
| # Check if the value is more than just the key itself (e.g., "A" vs "A Some text") | |
| if value and isinstance(value, str): | |
| # Remove the key from value and check if there's remaining content | |
| remaining_text = value.replace(key, '').strip() | |
| if remaining_text and len(remaining_text) > 0: | |
| has_valid_options = True | |
| break | |
| return 'MCQ' if has_valid_options else 'DESCRIPTIVE' | |
| def add_question_type_validation(structured_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| """ | |
| Adds 'question_type' field to all question entries in the structured data. | |
| Args: | |
| structured_data: List of dictionaries containing question data | |
| Returns: | |
| List[Dict[str, Any]]: Modified list with 'question_type' field added | |
| """ | |
| print("\n" + "=" * 80) | |
| print("--- ADDING QUESTION TYPE VALIDATION ---") | |
| print("=" * 80) | |
| mcq_count = 0 | |
| descriptive_count = 0 | |
| metadata_count = 0 | |
| for item in structured_data: | |
| item_type = item.get('type', 'Question') | |
| # Skip metadata entries | |
| if item_type == 'METADATA': | |
| metadata_count += 1 | |
| item['question_type'] = 'METADATA' | |
| continue | |
| # Classify the question | |
| question_type = classify_question_type(item) | |
| item['question_type'] = question_type | |
| if question_type == 'MCQ': | |
| mcq_count += 1 | |
| else: | |
| descriptive_count += 1 | |
| print(f" β Classification Complete:") | |
| print(f" - MCQ Questions: {mcq_count}") | |
| print(f" - Descriptive/Integer Questions: {descriptive_count}") | |
| print(f" - Metadata Entries: {metadata_count}") | |
| print(f" - Total Entries: {len(structured_data)}") | |
| return structured_data | |
| import time | |
| import traceback | |
| import glob | |
| # def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str, structured_intermediate_output_path: Optional[str] = None) -> Optional[List[Dict[str, Any]]]: | |
| # if not os.path.exists(input_pdf_path): | |
| # print(f"β ERROR: File not found: {input_pdf_path}") | |
| # return None | |
| # print("\n" + "#" * 80) | |
| # print("### STARTING OPTIMIZED FULL DOCUMENT ANALYSIS PIPELINE ###") | |
| # print(f"Input: {input_pdf_path}") | |
| # print("#" * 80) | |
| # overall_start = time.time() | |
| # pdf_name = os.path.splitext(os.path.basename(input_pdf_path))[0] | |
| # temp_pipeline_dir = os.path.join(tempfile.gettempdir(), f"pipeline_run_{pdf_name}_{os.getpid()}") | |
| # os.makedirs(temp_pipeline_dir, exist_ok=True) | |
| # preprocessed_json_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_preprocessed.json") | |
| # raw_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_raw_predictions.json") | |
| # # If the user didn't provide a path, create one in the temp directory | |
| # if structured_intermediate_output_path is None: | |
| # structured_intermediate_output_path = os.path.join( | |
| # temp_pipeline_dir, f"{pdf_name}_structured_intermediate.json" | |
| # ) | |
| # final_result = None | |
| # try: | |
| # # --- Phase 1: Preprocessing --- | |
| # print(f"\n[Step 1/5] Preprocessing (YOLO + Masking)...") | |
| # p1_start = time.time() | |
| # preprocessed_json_path_out = run_single_pdf_preprocessing(input_pdf_path, preprocessed_json_path) | |
| # if not preprocessed_json_path_out: | |
| # print("β FAILED at Step 1: Preprocessing returned None.") | |
| # return None | |
| # print(f"β Step 1 Complete ({time.time() - p1_start:.2f}s)") | |
| # # --- Phase 2: Inference --- | |
| # print(f"\n[Step 2/5] Inference (LayoutLMv3)...") | |
| # p2_start = time.time() | |
| # page_raw_predictions_list = run_inference_and_get_raw_words( | |
| # input_pdf_path, layoutlmv3_model_path, preprocessed_json_path_out | |
| # ) | |
| # if not page_raw_predictions_list: | |
| # print("β FAILED at Step 2: Inference returned no data.") | |
| # return None | |
| # # Save raw predictions for Step 3 | |
| # with open(raw_output_path, 'w', encoding='utf-8') as f: | |
| # json.dump(page_raw_predictions_list, f, indent=4) | |
| # print(f"β Step 2 Complete ({time.time() - p2_start:.2f}s)") | |
| # # --- Phase 3: Decoding --- | |
| # print(f"\n[Step 3/5] Decoding (BIO to Structured JSON)...") | |
| # p3_start = time.time() | |
| # structured_data_list = convert_bio_to_structured_json_relaxed( | |
| # raw_output_path, structured_intermediate_output_path | |
| # ) | |
| # if not structured_data_list: | |
| # print("β FAILED at Step 3: BIO conversion failed.") | |
| # return None | |
| # # Logic adjustments | |
| # print("... Correcting misalignments and linking context ...") | |
| # structured_data_list = correct_misaligned_options(structured_data_list) | |
| # structured_data_list = process_context_linking(structured_data_list) | |
| # print(f"β Step 3 Complete ({time.time() - p3_start:.2f}s)") | |
| # # --- Phase 4: Base64 & LaTeX --- | |
| # print(f"\n[Step 4/5] Finalizing Layout (Base64 Images & LaTeX)...") | |
| # p4_start = time.time() | |
| # final_result = embed_images_as_base64_in_memory(structured_data_list, FIGURE_EXTRACTION_DIR) | |
| # if not final_result: | |
| # print("β FAILED at Step 4: Final formatting failed.") | |
| # return None | |
| # print(f"β Step 4 Complete ({time.time() - p4_start:.2f}s)") | |
| # # --- ADD THIS NEW STEP HERE --- | |
| # print(f"\n[Step 4.5/5] Adding Question Type Classification...") | |
| # p4_5_start = time.time() | |
| # final_result = add_question_type_validation(final_result) | |
| # print(f"β Step 4.5 Complete ({time.time() - p4_5_start:.2f}s)") | |
| # # --- END OF NEW STEP --- | |
| # # --- Phase 5: Hierarchical Tagging --- | |
| # print(f"\n[Step 5/5] AI Classification (Subject/Concept Tagging)...") | |
| # p5_start = time.time() | |
| # classifier = HierarchicalClassifier() | |
| # if classifier.load_models(): | |
| # final_result = post_process_json_with_inference(final_result, classifier) | |
| # print(f"β Step 5 Complete: Tags added ({time.time() - p5_start:.2f}s)") | |
| # else: | |
| # print("β οΈ WARNING: Classifier models failed to load. Skipping tags.") | |
| # except Exception as e: | |
| # print(f"\nβΌοΈ FATAL PIPELINE EXCEPTION:") | |
| # print(f"Error Message: {str(e)}") | |
| # traceback.print_exc() | |
| # return None | |
| # finally: | |
| # print(f"\nCleaning up temporary files in {temp_pipeline_dir}...") | |
| # try: | |
| # for f in glob.glob(os.path.join(temp_pipeline_dir, '*')): | |
| # os.remove(f) | |
| # os.rmdir(temp_pipeline_dir) | |
| # print("π§Ή Cleanup successful.") | |
| # except Exception as e: | |
| # print(f"β οΈ Cleanup failed: {e}") | |
| # total_time = time.time() - overall_start | |
| # print("\n" + "#" * 80) | |
| # print(f"### PIPELINE COMPLETE | Total Time: {total_time:.2f}s ###") | |
| # print("#" * 80) | |
| # return final_result | |
| def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str, structured_intermediate_output_path: Optional[str] = None) -> Optional[List[Dict[str, Any]]]: | |
| if not os.path.exists(input_pdf_path): | |
| print(f"β ERROR: File not found: {input_pdf_path}") | |
| return None | |
| print("\n" + "#" * 80) | |
| print("### STARTING OPTIMIZED FULL DOCUMENT ANALYSIS PIPELINE ###") | |
| print(f"Input: {input_pdf_path}") | |
| print("#" * 80) | |
| overall_start = time.time() | |
| pdf_name = os.path.splitext(os.path.basename(input_pdf_path))[0] | |
| temp_pipeline_dir = os.path.join(tempfile.gettempdir(), f"pipeline_run_{pdf_name}_{os.getpid()}") | |
| os.makedirs(temp_pipeline_dir, exist_ok=True) | |
| preprocessed_json_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_preprocessed.json") | |
| raw_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_raw_predictions.json") | |
| if structured_intermediate_output_path is None: | |
| structured_intermediate_output_path = os.path.join( | |
| temp_pipeline_dir, f"{pdf_name}_structured_intermediate.json" | |
| ) | |
| final_result = None | |
| try: | |
| # --- Phase 1: Preprocessing --- | |
| print(f"\n[Step 1/5] Preprocessing (YOLO + Masking)...") | |
| p1_start = time.time() | |
| preprocessed_json_path_out = run_single_pdf_preprocessing(input_pdf_path, preprocessed_json_path) | |
| if not preprocessed_json_path_out: | |
| print("β FAILED at Step 1: Preprocessing returned None.") | |
| return None | |
| print(f"β Step 1 Complete ({time.time() - p1_start:.2f}s)") | |
| # --- Phase 2: Inference --- | |
| print(f"\n[Step 2/5] Inference (LayoutLMv3)...") | |
| p2_start = time.time() | |
| page_raw_predictions_list = run_inference_and_get_raw_words( | |
| input_pdf_path, layoutlmv3_model_path, preprocessed_json_path_out | |
| ) | |
| if not page_raw_predictions_list: | |
| print("β FAILED at Step 2: Inference returned no data.") | |
| return None | |
| with open(raw_output_path, 'w', encoding='utf-8') as f: | |
| json.dump(page_raw_predictions_list, f, indent=4) | |
| print(f"β Step 2 Complete ({time.time() - p2_start:.2f}s)") | |
| # --- Phase 3: Decoding --- | |
| print(f"\n[Step 3/5] Decoding (BIO to Structured JSON)...") | |
| p3_start = time.time() | |
| structured_data_list = convert_bio_to_structured_json_relaxed( | |
| raw_output_path, structured_intermediate_output_path | |
| ) | |
| if not structured_data_list: | |
| print("β FAILED at Step 3: BIO conversion failed.") | |
| return None | |
| print("... Correcting misalignments and linking context ...") | |
| structured_data_list = correct_misaligned_options(structured_data_list) | |
| structured_data_list = process_context_linking(structured_data_list) | |
| print(f"β Step 3 Complete ({time.time() - p3_start:.2f}s)") | |
| # --- Phase 4: Base64 & LaTeX --- | |
| print(f"\n[Step 4/5] Finalizing Layout (Base64 Images & LaTeX)...") | |
| p4_start = time.time() | |
| final_result = embed_images_as_base64_in_memory(structured_data_list, FIGURE_EXTRACTION_DIR) | |
| if not final_result: | |
| print("β FAILED at Step 4: Final formatting failed.") | |
| return None | |
| print(f"β Step 4 Complete ({time.time() - p4_start:.2f}s)") | |
| # --- Phase 4.5: Question Type Classification --- | |
| print(f"\n[Step 4.5/5] Adding Question Type Classification...") | |
| p4_5_start = time.time() | |
| final_result = add_question_type_validation(final_result) | |
| print(f"β Step 4.5 Complete ({time.time() - p4_5_start:.2f}s)") | |
| # --- Phase 5: Hierarchical Tagging --- | |
| print(f"\n[Step 5/5] AI Classification (Subject/Concept Tagging)...") | |
| p5_start = time.time() | |
| classifier = HierarchicalClassifier() | |
| if classifier.load_models(): | |
| final_result = post_process_json_with_inference(final_result, classifier) | |
| print(f"β Step 5 Complete: Tags added ({time.time() - p5_start:.2f}s)") | |
| else: | |
| print("β οΈ WARNING: Classifier models failed to load. Skipping tags.") | |
| # ============================================================ | |
| # π§ NEW STEP: FILTER OUT METADATA ENTRIES | |
| # ============================================================ | |
| print(f"\n[Post-Processing] Removing METADATA entries...") | |
| initial_count = len(final_result) | |
| final_result = [item for item in final_result if item.get('type') != 'METADATA'] | |
| removed_count = initial_count - len(final_result) | |
| print(f"β Removed {removed_count} METADATA entries. {len(final_result)} questions remain.") | |
| # ============================================================ | |
| except Exception as e: | |
| print(f"\nβΌοΈ FATAL PIPELINE EXCEPTION:") | |
| print(f"Error Message: {str(e)}") | |
| traceback.print_exc() | |
| return None | |
| finally: | |
| print(f"\nCleaning up temporary files in {temp_pipeline_dir}...") | |
| try: | |
| for f in glob.glob(os.path.join(temp_pipeline_dir, '*')): | |
| os.remove(f) | |
| os.rmdir(temp_pipeline_dir) | |
| print("π§Ή Cleanup successful.") | |
| except Exception as e: | |
| print(f"β οΈ Cleanup failed: {e}") | |
| total_time = time.time() - overall_start | |
| print("\n" + "#" * 80) | |
| print(f"### PIPELINE COMPLETE | Total Time: {total_time:.2f}s ###") | |
| print("#" * 80) | |
| return final_result | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Complete Pipeline") | |
| parser.add_argument("--input_pdf", type=str, required=True, help="Input PDF") | |
| parser.add_argument("--layoutlmv3_model_path", type=str, default=DEFAULT_LAYOUTLMV3_MODEL_PATH, help="Model Path") | |
| # --- ADDED ARGUMENT FOR DEBUGGING --- | |
| parser.add_argument("--raw_preds_path", type=str, default='BIO_debug.json', | |
| help="Debug path for raw BIO tag predictions (JSON).") | |
| # ------------------------------------ | |
| args = parser.parse_args() | |
| pdf_name = os.path.splitext(os.path.basename(args.input_pdf))[0] | |
| final_output_path = os.path.abspath(f"{pdf_name}_final_output_embedded.json") | |
| # --- CALCULATE RAW PREDICTIONS OUTPUT PATH (Kept commented as per original script) --- | |
| # raw_predictions_output_path = os.path.abspath( | |
| # args.raw_preds_path if args.raw_preds_path else f"{pdf_name}_raw_predictions_debug.json") | |
| # --------------------------------------------- | |
| # --- UPDATED FUNCTION CALL --- | |
| final_json_data = run_document_pipeline( | |
| args.input_pdf, | |
| args.layoutlmv3_model_path ) | |
| # ----------------------------- | |
| # π CRITICAL FINAL FIX: AGGRESSIVE CUSTOM JSON SAVING π | |
| if final_json_data: | |
| # 1. Dump the Python object to a standard JSON string. | |
| # This converts the in-memory double backslash ('\\') into a quadruple backslash ('\\\\') | |
| # in the raw json_str string content. | |
| json_str = json.dumps(final_json_data, indent=2, ensure_ascii=False) | |
| # 2. **AGGRESSIVE UNDO ESCAPING:** We assume we have quadruple backslashes and | |
| # replace them with the double backslashes needed for the LaTeX command to work. | |
| # This operation essentially replaces four literal backslashes with two literal backslashes. | |
| # final_output_content = json_str.replace('\\\\\\\\', '\\\\') | |
| # 3. Write the corrected string content to the file. | |
| with open(final_output_path, 'w', encoding='utf-8') as f: | |
| f.write(json_str) | |
| print(f"\nβ Final Data Saved: {final_output_path}") | |
| else: | |
| print("\nβ Pipeline Failed.") | |
| sys.exit(1) |