Update app.py
Browse files
app.py
CHANGED
|
@@ -585,6 +585,520 @@
|
|
| 585 |
|
| 586 |
|
| 587 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 588 |
import base64
|
| 589 |
from PIL import Image
|
| 590 |
import re
|
|
@@ -603,14 +1117,12 @@ import io
|
|
| 603 |
import json
|
| 604 |
|
| 605 |
# ============================================================================
|
| 606 |
-
# --- Global Setup and Configuration ---
|
| 607 |
# ============================================================================
|
| 608 |
|
| 609 |
-
# Configure logging to write to a string buffer for display in the report
|
| 610 |
log_stream = io.StringIO()
|
| 611 |
logging.basicConfig(level=logging.WARNING, stream=log_stream, format='%(levelname)s:%(message)s')
|
| 612 |
|
| 613 |
-
# Patch torch.load to prevent weights_only error with older models
|
| 614 |
_original_torch_load = torch.load
|
| 615 |
def patched_torch_load(*args, **kwargs):
|
| 616 |
kwargs["weights_only"] = False
|
|
@@ -620,7 +1132,6 @@ torch.load = patched_torch_load
|
|
| 620 |
WEIGHTS_PATH = 'best.pt'
|
| 621 |
SCALE_FACTOR = 2.0
|
| 622 |
|
| 623 |
-
# --- OCR Model Initialization ---
|
| 624 |
from transformers import TrOCRProcessor
|
| 625 |
from optimum.onnxruntime import ORTModelForVision2Seq
|
| 626 |
|
|
@@ -635,14 +1146,13 @@ except Exception as e:
|
|
| 635 |
ort_model = None
|
| 636 |
OCR_MODEL_LOADED = False
|
| 637 |
|
| 638 |
-
# Detection parameters
|
| 639 |
CONF_THRESHOLD = 0.2
|
| 640 |
TARGET_CLASSES = ['figure', 'equation']
|
| 641 |
IOU_MERGE_THRESHOLD = 0.4
|
| 642 |
IOA_SUPPRESSION_THRESHOLD = 0.7
|
| 643 |
|
| 644 |
# ============================================================================
|
| 645 |
-
# --- BOX COMBINATION LOGIC (
|
| 646 |
# ============================================================================
|
| 647 |
|
| 648 |
def calculate_iou(box1, box2):
|
|
@@ -685,9 +1195,9 @@ def filter_nested_boxes(detections, ioa_threshold=0.80):
|
|
| 685 |
return [detections[i] for i in keep_indices]
|
| 686 |
|
| 687 |
|
|
|
|
| 688 |
def merge_overlapping_boxes(detections, iou_threshold):
|
| 689 |
if not detections: return []
|
| 690 |
-
# 1. Sort by confidence (YOLO standard)
|
| 691 |
detections.sort(key=lambda d: d['conf'], reverse=True)
|
| 692 |
merged_detections = []
|
| 693 |
is_merged = [False] * len(detections)
|
|
@@ -709,16 +1219,15 @@ def merge_overlapping_boxes(detections, iou_threshold):
|
|
| 709 |
is_merged[j] = True
|
| 710 |
merged_detections.append({
|
| 711 |
'coords': (merged_x1, merged_y1, merged_x2, merged_y2),
|
| 712 |
-
# 'y1' is retained for clarity, though 'coords' contains it
|
| 713 |
'y1': merged_y1,
|
| 714 |
'class': current_class,
|
| 715 |
'conf': detections[i]['conf']
|
| 716 |
})
|
| 717 |
|
| 718 |
-
# --- FIX IMPLEMENTATION
|
| 719 |
-
# Sort
|
| 720 |
-
# This
|
| 721 |
-
merged_detections.sort(key=lambda d:
|
| 722 |
|
| 723 |
return merged_detections
|
| 724 |
|
|
@@ -786,6 +1295,7 @@ def get_latex_from_base64(base64_string: str) -> str:
|
|
| 786 |
return f"[TR_OCR_ERROR: {e}]"
|
| 787 |
|
| 788 |
|
|
|
|
| 789 |
def run_yolo_detection_and_count(
|
| 790 |
image: np.ndarray, model: YOLO, page_num: int,
|
| 791 |
current_eq_count: int, current_fig_count: int
|
|
@@ -800,7 +1310,7 @@ def run_yolo_detection_and_count(
|
|
| 800 |
|
| 801 |
detected_items: List[Dict[str, Union[Image.Image, str, Tuple[float,...]]]] = []
|
| 802 |
yolo_detections = []
|
| 803 |
-
|
| 804 |
try:
|
| 805 |
results = model.predict(image, conf=CONF_THRESHOLD, verbose=False)
|
| 806 |
if results and results[0].boxes:
|
|
@@ -817,10 +1327,11 @@ def run_yolo_detection_and_count(
|
|
| 817 |
logging.error(f"ERROR: YOLO inference failed on page {page_num}: {e}")
|
| 818 |
return [], eq_counter, fig_counter
|
| 819 |
|
|
|
|
| 820 |
merged_detections = merge_overlapping_boxes(yolo_detections, IOU_MERGE_THRESHOLD)
|
| 821 |
final_detections = filter_nested_boxes(merged_detections, IOA_SUPPRESSION_THRESHOLD)
|
| 822 |
|
| 823 |
-
# Note: final_detections is now sorted by
|
| 824 |
|
| 825 |
for det in final_detections:
|
| 826 |
bbox = det["coords"]
|
|
@@ -1070,9 +1581,9 @@ if __name__ == "__main__":
|
|
| 1070 |
output_structured_latex,
|
| 1071 |
output_gallery
|
| 1072 |
],
|
| 1073 |
-
title="π YOLO Detection & Math OCR Pipeline (
|
| 1074 |
description=(
|
| 1075 |
-
"Upload a PDF. YOLO detects equations/figures, and OCR converts equations to LaTeX.
|
| 1076 |
),
|
| 1077 |
)
|
| 1078 |
|
|
|
|
| 585 |
|
| 586 |
|
| 587 |
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
# import base64
|
| 591 |
+
# from PIL import Image
|
| 592 |
+
# import re
|
| 593 |
+
# import fitz # PyMuPDF
|
| 594 |
+
# import numpy as np
|
| 595 |
+
# import cv2
|
| 596 |
+
# import torch
|
| 597 |
+
# import torch.serialization
|
| 598 |
+
# import os
|
| 599 |
+
# import time
|
| 600 |
+
# from typing import Optional, Tuple, List, Dict, Any, Union
|
| 601 |
+
# from ultralytics import YOLO
|
| 602 |
+
# import logging
|
| 603 |
+
# import gradio as gr
|
| 604 |
+
# import io
|
| 605 |
+
# import json
|
| 606 |
+
|
| 607 |
+
# # ============================================================================
|
| 608 |
+
# # --- Global Setup and Configuration ---
|
| 609 |
+
# # ============================================================================
|
| 610 |
+
|
| 611 |
+
# # Configure logging to write to a string buffer for display in the report
|
| 612 |
+
# log_stream = io.StringIO()
|
| 613 |
+
# logging.basicConfig(level=logging.WARNING, stream=log_stream, format='%(levelname)s:%(message)s')
|
| 614 |
+
|
| 615 |
+
# # Patch torch.load to prevent weights_only error with older models
|
| 616 |
+
# _original_torch_load = torch.load
|
| 617 |
+
# def patched_torch_load(*args, **kwargs):
|
| 618 |
+
# kwargs["weights_only"] = False
|
| 619 |
+
# return _original_torch_load(*args, **kwargs)
|
| 620 |
+
# torch.load = patched_torch_load
|
| 621 |
+
|
| 622 |
+
# WEIGHTS_PATH = 'best.pt'
|
| 623 |
+
# SCALE_FACTOR = 2.0
|
| 624 |
+
|
| 625 |
+
# # --- OCR Model Initialization ---
|
| 626 |
+
# from transformers import TrOCRProcessor
|
| 627 |
+
# from optimum.onnxruntime import ORTModelForVision2Seq
|
| 628 |
+
|
| 629 |
+
# MODEL_NAME = 'breezedeus/pix2text-mfr-1.5'
|
| 630 |
+
# try:
|
| 631 |
+
# processor = TrOCRProcessor.from_pretrained(MODEL_NAME)
|
| 632 |
+
# ort_model = ORTModelForVision2Seq.from_pretrained(MODEL_NAME, use_cache=False)
|
| 633 |
+
# OCR_MODEL_LOADED = True
|
| 634 |
+
# except Exception as e:
|
| 635 |
+
# logging.warning(f"OCR model loading failed: {e}")
|
| 636 |
+
# processor = None
|
| 637 |
+
# ort_model = None
|
| 638 |
+
# OCR_MODEL_LOADED = False
|
| 639 |
+
|
| 640 |
+
# # Detection parameters
|
| 641 |
+
# CONF_THRESHOLD = 0.2
|
| 642 |
+
# TARGET_CLASSES = ['figure', 'equation']
|
| 643 |
+
# IOU_MERGE_THRESHOLD = 0.4
|
| 644 |
+
# IOA_SUPPRESSION_THRESHOLD = 0.7
|
| 645 |
+
|
| 646 |
+
# # ============================================================================
|
| 647 |
+
# # --- BOX COMBINATION LOGIC (FIXED) ---
|
| 648 |
+
# # ============================================================================
|
| 649 |
+
|
| 650 |
+
# def calculate_iou(box1, box2):
|
| 651 |
+
# x1_a, y1_a, x2_a, y2_a = box1
|
| 652 |
+
# x1_b, y1_b, x2_b, y2_b = box2
|
| 653 |
+
# x_left = max(x1_a, x1_b)
|
| 654 |
+
# y_top = max(y1_a, y1_b)
|
| 655 |
+
# x_right = min(x2_a, x2_b)
|
| 656 |
+
# y_bottom = min(y2_a, y2_b)
|
| 657 |
+
# intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top)
|
| 658 |
+
# box_a_area = (x2_a - x1_a) * (y2_a - y1_a)
|
| 659 |
+
# box_b_area = (x2_b - x1_b) * (y2_b - y1_b)
|
| 660 |
+
# union_area = float(box_a_area + box_b_area - intersection_area)
|
| 661 |
+
# return intersection_area / union_area if union_area > 0 else 0
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
# def filter_nested_boxes(detections, ioa_threshold=0.80):
|
| 665 |
+
# if not detections: return []
|
| 666 |
+
# for d in detections:
|
| 667 |
+
# x1, y1, x2, y2 = d['coords']
|
| 668 |
+
# d['area'] = (x2 - x1) * (y2 - y1)
|
| 669 |
+
# detections.sort(key=lambda x: x['area'], reverse=True)
|
| 670 |
+
# keep_indices = []
|
| 671 |
+
# is_suppressed = [False] * len(detections)
|
| 672 |
+
# for i in range(len(detections)):
|
| 673 |
+
# if is_suppressed[i]: continue
|
| 674 |
+
# keep_indices.append(i)
|
| 675 |
+
# box_a = detections[i]['coords']
|
| 676 |
+
# for j in range(i + 1, len(detections)):
|
| 677 |
+
# if is_suppressed[j]: continue
|
| 678 |
+
# box_b = detections[j]['coords']
|
| 679 |
+
# x_left = max(box_a[0], box_b[0])
|
| 680 |
+
# y_top = max(box_a[1], box_b[1])
|
| 681 |
+
# x_right = min(box_a[2], box_b[2])
|
| 682 |
+
# y_bottom = min(box_a[3], box_b[3])
|
| 683 |
+
# intersection = max(0, x_right - x_left) * max(0, y_bottom - y_top)
|
| 684 |
+
# area_b = detections[j]['area']
|
| 685 |
+
# if area_b > 0 and intersection / area_b > ioa_threshold:
|
| 686 |
+
# is_suppressed[j] = True
|
| 687 |
+
# return [detections[i] for i in keep_indices]
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
# def merge_overlapping_boxes(detections, iou_threshold):
|
| 691 |
+
# if not detections: return []
|
| 692 |
+
# # 1. Sort by confidence (YOLO standard)
|
| 693 |
+
# detections.sort(key=lambda d: d['conf'], reverse=True)
|
| 694 |
+
# merged_detections = []
|
| 695 |
+
# is_merged = [False] * len(detections)
|
| 696 |
+
|
| 697 |
+
# for i in range(len(detections)):
|
| 698 |
+
# if is_merged[i]: continue
|
| 699 |
+
# current_box = detections[i]['coords']
|
| 700 |
+
# current_class = detections[i]['class']
|
| 701 |
+
# merged_x1, merged_y1, merged_x2, merged_y2 = current_box
|
| 702 |
+
# for j in range(i + 1, len(detections)):
|
| 703 |
+
# if is_merged[j] or detections[j]['class'] != current_class: continue
|
| 704 |
+
# other_box = detections[j]['coords']
|
| 705 |
+
# iou = calculate_iou(current_box, other_box)
|
| 706 |
+
# if iou > iou_threshold:
|
| 707 |
+
# merged_x1 = min(merged_x1, other_box[0])
|
| 708 |
+
# merged_y1 = min(merged_y1, other_box[1])
|
| 709 |
+
# merged_x2 = max(merged_x2, other_box[2])
|
| 710 |
+
# merged_y2 = max(other_box[3], other_box[3])
|
| 711 |
+
# is_merged[j] = True
|
| 712 |
+
# merged_detections.append({
|
| 713 |
+
# 'coords': (merged_x1, merged_y1, merged_x2, merged_y2),
|
| 714 |
+
# # 'y1' is retained for clarity, though 'coords' contains it
|
| 715 |
+
# 'y1': merged_y1,
|
| 716 |
+
# 'class': current_class,
|
| 717 |
+
# 'conf': detections[i]['conf']
|
| 718 |
+
# })
|
| 719 |
+
|
| 720 |
+
# # --- FIX IMPLEMENTATION: READING ORDER SORT ---
|
| 721 |
+
# # Sort primarily by y1 (vertical position), secondarily by x1 (horizontal position).
|
| 722 |
+
# # This correctly handles two-column layouts like Q.10 options (A), (B), (C), (D)
|
| 723 |
+
# merged_detections.sort(key=lambda d: (d['coords'][1], d['coords'][0]))
|
| 724 |
+
|
| 725 |
+
# return merged_detections
|
| 726 |
+
|
| 727 |
+
# # ============================================================================
|
| 728 |
+
# # --- UTILITY FUNCTIONS (Retained) ---
|
| 729 |
+
# # ============================================================================
|
| 730 |
+
|
| 731 |
+
# def pixmap_to_numpy(pix: fitz.Pixmap) -> np.ndarray:
|
| 732 |
+
# """Converts a PyMuPDF Pixmap to a NumPy array for OpenCV/YOLO."""
|
| 733 |
+
# img = np.frombuffer(pix.samples, dtype=np.uint8).reshape(
|
| 734 |
+
# (pix.h, pix.w, pix.n)
|
| 735 |
+
# )
|
| 736 |
+
# if pix.n == 4:
|
| 737 |
+
# img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
|
| 738 |
+
# elif pix.n == 1:
|
| 739 |
+
# img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
| 740 |
+
# return img
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
# def crop_and_convert_to_pil(image: np.ndarray, bbox: Tuple[float, float, float, float]) -> Image.Image:
|
| 744 |
+
# """Crops the numpy array and returns a PIL Image object."""
|
| 745 |
+
# x1, y1, x2, y2 = map(int, bbox)
|
| 746 |
+
# h, w, _ = image.shape
|
| 747 |
+
|
| 748 |
+
# x1 = max(0, x1)
|
| 749 |
+
# y1 = max(0, y1)
|
| 750 |
+
# x2 = min(w, x2)
|
| 751 |
+
# y2 = min(h, y2)
|
| 752 |
+
|
| 753 |
+
# crop_np = image[y1:y2, x1:x2]
|
| 754 |
+
# crop_pil = Image.fromarray(cv2.cvtColor(crop_np, cv2.COLOR_BGR2RGB))
|
| 755 |
+
|
| 756 |
+
# return crop_pil
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
# def pil_to_base64(img: Image.Image) -> str:
|
| 760 |
+
# """Converts a PIL Image object to a Base64 encoded string (PNG format) for OCR input."""
|
| 761 |
+
# buffer = io.BytesIO()
|
| 762 |
+
# img.save(buffer, format="PNG")
|
| 763 |
+
# return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
# def get_latex_from_base64(base64_string: str) -> str:
|
| 767 |
+
# """Performs the OCR conversion using the globally loaded model."""
|
| 768 |
+
# if not OCR_MODEL_LOADED:
|
| 769 |
+
# return "[MODEL_ERROR: Model not loaded]"
|
| 770 |
+
|
| 771 |
+
# try:
|
| 772 |
+
# image_data = base64.b64decode(base64_string)
|
| 773 |
+
# image = Image.open(io.BytesIO(image_data)).convert('RGB')
|
| 774 |
+
|
| 775 |
+
# pixel_values = processor(images=image, return_tensors="pt").pixel_values
|
| 776 |
+
# generated_ids = ort_model.generate(pixel_values)
|
| 777 |
+
# raw_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| 778 |
+
|
| 779 |
+
# if not raw_text:
|
| 780 |
+
# return "[OCR_WARNING: No formula found]"
|
| 781 |
+
|
| 782 |
+
# latex = raw_text[0]
|
| 783 |
+
# latex = re.sub(r'[\r\n]+', '', latex)
|
| 784 |
+
|
| 785 |
+
# return latex
|
| 786 |
+
|
| 787 |
+
# except Exception as e:
|
| 788 |
+
# return f"[TR_OCR_ERROR: {e}]"
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
# def run_yolo_detection_and_count(
|
| 792 |
+
# image: np.ndarray, model: YOLO, page_num: int,
|
| 793 |
+
# current_eq_count: int, current_fig_count: int
|
| 794 |
+
# ) -> Tuple[List[Dict[str, Union[Image.Image, str, Tuple[float,...]]]], int, int]:
|
| 795 |
+
# """
|
| 796 |
+
# Performs YOLO detection and returns a list of detected item dictionaries
|
| 797 |
+
# and the updated total counters.
|
| 798 |
+
# """
|
| 799 |
+
|
| 800 |
+
# eq_counter = current_eq_count
|
| 801 |
+
# fig_counter = current_fig_count
|
| 802 |
+
|
| 803 |
+
# detected_items: List[Dict[str, Union[Image.Image, str, Tuple[float,...]]]] = []
|
| 804 |
+
# yolo_detections = []
|
| 805 |
+
|
| 806 |
+
# try:
|
| 807 |
+
# results = model.predict(image, conf=CONF_THRESHOLD, verbose=False)
|
| 808 |
+
# if results and results[0].boxes:
|
| 809 |
+
# for box in results[0].boxes.data.tolist():
|
| 810 |
+
# x1, y1, x2, y2, conf, cls_id = box
|
| 811 |
+
# cls_name = model.names[int(cls_id)]
|
| 812 |
+
# if cls_name in TARGET_CLASSES:
|
| 813 |
+
# yolo_detections.append({
|
| 814 |
+
# 'coords': (x1, y1, x2, y2),
|
| 815 |
+
# 'class': cls_name,
|
| 816 |
+
# 'conf': conf
|
| 817 |
+
# })
|
| 818 |
+
# except Exception as e:
|
| 819 |
+
# logging.error(f"ERROR: YOLO inference failed on page {page_num}: {e}")
|
| 820 |
+
# return [], eq_counter, fig_counter
|
| 821 |
+
|
| 822 |
+
# merged_detections = merge_overlapping_boxes(yolo_detections, IOU_MERGE_THRESHOLD)
|
| 823 |
+
# final_detections = filter_nested_boxes(merged_detections, IOA_SUPPRESSION_THRESHOLD)
|
| 824 |
+
|
| 825 |
+
# # Note: final_detections is now sorted by (y1, x1) in reading order.
|
| 826 |
+
|
| 827 |
+
# for det in final_detections:
|
| 828 |
+
# bbox = det["coords"]
|
| 829 |
+
# crop_pil = crop_and_convert_to_pil(image, bbox)
|
| 830 |
+
|
| 831 |
+
# item = {
|
| 832 |
+
# "type": det["class"],
|
| 833 |
+
# "coords": bbox,
|
| 834 |
+
# "pil_image": crop_pil,
|
| 835 |
+
# }
|
| 836 |
+
|
| 837 |
+
# if det["class"] == "equation":
|
| 838 |
+
# eq_counter += 1
|
| 839 |
+
# item["id"] = f"EQUATION{eq_counter}"
|
| 840 |
+
# item["latex"] = ""
|
| 841 |
+
# elif det["class"] == "figure":
|
| 842 |
+
# fig_counter += 1
|
| 843 |
+
# item["id"] = f"FIGURE{fig_counter}"
|
| 844 |
+
# item["latex"] = "[FIGURE - No LaTeX]"
|
| 845 |
+
|
| 846 |
+
# detected_items.append(item)
|
| 847 |
+
|
| 848 |
+
# return detected_items, eq_counter, fig_counter
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
# # ============================================================================
|
| 852 |
+
# # --- MAIN DOCUMENT PROCESSING FUNCTION (Retained Logic) ---
|
| 853 |
+
# # ============================================================================
|
| 854 |
+
|
| 855 |
+
# def run_single_pdf_preprocessing(
|
| 856 |
+
# pdf_path: str
|
| 857 |
+
# ) -> Tuple[int, int, int, str, float, Dict[str, Union[int, str]], List[Tuple[Image.Image, str]]]:
|
| 858 |
+
# """
|
| 859 |
+
# Runs the pipeline, performs OCR, and returns final results.
|
| 860 |
+
# """
|
| 861 |
+
|
| 862 |
+
# log_stream.truncate(0)
|
| 863 |
+
# log_stream.seek(0)
|
| 864 |
+
|
| 865 |
+
# start_time = time.time()
|
| 866 |
+
|
| 867 |
+
# all_extracted_items: List[Dict[str, Union[Image.Image, str]]] = []
|
| 868 |
+
|
| 869 |
+
# total_figure_count = 0
|
| 870 |
+
# total_equation_count = 0
|
| 871 |
+
|
| 872 |
+
|
| 873 |
+
# # 1. Validation and Model Loading (YOLO)
|
| 874 |
+
# t0 = time.time()
|
| 875 |
+
# if not os.path.exists(pdf_path):
|
| 876 |
+
# report = f"β FATAL ERROR: Input PDF not found at {pdf_path}."
|
| 877 |
+
# return 0, 0, 0, report, time.time() - start_time, {}, []
|
| 878 |
+
|
| 879 |
+
# try:
|
| 880 |
+
# model = YOLO(WEIGHTS_PATH)
|
| 881 |
+
# logging.warning(f"INFO: Loaded YOLO model from: {WEIGHTS_PATH}")
|
| 882 |
+
# except Exception as e:
|
| 883 |
+
# report = f"β ERROR loading YOLO model: {e}\n(Ensure 'best.pt' is available and valid.)"
|
| 884 |
+
# return 0, 0, 0, report, time.time() - start_time, {}, []
|
| 885 |
+
# t1 = time.time()
|
| 886 |
+
# logging.warning(f"INFO: Model Loading Time: {t1-t0:.4f}s")
|
| 887 |
+
|
| 888 |
+
# # 2. PDF Loading (fitz)
|
| 889 |
+
# t2 = time.time()
|
| 890 |
+
# try:
|
| 891 |
+
# doc = fitz.open(pdf_path)
|
| 892 |
+
# total_pages = doc.page_count
|
| 893 |
+
# logging.warning(f"INFO: Opened PDF with {doc.page_count} pages")
|
| 894 |
+
# except Exception as e:
|
| 895 |
+
# report = f"β ERROR loading PDF file: {e}"
|
| 896 |
+
# return 0, 0, 0, report, time.time() - start_time, {}, []
|
| 897 |
+
# t3 = time.time()
|
| 898 |
+
# logging.warning(f"INFO: PDF Initialization Time: {t3-t2:.4f}s")
|
| 899 |
+
|
| 900 |
+
# mat = fitz.Matrix(SCALE_FACTOR, SCALE_FACTOR)
|
| 901 |
+
|
| 902 |
+
# # 3. Page Processing, Detection, and OCR Loop
|
| 903 |
+
# t4 = time.time()
|
| 904 |
+
# for page_num_0_based in range(doc.page_count):
|
| 905 |
+
# page_start_time = time.time()
|
| 906 |
+
# fitz_page = doc.load_page(page_num_0_based)
|
| 907 |
+
# page_num = page_num_0_based + 1
|
| 908 |
+
|
| 909 |
+
# # Render page to image for YOLO
|
| 910 |
+
# try:
|
| 911 |
+
# pix_start = time.time()
|
| 912 |
+
# pix = fitz_page.get_pixmap(matrix=mat)
|
| 913 |
+
# original_img = pixmap_to_numpy(pix)
|
| 914 |
+
# pix_time = time.time() - pix_start
|
| 915 |
+
# except Exception as e:
|
| 916 |
+
# logging.error(f"ERROR: Error converting page {page_num} to image: {e}. Skipping.")
|
| 917 |
+
# continue
|
| 918 |
+
|
| 919 |
+
# # YOLO Detection
|
| 920 |
+
# detect_start = time.time()
|
| 921 |
+
# (
|
| 922 |
+
# page_extracted_items,
|
| 923 |
+
# total_equation_count,
|
| 924 |
+
# total_figure_count
|
| 925 |
+
# ) = run_yolo_detection_and_count(
|
| 926 |
+
# original_img,
|
| 927 |
+
# model,
|
| 928 |
+
# page_num,
|
| 929 |
+
# total_equation_count,
|
| 930 |
+
# total_figure_count
|
| 931 |
+
# )
|
| 932 |
+
# detect_time = time.time() - detect_start
|
| 933 |
+
|
| 934 |
+
# # --- OCR/LaTeX Conversion and Logging ---
|
| 935 |
+
# ocr_total_time = 0
|
| 936 |
+
# page_equations = 0
|
| 937 |
+
|
| 938 |
+
# for item in page_extracted_items:
|
| 939 |
+
# if item["type"] == "equation":
|
| 940 |
+
# page_equations += 1
|
| 941 |
+
# ocr_start = time.time()
|
| 942 |
+
|
| 943 |
+
# b64_string = pil_to_base64(item["pil_image"])
|
| 944 |
+
# item["latex"] = get_latex_from_base64(b64_string)
|
| 945 |
+
|
| 946 |
+
# ocr_time = time.time() - ocr_start
|
| 947 |
+
# ocr_total_time += ocr_time
|
| 948 |
+
|
| 949 |
+
# logging.warning(f"LATEX: Page {page_num}, ID {item['id']} -> Time: {ocr_time:.4f}s, Formula: {item['latex'][:50]}...")
|
| 950 |
+
|
| 951 |
+
# all_extracted_items.extend(page_extracted_items)
|
| 952 |
+
|
| 953 |
+
# page_figures = sum(1 for item in page_extracted_items if item["type"] == "figure")
|
| 954 |
+
|
| 955 |
+
# page_total_time = time.time() - page_start_time
|
| 956 |
+
# logging.warning(f"SUMMARY: Page {page_num}: EQs={page_equations}, Figs={page_figures} | Page Time: {page_total_time:.4f}s (Detect={detect_time:.4f}s, OCR Total={ocr_total_time:.4f}s)")
|
| 957 |
+
|
| 958 |
+
# doc.close()
|
| 959 |
+
# t5 = time.time()
|
| 960 |
+
# detection_loop_time = t5 - t4
|
| 961 |
+
# logging.warning(f"INFO: Total Detection and OCR Loop Time ({total_pages} pages): {detection_loop_time:.4f}s")
|
| 962 |
+
|
| 963 |
+
# # 4. Final Report Generation and Gallery Formatting
|
| 964 |
+
|
| 965 |
+
# # Create the structured JSON output as requested by the user
|
| 966 |
+
# structured_latex_output = {
|
| 967 |
+
# "Total Pages": total_pages,
|
| 968 |
+
# "Total Equations": total_equation_count,
|
| 969 |
+
# }
|
| 970 |
+
# for item in all_extracted_items:
|
| 971 |
+
# if item["type"] == "equation":
|
| 972 |
+
# # Map EQUATION ID to LaTeX code
|
| 973 |
+
# structured_latex_output[item["id"]] = item["latex"]
|
| 974 |
+
|
| 975 |
+
|
| 976 |
+
# # Format the extracted items for the Gradio Gallery
|
| 977 |
+
# gallery_items: List[Tuple[Image.Image, str]] = []
|
| 978 |
+
|
| 979 |
+
# for item in all_extracted_items:
|
| 980 |
+
# image_label = item["id"]
|
| 981 |
+
# if item["type"] == "equation":
|
| 982 |
+
# image_label = f'{item["id"]}: {item["latex"]}'
|
| 983 |
+
|
| 984 |
+
# gallery_items.append((item["pil_image"], image_label))
|
| 985 |
+
|
| 986 |
+
|
| 987 |
+
# total_execution_time = t5 - start_time
|
| 988 |
+
|
| 989 |
+
# full_log = log_stream.getvalue()
|
| 990 |
+
|
| 991 |
+
# report = (
|
| 992 |
+
# f"β
**YOLO Counting & OCR Complete!**\n\n"
|
| 993 |
+
# f"**1) Total Pages Detected in PDF:** **{total_pages}**\n"
|
| 994 |
+
# f"**2) Total Equations Detected:** **{total_equation_count}**\n"
|
| 995 |
+
# f"**3) Total Figures Detected:** **{total_figure_count}**\n"
|
| 996 |
+
# f"---\n"
|
| 997 |
+
# f"**4) Total Execution Time:** **{total_execution_time:.4f}s**\n"
|
| 998 |
+
# f"### Full Processing Log\n"
|
| 999 |
+
# f"```text\n"
|
| 1000 |
+
# f"{full_log}"
|
| 1001 |
+
# f"\n```"
|
| 1002 |
+
# )
|
| 1003 |
+
|
| 1004 |
+
# # Return the new structured_latex_output instead of the page counts
|
| 1005 |
+
# return total_pages, total_equation_count, total_figure_count, report, total_execution_time, structured_latex_output, gallery_items
|
| 1006 |
+
|
| 1007 |
+
|
| 1008 |
+
# # ============================================================================
|
| 1009 |
+
# # --- GRADIO INTERFACE FUNCTION & DEFINITION (Retained) ---
|
| 1010 |
+
# # ============================================================================
|
| 1011 |
+
|
| 1012 |
+
# def gradio_process_pdf(pdf_file) -> Tuple[str, str, str, str, Dict[str, Union[int, str]], List[Tuple[Image.Image, str]]]:
|
| 1013 |
+
# """Gradio wrapper function to handle file upload and return results."""
|
| 1014 |
+
# if pdf_file is None:
|
| 1015 |
+
# return "N/A", "N/A", "N/A", "Please upload a PDF file.", {}, []
|
| 1016 |
+
|
| 1017 |
+
# pdf_path = pdf_file.name
|
| 1018 |
+
|
| 1019 |
+
# try:
|
| 1020 |
+
# (
|
| 1021 |
+
# num_pages,
|
| 1022 |
+
# num_equations,
|
| 1023 |
+
# num_figures,
|
| 1024 |
+
# report,
|
| 1025 |
+
# total_time,
|
| 1026 |
+
# structured_latex_output,
|
| 1027 |
+
# gallery_items
|
| 1028 |
+
# ) = run_single_pdf_preprocessing(pdf_path)
|
| 1029 |
+
|
| 1030 |
+
|
| 1031 |
+
# return str(num_pages), str(num_equations), str(num_figures), report, structured_latex_output, gallery_items
|
| 1032 |
+
|
| 1033 |
+
|
| 1034 |
+
# except Exception as e:
|
| 1035 |
+
# error_msg = f"An unexpected error occurred: {e}"
|
| 1036 |
+
# logging.error(f"FATAL: {error_msg}", exc_info=True)
|
| 1037 |
+
# full_log = log_stream.getvalue()
|
| 1038 |
+
# error_report = f"β CRITICAL ERROR:\n{error_msg}\n\n### Log up to Failure\n```text\n{full_log}\n```"
|
| 1039 |
+
# return "Error", "Error", "Error", error_report, {}, []
|
| 1040 |
+
|
| 1041 |
+
|
| 1042 |
+
# if __name__ == "__main__":
|
| 1043 |
+
|
| 1044 |
+
# if not os.path.exists(WEIGHTS_PATH):
|
| 1045 |
+
# logging.error(f"β FATAL ERROR: YOLO weight file '{WEIGHTS_PATH}' not found. Cannot run live inference.")
|
| 1046 |
+
|
| 1047 |
+
# input_file = gr.File(label="Upload PDF Document", type="filepath", file_types=[".pdf"])
|
| 1048 |
+
|
| 1049 |
+
# output_pages = gr.Textbox(label="Total Pages in PDF", interactive=False)
|
| 1050 |
+
# output_equations = gr.Textbox(label="Total Equations Detected", interactive=False)
|
| 1051 |
+
# output_figures = gr.Textbox(label="Total Figures Detected", interactive=False)
|
| 1052 |
+
# output_report = gr.Markdown(label="Processing Summary and Full Log")
|
| 1053 |
+
|
| 1054 |
+
# output_structured_latex = gr.JSON(label="Structured LaTeX Output (EQUATIONx : <latex code>)")
|
| 1055 |
+
|
| 1056 |
+
# output_gallery = gr.Gallery(
|
| 1057 |
+
# label="Detected Items (with Extracted LaTeX)",
|
| 1058 |
+
# columns=3,
|
| 1059 |
+
# height="auto",
|
| 1060 |
+
# object_fit="contain",
|
| 1061 |
+
# allow_preview=False
|
| 1062 |
+
# )
|
| 1063 |
+
|
| 1064 |
+
# interface = gr.Interface(
|
| 1065 |
+
# fn=gradio_process_pdf,
|
| 1066 |
+
# inputs=input_file,
|
| 1067 |
+
# outputs=[
|
| 1068 |
+
# output_pages,
|
| 1069 |
+
# output_equations,
|
| 1070 |
+
# output_figures,
|
| 1071 |
+
# output_report,
|
| 1072 |
+
# output_structured_latex,
|
| 1073 |
+
# output_gallery
|
| 1074 |
+
# ],
|
| 1075 |
+
# title="π YOLO Detection & Math OCR Pipeline (Reading Order Fix)",
|
| 1076 |
+
# description=(
|
| 1077 |
+
# "Upload a PDF. YOLO detects equations/figures, and OCR converts equations to LaTeX. Now includes a fix for two-column reading order."
|
| 1078 |
+
# ),
|
| 1079 |
+
# )
|
| 1080 |
+
|
| 1081 |
+
# print("\nStarting Gradio application...")
|
| 1082 |
+
# interface.launch(inbrowser=True)
|
| 1083 |
+
|
| 1084 |
+
|
| 1085 |
+
|
| 1086 |
+
|
| 1087 |
+
|
| 1088 |
+
|
| 1089 |
+
|
| 1090 |
+
|
| 1091 |
+
|
| 1092 |
+
|
| 1093 |
+
|
| 1094 |
+
|
| 1095 |
+
|
| 1096 |
+
|
| 1097 |
+
|
| 1098 |
+
|
| 1099 |
+
|
| 1100 |
+
|
| 1101 |
+
|
| 1102 |
import base64
|
| 1103 |
from PIL import Image
|
| 1104 |
import re
|
|
|
|
| 1117 |
import json
|
| 1118 |
|
| 1119 |
# ============================================================================
|
| 1120 |
+
# --- Global Setup and Configuration (Retained) ---
|
| 1121 |
# ============================================================================
|
| 1122 |
|
|
|
|
| 1123 |
log_stream = io.StringIO()
|
| 1124 |
logging.basicConfig(level=logging.WARNING, stream=log_stream, format='%(levelname)s:%(message)s')
|
| 1125 |
|
|
|
|
| 1126 |
_original_torch_load = torch.load
|
| 1127 |
def patched_torch_load(*args, **kwargs):
|
| 1128 |
kwargs["weights_only"] = False
|
|
|
|
| 1132 |
WEIGHTS_PATH = 'best.pt'
|
| 1133 |
SCALE_FACTOR = 2.0
|
| 1134 |
|
|
|
|
| 1135 |
from transformers import TrOCRProcessor
|
| 1136 |
from optimum.onnxruntime import ORTModelForVision2Seq
|
| 1137 |
|
|
|
|
| 1146 |
ort_model = None
|
| 1147 |
OCR_MODEL_LOADED = False
|
| 1148 |
|
|
|
|
| 1149 |
CONF_THRESHOLD = 0.2
|
| 1150 |
TARGET_CLASSES = ['figure', 'equation']
|
| 1151 |
IOU_MERGE_THRESHOLD = 0.4
|
| 1152 |
IOA_SUPPRESSION_THRESHOLD = 0.7
|
| 1153 |
|
| 1154 |
# ============================================================================
|
| 1155 |
+
# --- BOX COMBINATION LOGIC (PURE VERTICAL FIX) ---
|
| 1156 |
# ============================================================================
|
| 1157 |
|
| 1158 |
def calculate_iou(box1, box2):
|
|
|
|
| 1195 |
return [detections[i] for i in keep_indices]
|
| 1196 |
|
| 1197 |
|
| 1198 |
+
# --- UPDATED: page_width argument removed ---
|
| 1199 |
def merge_overlapping_boxes(detections, iou_threshold):
|
| 1200 |
if not detections: return []
|
|
|
|
| 1201 |
detections.sort(key=lambda d: d['conf'], reverse=True)
|
| 1202 |
merged_detections = []
|
| 1203 |
is_merged = [False] * len(detections)
|
|
|
|
| 1219 |
is_merged[j] = True
|
| 1220 |
merged_detections.append({
|
| 1221 |
'coords': (merged_x1, merged_y1, merged_x2, merged_y2),
|
|
|
|
| 1222 |
'y1': merged_y1,
|
| 1223 |
'class': current_class,
|
| 1224 |
'conf': detections[i]['conf']
|
| 1225 |
})
|
| 1226 |
|
| 1227 |
+
# --- PURE VERTICAL FIX IMPLEMENTATION ---
|
| 1228 |
+
# Sort ONLY by the top y-coordinate (coords[1]).
|
| 1229 |
+
# This ignores horizontal position and any complex layout.
|
| 1230 |
+
merged_detections.sort(key=lambda d: d['coords'][1])
|
| 1231 |
|
| 1232 |
return merged_detections
|
| 1233 |
|
|
|
|
| 1295 |
return f"[TR_OCR_ERROR: {e}]"
|
| 1296 |
|
| 1297 |
|
| 1298 |
+
# --- UPDATED: page width argument removed from signature and call ---
|
| 1299 |
def run_yolo_detection_and_count(
|
| 1300 |
image: np.ndarray, model: YOLO, page_num: int,
|
| 1301 |
current_eq_count: int, current_fig_count: int
|
|
|
|
| 1310 |
|
| 1311 |
detected_items: List[Dict[str, Union[Image.Image, str, Tuple[float,...]]]] = []
|
| 1312 |
yolo_detections = []
|
| 1313 |
+
|
| 1314 |
try:
|
| 1315 |
results = model.predict(image, conf=CONF_THRESHOLD, verbose=False)
|
| 1316 |
if results and results[0].boxes:
|
|
|
|
| 1327 |
logging.error(f"ERROR: YOLO inference failed on page {page_num}: {e}")
|
| 1328 |
return [], eq_counter, fig_counter
|
| 1329 |
|
| 1330 |
+
# Call merge_overlapping_boxes without page_width
|
| 1331 |
merged_detections = merge_overlapping_boxes(yolo_detections, IOU_MERGE_THRESHOLD)
|
| 1332 |
final_detections = filter_nested_boxes(merged_detections, IOA_SUPPRESSION_THRESHOLD)
|
| 1333 |
|
| 1334 |
+
# Note: final_detections is now sorted purely by y1
|
| 1335 |
|
| 1336 |
for det in final_detections:
|
| 1337 |
bbox = det["coords"]
|
|
|
|
| 1581 |
output_structured_latex,
|
| 1582 |
output_gallery
|
| 1583 |
],
|
| 1584 |
+
title="π YOLO Detection & Math OCR Pipeline (Pure Vertical Sort)",
|
| 1585 |
description=(
|
| 1586 |
+
"Upload a PDF. YOLO detects equations/figures, and OCR converts equations to LaTeX. The output is now strictly sorted by the top bounding box Y-coordinate."
|
| 1587 |
),
|
| 1588 |
)
|
| 1589 |
|