| # import base64 | |
| # from PIL import Image | |
| # import re | |
| # import fitz # PyMuPDF | |
| # import numpy as np | |
| # import cv2 | |
| # import torch | |
| # import torch.serialization | |
| # import os | |
| # import time | |
| # from typing import Optional, Tuple, List, Dict, Any | |
| # from ultralytics import YOLO | |
| # import logging | |
| # import gradio as gr | |
| # import shutil | |
| # import tempfile | |
| # import io | |
| # # ============================================================================ | |
| # # --- Global Patches and Setup --- | |
| # # ============================================================================ | |
| # # Patch torch.load to prevent weights_only error with older models | |
| # _original_torch_load = torch.load | |
| # def patched_torch_load(*args, **kwargs): | |
| # kwargs["weights_only"] = False | |
| # return _original_torch_load(*args, **kwargs) | |
| # torch.load = patched_torch_load | |
| # logging.basicConfig(level=logging.WARNING) | |
| # # ============================================================================ | |
| # # --- CONFIGURATION AND CONSTANTS --- | |
| # # ============================================================================ | |
| # WEIGHTS_PATH = 'best.pt' | |
| # SCALE_FACTOR = 2.0 | |
| # # OUTPUT_DIR = "yolo_extracted_regions" | |
| # # OUTPUT_DIR = os.path.join(tempfile.gettempdir(), "yolo_extracted_regions") | |
| # from transformers import TrOCRProcessor | |
| # from optimum.onnxruntime import ORTModelForVision2Seq | |
| # MODEL_NAME = 'breezedeus/pix2text-mfr-1.5' | |
| # processor = TrOCRProcessor.from_pretrained(MODEL_NAME) | |
| # ort_model = ORTModelForVision2Seq.from_pretrained(MODEL_NAME, use_cache=False) | |
| # # Detection parameters | |
| # CONF_THRESHOLD = 0.2 | |
| # TARGET_CLASSES = ['figure', 'equation'] | |
| # IOU_MERGE_THRESHOLD = 0.4 | |
| # IOA_SUPPRESSION_THRESHOLD = 0.7 | |
| # # Global counters (Reset per run) | |
| # GLOBAL_FIGURE_COUNT = 0 | |
| # GLOBAL_EQUATION_COUNT = 0 | |
| # # ============================================================================ | |
| # # --- BOX COMBINATION LOGIC (Retained for detection accuracy) --- | |
| # # ============================================================================ | |
| # def calculate_iou(box1, box2): | |
| # x1_a, y1_a, x2_a, y2_a = box1 | |
| # x1_b, y1_b, x2_b, y2_b = box2 | |
| # x_left = max(x1_a, x1_b) | |
| # y_top = max(y1_a, y1_b) | |
| # x_right = min(x2_a, x2_b) | |
| # y_bottom = min(y2_a, y2_b) | |
| # intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top) | |
| # box_a_area = (x2_a - x1_a) * (y2_a - y1_a) | |
| # box_b_area = (x2_b - x1_b) * (y2_b - y1_b) | |
| # union_area = float(box_a_area + box_b_area - intersection_area) | |
| # return intersection_area / union_area if union_area > 0 else 0 | |
| # def filter_nested_boxes(detections, ioa_threshold=0.80): | |
| # if not detections: return [] | |
| # for d in detections: | |
| # x1, y1, x2, y2 = d['coords'] | |
| # d['area'] = (x2 - x1) * (y2 - y1) | |
| # detections.sort(key=lambda x: x['area'], reverse=True) | |
| # keep_indices = [] | |
| # is_suppressed = [False] * len(detections) | |
| # for i in range(len(detections)): | |
| # if is_suppressed[i]: continue | |
| # keep_indices.append(i) | |
| # box_a = detections[i]['coords'] | |
| # for j in range(i + 1, len(detections)): | |
| # if is_suppressed[j]: continue | |
| # box_b = detections[j]['coords'] | |
| # x_left = max(box_a[0], box_b[0]) | |
| # y_top = max(box_a[1], box_b[1]) | |
| # x_right = min(box_a[2], box_b[2]) | |
| # y_bottom = min(box_a[3], box_b[3]) | |
| # intersection = max(0, x_right - x_left) * max(0, y_bottom - y_top) | |
| # area_b = detections[j]['area'] | |
| # if area_b > 0 and intersection / area_b > ioa_threshold: | |
| # is_suppressed[j] = True | |
| # return [detections[i] for i in keep_indices] | |
| # def merge_overlapping_boxes(detections, iou_threshold): | |
| # if not detections: return [] | |
| # detections.sort(key=lambda d: d['conf'], reverse=True) | |
| # merged_detections = [] | |
| # is_merged = [False] * len(detections) | |
| # for i in range(len(detections)): | |
| # if is_merged[i]: continue | |
| # current_box = detections[i]['coords'] | |
| # current_class = detections[i]['class'] | |
| # merged_x1, merged_y1, merged_x2, merged_y2 = current_box | |
| # for j in range(i + 1, len(detections)): | |
| # if is_merged[j] or detections[j]['class'] != current_class: continue | |
| # other_box = detections[j]['coords'] | |
| # iou = calculate_iou(current_box, other_box) | |
| # if iou > iou_threshold: | |
| # merged_x1 = min(merged_x1, other_box[0]) | |
| # merged_y1 = min(merged_y1, other_box[1]) | |
| # merged_x2 = max(merged_x2, other_box[2]) | |
| # merged_y2 = max(merged_y2, other_box[3]) | |
| # is_merged[j] = True | |
| # merged_detections.append({ | |
| # 'coords': (merged_x1, merged_y1, merged_x2, merged_y2), | |
| # 'y1': merged_y1, 'class': current_class, 'conf': detections[i]['conf'] | |
| # }) | |
| # return merged_detections | |
| # # ============================================================================ | |
| # # --- UTILITY FUNCTIONS --- | |
| # # ============================================================================ | |
| # def pixmap_to_numpy(pix: fitz.Pixmap) -> np.ndarray: | |
| # """Converts a PyMuPDF Pixmap to a NumPy array for OpenCV/YOLO.""" | |
| # img = np.frombuffer(pix.samples, dtype=np.uint8).reshape( | |
| # (pix.h, pix.w, pix.n) | |
| # ) | |
| # if pix.n == 4: | |
| # img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) | |
| # elif pix.n == 1: | |
| # img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |
| # return img | |
| # def run_yolo_detection_and_count( | |
| # image: np.ndarray, model: YOLO, page_num: int | |
| # ) -> Tuple[int, int, List[Dict[str, str]]]: | |
| # global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT | |
| # yolo_detections = [] | |
| # page_equations = 0 | |
| # page_figures = 0 | |
| # detected_items = [] | |
| # try: | |
| # results = model.predict(image, conf=CONF_THRESHOLD, verbose=False) | |
| # if results and results[0].boxes: | |
| # for box in results[0].boxes.data.tolist(): | |
| # x1, y1, x2, y2, conf, cls_id = box | |
| # cls_name = model.names[int(cls_id)] | |
| # if cls_name in TARGET_CLASSES: | |
| # yolo_detections.append({ | |
| # 'coords': (x1, y1, x2, y2), | |
| # 'class': cls_name, | |
| # 'conf': conf | |
| # }) | |
| # except Exception as e: | |
| # logging.error(f"YOLO inference failed on page {page_num}: {e}") | |
| # return 0, 0, [] | |
| # merged_detections = merge_overlapping_boxes(yolo_detections, IOU_MERGE_THRESHOLD) | |
| # final_detections = filter_nested_boxes(merged_detections, IOA_SUPPRESSION_THRESHOLD) | |
| # for det in final_detections: | |
| # bbox = det["coords"] | |
| # if det["class"] == "equation": | |
| # GLOBAL_EQUATION_COUNT += 1 | |
| # page_equations += 1 | |
| # b64 = crop_and_convert_to_base64(image, bbox) | |
| # detected_items.append({ | |
| # "type": "equation", | |
| # "id": f"EQUATION{GLOBAL_EQUATION_COUNT}", | |
| # "base64": b64 | |
| # }) | |
| # elif det["class"] == "figure": | |
| # GLOBAL_FIGURE_COUNT += 1 | |
| # page_figures += 1 | |
| # b64 = crop_and_convert_to_base64(image, bbox) | |
| # detected_items.append({ | |
| # "type": "figure", | |
| # "id": f"FIGURE{GLOBAL_FIGURE_COUNT}", | |
| # "base64": b64 | |
| # }) | |
| # logging.warning(f" -> Page {page_num}: EQs={page_equations}, Figs={page_figures}") | |
| # return page_equations, page_figures, detected_items | |
| # def get_latex_from_base64(base64_string: str) -> str: | |
| # if ort_model is None or processor is None: | |
| # return "[MODEL_ERROR: Model not initialized]" | |
| # try: | |
| # image_data = base64.b64decode(base64_string) | |
| # image = Image.open(io.BytesIO(image_data)).convert('RGB') | |
| # pixel_values = processor(images=image, return_tensors="pt").pixel_values | |
| # generated_ids = ort_model.generate(pixel_values) | |
| # raw_text = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
| # if not raw_text: | |
| # return "[OCR_WARNING: No formula found]" | |
| # latex = raw_text[0] | |
| # latex = re.sub(r'[\r\n]+', '', latex) | |
| # return latex | |
| # except Exception as e: | |
| # return f"[TR_OCR_ERROR: {e}]" | |
| # def extract_images_from_page_in_memory(page) -> Dict[str, str]: | |
| # """ | |
| # Extract images from a page and return: | |
| # { "EQUATION1": base64_string, "FIGURE1": base64_string } | |
| # """ | |
| # image_map = {} | |
| # image_list = page.get_images(full=True) | |
| # for idx, img in enumerate(image_list, start=1): | |
| # xref = img[0] | |
| # base = page.parent.extract_image(xref) | |
| # image_bytes = base["image"] | |
| # base64_img = base64.b64encode(image_bytes).decode("utf-8") | |
| # # Convention: first image = FIGURE1, second image = EQUATION1 etc | |
| # # You can tune this if needed | |
| # image_map[f"FIGURE{idx}"] = base64_img | |
| # return image_map | |
| # def embed_images_as_base64_in_memory(structured_data, detected_items): | |
| # tag_regex = re.compile(r'(figure|equation)(\d+)', re.IGNORECASE) | |
| # item_lookup = {d["id"]: d for d in detected_items} | |
| # final_data = [] | |
| # for item in structured_data: | |
| # text_fields = [ | |
| # item.get('question', ''), | |
| # item.get('passage', ''), | |
| # item.get('new_passage', '') | |
| # ] | |
| # if 'options' in item: | |
| # text_fields.extend(item['options'].values()) | |
| # used_tags = set() | |
| # for text in text_fields: | |
| # for m in tag_regex.finditer(text or ""): | |
| # used_tags.add(m.group(0).upper()) | |
| # for tag in used_tags: | |
| # base_key = tag.lower().replace(" ", "") | |
| # if tag not in item_lookup: | |
| # item[base_key] = "[MISSING_IMAGE]" | |
| # continue | |
| # entry = item_lookup[tag] | |
| # if entry["type"] == "equation": | |
| # item[base_key] = get_latex_from_base64(entry["base64"]) | |
| # else: | |
| # item[base_key] = entry["base64"] | |
| # final_data.append(item) | |
| # return final_data | |
| # def crop_and_convert_to_base64(image: np.ndarray, bbox: Tuple[float, float, float, float]) -> str: | |
| # x1, y1, x2, y2 = map(int, bbox) | |
| # h, w, _ = image.shape | |
| # x1 = max(0, x1) | |
| # y1 = max(0, y1) | |
| # x2 = min(w, x2) | |
| # y2 = min(h, y2) | |
| # crop = image[y1:y2, x1:x2] | |
| # _, buffer = cv2.imencode(".png", crop) | |
| # return base64.b64encode(buffer).decode("utf-8") | |
| # # ============================================================================ | |
| # # --- MAIN DOCUMENT PROCESSING FUNCTION (Fixed for JSON serialization) --- | |
| # # ============================================================================ | |
| # # NOTE: The return signature now uses Dict[str, int] for the equation counts | |
| # def run_single_pdf_preprocessing(pdf_path: str) -> Tuple[int, int, int, str, float, Dict[str, int], List[str]]: | |
| # """ | |
| # Runs the pipeline, returns counts, report, total time, page counts dict (str keys), and empty list. | |
| # """ | |
| # global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT | |
| # start_time = time.time() | |
| # log_messages = [] | |
| # all_saved_images = [] | |
| # all_base64_images: List[str] = [] | |
| # # Dictionary to store {page_number (int): equation_count (int)} | |
| # equation_counts_per_page: Dict[int, int] = {} | |
| # # Reset globals | |
| # GLOBAL_FIGURE_COUNT = 0 | |
| # GLOBAL_EQUATION_COUNT = 0 | |
| # # if os.path.exists(OUTPUT_DIR): | |
| # # shutil.rmtree(OUTPUT_DIR) | |
| # # os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # # 1. Validation and Model Loading | |
| # t0 = time.time() | |
| # if not os.path.exists(pdf_path): | |
| # report = f"β FATAL ERROR: Input PDF not found at {pdf_path}." | |
| # return 0, 0, 0, report, time.time() - start_time, {}, [] | |
| # try: | |
| # model = YOLO(WEIGHTS_PATH) | |
| # logging.warning(f"β Loaded YOLO model from: {WEIGHTS_PATH}") | |
| # except Exception as e: | |
| # report = f"β ERROR loading YOLO model: {e}\n(Ensure 'best.pt' is available and valid.)" | |
| # return 0, 0, 0, report, time.time() - start_time, {}, [] | |
| # t1 = time.time() | |
| # log_messages.append(f"Model Loading Time: {t1-t0:.4f}s") | |
| # # 2. PDF Loading | |
| # t2 = time.time() | |
| # try: | |
| # doc = fitz.open(pdf_path) | |
| # total_pages = doc.page_count | |
| # logging.warning(f"β Opened PDF with {doc.page_count} pages") | |
| # except Exception as e: | |
| # report = f"β ERROR loading PDF file: {e}" | |
| # return 0, 0, 0, report, time.time() - start_time, {}, [] | |
| # t3 = time.time() | |
| # log_messages.append(f"PDF Initialization Time: {t3-t2:.4f}s") | |
| # mat = fitz.Matrix(SCALE_FACTOR, SCALE_FACTOR) | |
| # # 3. Page Processing and Detection Loop | |
| # t4 = time.time() | |
| # for page_num_0_based in range(doc.page_count): | |
| # page_start_time = time.time() | |
| # fitz_page = doc.load_page(page_num_0_based) | |
| # page_num = page_num_0_based + 1 | |
| # # Render page to image for YOLO | |
| # try: | |
| # pix_start = time.time() | |
| # pix = fitz_page.get_pixmap(matrix=mat) | |
| # original_img = pixmap_to_numpy(pix) | |
| # pix_time = time.time() - pix_start | |
| # except Exception as e: | |
| # logging.error(f"Error converting page {page_num} to image: {e}. Skipping.") | |
| # continue | |
| # # Core Detection | |
| # detect_start = time.time() | |
| # # page_equations, _ = run_yolo_detection_and_count(original_img, model, page_num) | |
| # page_equations, _, page_images = run_yolo_detection_and_count(original_img, model, page_num) | |
| # all_saved_images.extend(page_images) | |
| # detect_time = time.time() - detect_start | |
| # # Store the count in the dictionary (INT keys) | |
| # equation_counts_per_page[page_num] = page_equations | |
| # page_total_time = time.time() - page_start_time | |
| # log_messages.append(f"Page {page_num} Time: Total={page_total_time:.4f}s (Render={pix_time:.4f}s, Detect={detect_time:.4f}s)") | |
| # doc.close() | |
| # t5 = time.time() | |
| # detection_loop_time = t5 - t4 | |
| # log_messages.append(f"Total Detection Loop Time ({total_pages} pages): {detection_loop_time:.4f}s") | |
| # # FIX APPLIED HERE: Convert integer keys to string keys for JSON serialization | |
| # equation_counts_per_page_str_keys: Dict[str, int] = { | |
| # str(k): v for k, v in equation_counts_per_page.items() | |
| # } | |
| # # 4. Final Report Generation | |
| # total_execution_time = t5 - start_time | |
| # report = ( | |
| # f"β **YOLO Counting Complete!**\n\n" | |
| # f"**1) Total Pages Detected in PDF:** **{total_pages}**\n" | |
| # f"**2) Total Equations Detected:** **{GLOBAL_EQUATION_COUNT}**\n" | |
| # f"**3) Total Figures Detected:** **{GLOBAL_FIGURE_COUNT}**\n" | |
| # f"---\n" | |
| # f"**4) Total Execution Time:** **{total_execution_time:.4f}s**\n" | |
| # f"### Detailed Step Timing\n" | |
| # f"```\n" | |
| # + "\n".join(log_messages) + | |
| # f"\n```" | |
| # ) | |
| # # Return the dictionary with string keys | |
| # # return total_pages, GLOBAL_EQUATION_COUNT, GLOBAL_FIGURE_COUNT, report, total_execution_time, equation_counts_per_page_str_keys, [] | |
| # return total_pages, GLOBAL_EQUATION_COUNT, GLOBAL_FIGURE_COUNT, report, total_execution_time, equation_counts_per_page_str_keys, all_saved_images | |
| # # ============================================================================ | |
| # # --- GRADIO INTERFACE FUNCTION (Updated) --- | |
| # # ============================================================================ | |
| # def gradio_process_pdf(pdf_file) -> Tuple[str, str, str, str, Dict[str, int], List[str]]: | |
| # """ | |
| # Gradio wrapper function to handle file upload and return results. | |
| # """ | |
| # if pdf_file is None: | |
| # # Return an empty dict with string keys | |
| # return "N/A", "N/A", "N/A", "Please upload a PDF file.", {}, [] | |
| # pdf_path = pdf_file.name | |
| # try: | |
| # # Unpack the new return value: equation_counts_per_page (with string keys) | |
| # # num_pages, num_equations, num_figures, report, total_time, equation_counts_per_page, _ = run_single_pdf_preprocessing( | |
| # # pdf_path | |
| # # ) | |
| # # num_pages, num_equations, num_figures, report, total_time, equation_counts_per_page, images = run_single_pdf_preprocessing(pdf_path) | |
| # num_pages, num_equations, num_figures, report, total_time, equation_counts_per_page, images = run_single_pdf_preprocessing(pdf_path) | |
| # # Return results (6 items now) | |
| # # return str(num_pages), str(num_equations), str(num_figures), report, equation_counts_per_page, [] | |
| # return str(num_pages), str(num_equations), str(num_figures), report, equation_counts_per_page, images | |
| # except Exception as e: | |
| # error_msg = f"An unexpected error occurred: {e}" | |
| # logging.error(error_msg, exc_info=True) | |
| # # Return an empty dict on error | |
| # return "Error", "Error", "Error", error_msg, {}, [] | |
| # # ============================================================================ | |
| # # --- GRADIO INTERFACE DEFINITION (Updated) --- | |
| # # ============================================================================ | |
| # if __name__ == "__main__": | |
| # if not os.path.exists(WEIGHTS_PATH): | |
| # logging.error(f"β FATAL ERROR: YOLO weight file '{WEIGHTS_PATH}' not found. Cannot run live inference.") | |
| # input_file = gr.File(label="Upload PDF Document", type="filepath", file_types=[".pdf"]) | |
| # # Outputs | |
| # output_pages = gr.Textbox(label="Total Pages in PDF", interactive=False) | |
| # output_equations = gr.Textbox(label="Total Equations Detected", interactive=False) | |
| # output_figures = gr.Textbox(label="Total Figures Detected", interactive=False) | |
| # output_report = gr.Markdown(label="Processing Summary and Timing") | |
| # # NEW OUTPUT: JSON component for structured data | |
| # output_page_counts = gr.JSON(label="Equation Count Per Page (Dictionary)") | |
| # # Gradio Gallery is retained but will receive an empty list [] | |
| # output_gallery = gr.Gallery( | |
| # label="Detected Equations (Disabled for Speed)", | |
| # columns=5, | |
| # height="auto", | |
| # object_fit="contain", | |
| # allow_preview=False | |
| # ) | |
| # interface = gr.Interface( | |
| # fn=gradio_process_pdf, | |
| # inputs=input_file, | |
| # # Outputs list remains the same, but the JSON component now receives string keys. | |
| # outputs=[ | |
| # output_pages, | |
| # output_equations, | |
| # output_figures, | |
| # output_report, | |
| # output_page_counts, | |
| # output_gallery | |
| # ], | |
| # title="π YOLO Counting with Per-Page Data & Timing", | |
| # description=( | |
| # "Upload a PDF to run YOLO detection. The results include total counts, a breakdown of " | |
| # "equation counts per page (in JSON format), and detailed timing." | |
| # ), | |
| # ) | |
| # print("\nStarting Gradio application...") | |
| # # interface.launch(inbrowser=True) | |
| # interface.launch( | |
| # inbrowser=True, | |
| # # allowed_paths=[OUTPUT_DIR] | |
| # ) | |
| import base64 | |
| from PIL import Image | |
| import re | |
| import fitz # PyMuPDF | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import torch.serialization | |
| import os | |
| import time | |
| from typing import Optional, Tuple, List, Dict, Any, Union | |
| from ultralytics import YOLO | |
| import logging | |
| import gradio as gr | |
| import shutil | |
| import tempfile | |
| import io | |
| # ============================================================================ | |
| # --- Global Patches and Setup --- | |
| # ============================================================================ | |
| # Patch torch.load to prevent weights_only error with older models | |
| _original_torch_load = torch.load | |
| def patched_torch_load(*args, **kwargs): | |
| kwargs["weights_only"] = False | |
| return _original_torch_load(*args, **kwargs) | |
| torch.load = patched_torch_load | |
| logging.basicConfig(level=logging.WARNING) | |
| # ============================================================================ | |
| # --- CONFIGURATION AND CONSTANTS --- | |
| # ============================================================================ | |
| WEIGHTS_PATH = 'best.pt' | |
| SCALE_FACTOR = 2.0 | |
| # --- OCR Model Initialization (Retained but not used in the main loop for counting) --- | |
| from transformers import TrOCRProcessor | |
| from optimum.onnxruntime import ORTModelForVision2Seq | |
| MODEL_NAME = 'breezedeus/pix2text-mfr-1.5' | |
| # Note: These models are kept global but unused in the main flow, | |
| # as the user did not explicitly ask to remove the heavy OCR dependency yet. | |
| try: | |
| processor = TrOCRProcessor.from_pretrained(MODEL_NAME) | |
| ort_model = ORTModelForVision2Seq.from_pretrained(MODEL_NAME, use_cache=False) | |
| except Exception as e: | |
| logging.warning(f"OCR model loading failed (expected if dependencies are missing): {e}") | |
| processor = None | |
| ort_model = None | |
| # Detection parameters | |
| CONF_THRESHOLD = 0.2 | |
| TARGET_CLASSES = ['figure', 'equation'] | |
| IOU_MERGE_THRESHOLD = 0.4 | |
| IOA_SUPPRESSION_THRESHOLD = 0.7 | |
| # --- REMOVED GLOBAL COUNTERS --- | |
| # GLOBAL_FIGURE_COUNT = 0 | |
| # GLOBAL_EQUATION_COUNT = 0 | |
| # ============================================================================ | |
| # --- BOX COMBINATION LOGIC (Retained) --- | |
| # ============================================================================ | |
| def calculate_iou(box1, box2): | |
| x1_a, y1_a, x2_a, y2_a = box1 | |
| x1_b, y1_b, x2_b, y2_b = box2 | |
| x_left = max(x1_a, x1_b) | |
| y_top = max(y1_a, y1_b) | |
| x_right = min(x2_a, x2_b) | |
| y_bottom = min(y2_a, y2_b) | |
| intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top) | |
| box_a_area = (x2_a - x1_a) * (y2_a - y1_a) | |
| box_b_area = (x2_b - x1_b) * (y2_b - y1_b) | |
| union_area = float(box_a_area + box_b_area - intersection_area) | |
| return intersection_area / union_area if union_area > 0 else 0 | |
| def filter_nested_boxes(detections, ioa_threshold=0.80): | |
| if not detections: return [] | |
| for d in detections: | |
| x1, y1, x2, y2 = d['coords'] | |
| d['area'] = (x2 - x1) * (y2 - y1) | |
| detections.sort(key=lambda x: x['area'], reverse=True) | |
| keep_indices = [] | |
| is_suppressed = [False] * len(detections) | |
| for i in range(len(detections)): | |
| if is_suppressed[i]: continue | |
| keep_indices.append(i) | |
| box_a = detections[i]['coords'] | |
| for j in range(i + 1, len(detections)): | |
| if is_suppressed[j]: continue | |
| box_b = detections[j]['coords'] | |
| x_left = max(box_a[0], box_b[0]) | |
| y_top = max(box_a[1], box_b[1]) | |
| x_right = min(box_a[2], box_b[2]) | |
| y_bottom = min(box_a[3], box_b[3]) | |
| intersection = max(0, x_right - x_left) * max(0, y_bottom - y_top) | |
| area_b = detections[j]['area'] | |
| if area_b > 0 and intersection / area_b > ioa_threshold: | |
| is_suppressed[j] = True | |
| return [detections[i] for i in keep_indices] | |
| def merge_overlapping_boxes(detections, iou_threshold): | |
| if not detections: return [] | |
| detections.sort(key=lambda d: d['conf'], reverse=True) | |
| merged_detections = [] | |
| is_merged = [False] * len(detections) | |
| for i in range(len(detections)): | |
| if is_merged[i]: continue | |
| current_box = detections[i]['coords'] | |
| current_class = detections[i]['class'] | |
| merged_x1, merged_y1, merged_x2, merged_y2 = current_box | |
| for j in range(i + 1, len(detections)): | |
| if is_merged[j] or detections[j]['class'] != current_class: continue | |
| other_box = detections[j]['coords'] | |
| iou = calculate_iou(current_box, other_box) | |
| if iou > iou_threshold: | |
| merged_x1 = min(merged_x1, other_box[0]) | |
| merged_y1 = min(merged_y1, other_box[1]) | |
| merged_x2 = max(merged_x2, other_box[2]) | |
| merged_y2 = max(other_box[3], other_box[3]) | |
| is_merged[j] = True | |
| merged_detections.append({ | |
| 'coords': (merged_x1, merged_y1, merged_x2, merged_y2), | |
| 'y1': merged_y1, 'class': current_class, 'conf': detections[i]['conf'] | |
| }) | |
| return merged_detections | |
| # ============================================================================ | |
| # --- UTILITY FUNCTIONS --- | |
| # ============================================================================ | |
| def pixmap_to_numpy(pix: fitz.Pixmap) -> np.ndarray: | |
| """Converts a PyMuPDF Pixmap to a NumPy array for OpenCV/YOLO.""" | |
| img = np.frombuffer(pix.samples, dtype=np.uint8).reshape( | |
| (pix.h, pix.w, pix.n) | |
| ) | |
| if pix.n == 4: | |
| img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) | |
| elif pix.n == 1: | |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |
| return img | |
| def crop_and_convert_to_base64(image: np.ndarray, bbox: Tuple[float, float, float, float]) -> str: | |
| x1, y1, x2, y2 = map(int, bbox) | |
| h, w, _ = image.shape | |
| x1 = max(0, x1) | |
| y1 = max(0, y1) | |
| x2 = min(w, x2) | |
| y2 = min(h, y2) | |
| crop = image[y1:y2, x1:x2] | |
| _, buffer = cv2.imencode(".png", crop) | |
| return base64.b64encode(buffer).decode("utf-8") | |
| # --- NEW: Function to format base64 for Gradio Gallery --- | |
| def base64_to_gradio_gallery_tuple(base64_str: str, label: str) -> Tuple[str, str]: | |
| """Converts raw base64 to a data URI tuple for Gradio Gallery.""" | |
| # Format: ('data:image/png;base64,...', 'label') | |
| return (f"data:image/png;base64,{base64_str}", label) | |
| # --- UPDATED: run_yolo_detection_and_count to use passed counters --- | |
| def run_yolo_detection_and_count( | |
| image: np.ndarray, model: YOLO, page_num: int, | |
| current_eq_count: int, current_fig_count: int | |
| ) -> Tuple[int, int, List[Dict[str, str]], int, int]: | |
| """ | |
| Performs YOLO detection and returns page counts, detected items, | |
| and the updated global counters. | |
| """ | |
| # Use the passed counters as starting points for this page | |
| eq_counter = current_eq_count | |
| fig_counter = current_fig_count | |
| page_equations = 0 | |
| page_figures = 0 | |
| detected_items = [] | |
| yolo_detections = [] | |
| try: | |
| results = model.predict(image, conf=CONF_THRESHOLD, verbose=False) | |
| if results and results[0].boxes: | |
| for box in results[0].boxes.data.tolist(): | |
| x1, y1, x2, y2, conf, cls_id = box | |
| cls_name = model.names[int(cls_id)] | |
| if cls_name in TARGET_CLASSES: | |
| yolo_detections.append({ | |
| 'coords': (x1, y1, x2, y2), | |
| 'class': cls_name, | |
| 'conf': conf | |
| }) | |
| except Exception as e: | |
| logging.error(f"YOLO inference failed on page {page_num}: {e}") | |
| return 0, 0, [], eq_counter, fig_counter | |
| merged_detections = merge_overlapping_boxes(yolo_detections, IOU_MERGE_THRESHOLD) | |
| final_detections = filter_nested_boxes(merged_detections, IOA_SUPPRESSION_THRESHOLD) | |
| for det in final_detections: | |
| bbox = det["coords"] | |
| if det["class"] == "equation": | |
| eq_counter += 1 | |
| page_equations += 1 | |
| b64 = crop_and_convert_to_base64(image, bbox) | |
| detected_items.append({ | |
| "type": "equation", | |
| "id": f"EQUATION{eq_counter}", | |
| "base64": b64 | |
| }) | |
| elif det["class"] == "figure": | |
| fig_counter += 1 | |
| page_figures += 1 | |
| b64 = crop_and_convert_to_base64(image, bbox) | |
| detected_items.append({ | |
| "type": "figure", | |
| "id": f"FIGURE{fig_counter}", | |
| "base64": b64 | |
| }) | |
| logging.warning(f" -> Page {page_num}: EQs={page_equations}, Figs={page_figures}") | |
| # Return page counts, detected items, and the UPDATED total counters | |
| return page_equations, page_figures, detected_items, eq_counter, fig_counter | |
| # --- Other unused functions (get_latex_from_base64, etc.) are kept but not modified as | |
| # the focus is on the concurrency and Gradio Gallery fix. --- | |
| def get_latex_from_base64(base64_string: str) -> str: | |
| if ort_model is None or processor is None: | |
| return "[MODEL_ERROR: Model not initialized]" | |
| try: | |
| image_data = base64.b64decode(base64_string) | |
| image = Image.open(io.BytesIO(image_data)).convert('RGB') | |
| pixel_values = processor(images=image, return_tensors="pt").pixel_values | |
| generated_ids = ort_model.generate(pixel_values) | |
| raw_text = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
| if not raw_text: | |
| return "[OCR_WARNING: No formula found]" | |
| latex = raw_text[0] | |
| latex = re.sub(r'[\r\n]+', '', latex) | |
| return latex | |
| except Exception as e: | |
| return f"[TR_OCR_ERROR: {e}]" | |
| def embed_images_as_base64_in_memory(structured_data, detected_items): | |
| tag_regex = re.compile(r'(figure|equation)(\d+)', re.IGNORECASE) | |
| item_lookup = {d["id"]: d for d in detected_items} | |
| final_data = [] | |
| for item in structured_data: | |
| text_fields = [ | |
| item.get('question', ''), | |
| item.get('passage', ''), | |
| item.get('new_passage', '') | |
| ] | |
| if 'options' in item: | |
| text_fields.extend(item['options'].values()) | |
| used_tags = set() | |
| for text in text_fields: | |
| for m in tag_regex.finditer(text or ""): | |
| used_tags.add(m.group(0).upper()) | |
| for tag in used_tags: | |
| base_key = tag.lower().replace(" ", "") | |
| if tag not in item_lookup: | |
| item[base_key] = "[MISSING_IMAGE]" | |
| continue | |
| entry = item_lookup[tag] | |
| if entry["type"] == "equation": | |
| item[base_key] = get_latex_from_base64(entry["base64"]) | |
| else: | |
| item[base_key] = entry["base64"] | |
| final_data.append(item) | |
| return final_data | |
| # ============================================================================ | |
| # --- MAIN DOCUMENT PROCESSING FUNCTION (Fixed for concurrency) --- | |
| # ============================================================================ | |
| # --- UPDATED return type for clarity --- | |
| def run_single_pdf_preprocessing( | |
| pdf_path: str | |
| ) -> Tuple[int, int, int, str, float, Dict[str, int], List[Tuple[str, str]]]: | |
| """ | |
| Runs the pipeline, returns counts, report, total time, page counts dict (str keys), | |
| and a list of (image_data_uri, label) for the Gradio gallery. | |
| """ | |
| # --- INITIALIZE LOCAL COUNTERS --- | |
| start_time = time.time() | |
| log_messages = [] | |
| # This list now holds (data_uri, label) tuples for Gradio | |
| all_gradio_gallery_items: List[Tuple[str, str]] = [] | |
| # Dictionary to store {page_number (int): equation_count (int)} | |
| equation_counts_per_page: Dict[int, int] = {} | |
| # --- USE LOCAL COUNTERS FOR THREAD SAFETY --- | |
| total_figure_count = 0 | |
| total_equation_count = 0 | |
| # 1. Validation and Model Loading | |
| t0 = time.time() | |
| if not os.path.exists(pdf_path): | |
| report = f"β FATAL ERROR: Input PDF not found at {pdf_path}." | |
| # Return empty list of tuples for gallery on error | |
| return 0, 0, 0, report, time.time() - start_time, {}, [] | |
| try: | |
| model = YOLO(WEIGHTS_PATH) | |
| logging.warning(f"β Loaded YOLO model from: {WEIGHTS_PATH}") | |
| except Exception as e: | |
| report = f"β ERROR loading YOLO model: {e}\n(Ensure 'best.pt' is available and valid.)" | |
| return 0, 0, 0, report, time.time() - start_time, {}, [] | |
| t1 = time.time() | |
| log_messages.append(f"Model Loading Time: {t1-t0:.4f}s") | |
| # 2. PDF Loading | |
| t2 = time.time() | |
| try: | |
| doc = fitz.open(pdf_path) | |
| total_pages = doc.page_count | |
| logging.warning(f"β Opened PDF with {doc.page_count} pages") | |
| except Exception as e: | |
| report = f"β ERROR loading PDF file: {e}" | |
| return 0, 0, 0, report, time.time() - start_time, {}, [] | |
| t3 = time.time() | |
| log_messages.append(f"PDF Initialization Time: {t3-t2:.4f}s") | |
| mat = fitz.Matrix(SCALE_FACTOR, SCALE_FACTOR) | |
| # 3. Page Processing and Detection Loop | |
| t4 = time.time() | |
| for page_num_0_based in range(doc.page_count): | |
| page_start_time = time.time() | |
| fitz_page = doc.load_page(page_num_0_based) | |
| page_num = page_num_0_based + 1 | |
| # Render page to image for YOLO | |
| try: | |
| pix_start = time.time() | |
| pix = fitz_page.get_pixmap(matrix=mat) | |
| original_img = pixmap_to_numpy(pix) | |
| pix_time = time.time() - pix_start | |
| except Exception as e: | |
| logging.error(f"Error converting page {page_num} to image: {e}. Skipping.") | |
| continue | |
| # Core Detection | |
| detect_start = time.time() | |
| # --- PASSING AND RECEIVING THE COUNTERS HERE (Concurrency Fix) --- | |
| ( | |
| page_equations, | |
| page_figures, | |
| page_images_dicts, | |
| total_equation_count, | |
| total_figure_count | |
| ) = run_yolo_detection_and_count( | |
| original_img, | |
| model, | |
| page_num, | |
| total_equation_count, | |
| total_figure_count | |
| ) | |
| # --- FORMATTING FOR GRADIO GALLERY (Gradio Format Fix) --- | |
| for item in page_images_dicts: | |
| gradio_tuple = base64_to_gradio_gallery_tuple(item["base64"], item["id"]) | |
| all_gradio_gallery_items.append(gradio_tuple) | |
| detect_time = time.time() - detect_start | |
| # Store the count in the dictionary (INT keys) | |
| equation_counts_per_page[page_num] = page_equations | |
| page_total_time = time.time() - page_start_time | |
| log_messages.append(f"Page {page_num} Time: Total={page_total_time:.4f}s (Render={pix_time:.4f}s, Detect={detect_time:.4f}s)") | |
| doc.close() | |
| t5 = time.time() | |
| detection_loop_time = t5 - t4 | |
| log_messages.append(f"Total Detection Loop Time ({total_pages} pages): {detection_loop_time:.4f}s") | |
| # Convert integer keys to string keys for JSON serialization | |
| equation_counts_per_page_str_keys: Dict[str, int] = { | |
| str(k): v for k, v in equation_counts_per_page.items() | |
| } | |
| # 4. Final Report Generation | |
| total_execution_time = t5 - start_time | |
| report = ( | |
| f"β **YOLO Counting Complete!**\n\n" | |
| f"**1) Total Pages Detected in PDF:** **{total_pages}**\n" | |
| f"**2) Total Equations Detected:** **{total_equation_count}**\n" # Uses local final count | |
| f"**3) Total Figures Detected:** **{total_figure_count}**\n" # Uses local final count | |
| f"---\n" | |
| f"**4) Total Execution Time:** **{total_execution_time:.4f}s**\n" | |
| f"### Detailed Step Timing\n" | |
| f"```\n" | |
| + "\n".join(log_messages) + | |
| f"\n```" | |
| ) | |
| # Return the dictionary with string keys and the properly formatted gallery items | |
| return total_pages, total_equation_count, total_figure_count, report, total_execution_time, equation_counts_per_page_str_keys, all_gradio_gallery_items | |
| # ============================================================================ | |
| # --- GRADIO INTERFACE FUNCTION (Updated) --- | |
| # ============================================================================ | |
| # --- UPDATED return type for clarity --- | |
| def gradio_process_pdf(pdf_file) -> Tuple[str, str, str, str, Dict[str, int], List[Tuple[str, str]]]: | |
| """ | |
| Gradio wrapper function to handle file upload and return results. | |
| """ | |
| if pdf_file is None: | |
| # Return empty list of tuples for gallery on error | |
| return "N/A", "N/A", "N/A", "Please upload a PDF file.", {}, [] | |
| pdf_path = pdf_file.name | |
| try: | |
| # Unpack the new return value: equation_counts_per_page (with string keys) | |
| ( | |
| num_pages, | |
| num_equations, | |
| num_figures, | |
| report, | |
| total_time, | |
| equation_counts_per_page, | |
| gallery_items # Now correctly formatted list of tuples | |
| ) = run_single_pdf_preprocessing(pdf_path) | |
| # Return results (6 items now) | |
| return str(num_pages), str(num_equations), str(num_figures), report, equation_counts_per_page, gallery_items | |
| except Exception as e: | |
| error_msg = f"An unexpected error occurred: {e}" | |
| logging.error(error_msg, exc_info=True) | |
| # Return empty list of tuples for gallery on error | |
| return "Error", "Error", "Error", error_msg, {}, [] | |
| # ============================================================================ | |
| # --- GRADIO INTERFACE DEFINITION (Updated) --- | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| if not os.path.exists(WEIGHTS_PATH): | |
| logging.error(f"β FATAL ERROR: YOLO weight file '{WEIGHTS_PATH}' not found. Cannot run live inference.") | |
| input_file = gr.File(label="Upload PDF Document", type="filepath", file_types=[".pdf"]) | |
| # Outputs | |
| output_pages = gr.Textbox(label="Total Pages in PDF", interactive=False) | |
| output_equations = gr.Textbox(label="Total Equations Detected", interactive=False) | |
| output_figures = gr.Textbox(label="Total Figures Detected", interactive=False) | |
| output_report = gr.Markdown(label="Processing Summary and Timing") | |
| # NEW OUTPUT: JSON component for structured data | |
| output_page_counts = gr.JSON(label="Equation Count Per Page (Dictionary)") | |
| # Gradio Gallery is retained and now receives the correctly formatted list of tuples | |
| output_gallery = gr.Gallery( | |
| label="Detected Items (Gallery Format Fix Applied)", | |
| columns=5, | |
| height="auto", | |
| object_fit="contain", | |
| allow_preview=False | |
| ) | |
| interface = gr.Interface( | |
| fn=gradio_process_pdf, | |
| inputs=input_file, | |
| # Outputs list remains the same, but the gallery now works | |
| outputs=[ | |
| output_pages, | |
| output_equations, | |
| output_figures, | |
| output_report, | |
| output_page_counts, | |
| output_gallery | |
| ], | |
| title="π YOLO Counting with Per-Page Data & Timing (Concurrency Fix)", | |
| description=( | |
| "Upload a PDF to run YOLO detection. The concurrency bug and Gradio Gallery display error have been fixed." | |
| ), | |
| ) | |
| print("\nStarting Gradio application...") | |
| interface.launch(inbrowser=True) | |