Update app.py
Browse files
app.py
CHANGED
|
@@ -578,8 +578,6 @@
|
|
| 578 |
# )
|
| 579 |
|
| 580 |
|
| 581 |
-
|
| 582 |
-
|
| 583 |
import base64
|
| 584 |
from PIL import Image
|
| 585 |
import re
|
|
@@ -594,14 +592,17 @@ from typing import Optional, Tuple, List, Dict, Any, Union
|
|
| 594 |
from ultralytics import YOLO
|
| 595 |
import logging
|
| 596 |
import gradio as gr
|
| 597 |
-
import shutil
|
| 598 |
-
import tempfile
|
| 599 |
import io
|
|
|
|
| 600 |
|
| 601 |
# ============================================================================
|
| 602 |
-
# --- Global
|
| 603 |
# ============================================================================
|
| 604 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
# Patch torch.load to prevent weights_only error with older models
|
| 606 |
_original_torch_load = torch.load
|
| 607 |
def patched_torch_load(*args, **kwargs):
|
|
@@ -609,12 +610,6 @@ def patched_torch_load(*args, **kwargs):
|
|
| 609 |
return _original_torch_load(*args, **kwargs)
|
| 610 |
torch.load = patched_torch_load
|
| 611 |
|
| 612 |
-
logging.basicConfig(level=logging.WARNING)
|
| 613 |
-
|
| 614 |
-
# ============================================================================
|
| 615 |
-
# --- CONFIGURATION AND CONSTANTS ---
|
| 616 |
-
# ============================================================================
|
| 617 |
-
|
| 618 |
WEIGHTS_PATH = 'best.pt'
|
| 619 |
SCALE_FACTOR = 2.0
|
| 620 |
|
|
@@ -628,7 +623,7 @@ try:
|
|
| 628 |
ort_model = ORTModelForVision2Seq.from_pretrained(MODEL_NAME, use_cache=False)
|
| 629 |
OCR_MODEL_LOADED = True
|
| 630 |
except Exception as e:
|
| 631 |
-
logging.warning(f"OCR model loading failed
|
| 632 |
processor = None
|
| 633 |
ort_model = None
|
| 634 |
OCR_MODEL_LOADED = False
|
|
@@ -707,10 +702,12 @@ def merge_overlapping_boxes(detections, iou_threshold):
|
|
| 707 |
'coords': (merged_x1, merged_y1, merged_x2, merged_y2),
|
| 708 |
'y1': merged_y1, 'class': current_class, 'conf': detections[i]['conf']
|
| 709 |
})
|
|
|
|
|
|
|
| 710 |
return merged_detections
|
| 711 |
|
| 712 |
# ============================================================================
|
| 713 |
-
# --- UTILITY FUNCTIONS
|
| 714 |
# ============================================================================
|
| 715 |
|
| 716 |
def pixmap_to_numpy(pix: fitz.Pixmap) -> np.ndarray:
|
|
@@ -741,42 +738,59 @@ def crop_and_convert_to_pil(image: np.ndarray, bbox: Tuple[float, float, float,
|
|
| 741 |
return crop_pil
|
| 742 |
|
| 743 |
|
| 744 |
-
# --- NEW: Utility to convert PIL Image to Base64 (for OCR input) ---
|
| 745 |
def pil_to_base64(img: Image.Image) -> str:
|
| 746 |
-
"""Converts a PIL Image object to a Base64 encoded string (PNG format)."""
|
| 747 |
buffer = io.BytesIO()
|
| 748 |
img.save(buffer, format="PNG")
|
| 749 |
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
| 750 |
|
| 751 |
|
| 752 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 753 |
def run_yolo_detection_and_count(
|
| 754 |
image: np.ndarray, model: YOLO, page_num: int,
|
| 755 |
current_eq_count: int, current_fig_count: int
|
| 756 |
-
) -> Tuple[
|
| 757 |
"""
|
| 758 |
-
Performs YOLO detection and returns
|
| 759 |
-
|
| 760 |
"""
|
| 761 |
|
| 762 |
eq_counter = current_eq_count
|
| 763 |
fig_counter = current_fig_count
|
| 764 |
|
| 765 |
-
|
| 766 |
-
page_figures = 0
|
| 767 |
-
# Change: detected_items now holds dictionaries: {'type', 'id', 'pil_image'}
|
| 768 |
-
detected_items: List[Dict[str, Union[Image.Image, str]]] = []
|
| 769 |
yolo_detections = []
|
| 770 |
|
| 771 |
-
# ... (YOLO inference logic is the same)
|
| 772 |
try:
|
| 773 |
results = model.predict(image, conf=CONF_THRESHOLD, verbose=False)
|
| 774 |
-
|
| 775 |
if results and results[0].boxes:
|
| 776 |
for box in results[0].boxes.data.tolist():
|
| 777 |
x1, y1, x2, y2, conf, cls_id = box
|
| 778 |
cls_name = model.names[int(cls_id)]
|
| 779 |
-
|
| 780 |
if cls_name in TARGET_CLASSES:
|
| 781 |
yolo_detections.append({
|
| 782 |
'coords': (x1, y1, x2, y2),
|
|
@@ -784,108 +798,61 @@ def run_yolo_detection_and_count(
|
|
| 784 |
'conf': conf
|
| 785 |
})
|
| 786 |
except Exception as e:
|
| 787 |
-
logging.error(f"YOLO inference failed on page {page_num}: {e}")
|
| 788 |
-
return
|
| 789 |
|
| 790 |
merged_detections = merge_overlapping_boxes(yolo_detections, IOU_MERGE_THRESHOLD)
|
| 791 |
final_detections = filter_nested_boxes(merged_detections, IOA_SUPPRESSION_THRESHOLD)
|
| 792 |
|
| 793 |
for det in final_detections:
|
| 794 |
bbox = det["coords"]
|
| 795 |
-
|
| 796 |
crop_pil = crop_and_convert_to_pil(image, bbox)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 797 |
|
| 798 |
if det["class"] == "equation":
|
| 799 |
eq_counter += 1
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
"type": "equation",
|
| 803 |
-
"id": f"EQUATION{eq_counter}",
|
| 804 |
-
"pil_image": crop_pil,
|
| 805 |
-
"latex": "" # Placeholder for OCR result
|
| 806 |
-
})
|
| 807 |
-
|
| 808 |
elif det["class"] == "figure":
|
| 809 |
fig_counter += 1
|
| 810 |
-
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
"pil_image": crop_pil,
|
| 815 |
-
"latex": "[FIGURE - No LaTeX]" # Figures don't get OCR
|
| 816 |
-
})
|
| 817 |
-
|
| 818 |
-
logging.warning(f" -> Page {page_num}: EQs={page_equations}, Figs={page_figures}")
|
| 819 |
-
return page_equations, page_figures, detected_items, eq_counter, fig_counter
|
| 820 |
-
|
| 821 |
-
|
| 822 |
-
def get_latex_from_base64(base64_string: str) -> str:
|
| 823 |
-
"""
|
| 824 |
-
Performs the OCR conversion. Expects Base64 string input.
|
| 825 |
-
"""
|
| 826 |
-
if not OCR_MODEL_LOADED:
|
| 827 |
-
return "[MODEL_ERROR: Model not initialized or failed to load]"
|
| 828 |
-
|
| 829 |
-
try:
|
| 830 |
-
# OCR logic (unchanged)
|
| 831 |
-
image_data = base64.b64decode(base64_string)
|
| 832 |
-
image = Image.open(io.BytesIO(image_data)).convert('RGB')
|
| 833 |
-
|
| 834 |
-
pixel_values = processor(images=image, return_tensors="pt").pixel_values
|
| 835 |
-
generated_ids = ort_model.generate(pixel_values)
|
| 836 |
-
raw_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| 837 |
-
|
| 838 |
-
if not raw_text:
|
| 839 |
-
return "[OCR_WARNING: No formula found]"
|
| 840 |
-
|
| 841 |
-
latex = raw_text[0]
|
| 842 |
-
latex = re.sub(r'[\r\n]+', '', latex)
|
| 843 |
-
|
| 844 |
-
return latex
|
| 845 |
-
|
| 846 |
-
except Exception as e:
|
| 847 |
-
return f"[TR_OCR_ERROR: {e}]"
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
# --- UNUSED ORIGINAL FUNCTIONS RETAINED FOR COMPLETENESS ---
|
| 851 |
|
| 852 |
-
|
| 853 |
-
# ... (body retained)
|
| 854 |
-
pass
|
| 855 |
-
|
| 856 |
-
def embed_images_as_base64_in_memory(structured_data, detected_items):
|
| 857 |
-
# ... (body retained)
|
| 858 |
-
pass
|
| 859 |
-
|
| 860 |
-
def crop_and_convert_to_base64(image: np.ndarray, bbox: Tuple[float, float, float, float]) -> str:
|
| 861 |
-
# ... (body retained)
|
| 862 |
-
pass
|
| 863 |
|
| 864 |
|
| 865 |
# ============================================================================
|
| 866 |
-
# --- MAIN DOCUMENT PROCESSING FUNCTION (
|
| 867 |
# ============================================================================
|
| 868 |
|
|
|
|
| 869 |
def run_single_pdf_preprocessing(
|
| 870 |
pdf_path: str
|
| 871 |
-
) -> Tuple[int, int, int, str, float, Dict[str, int], List[Tuple[Image.Image, str]]]:
|
| 872 |
"""
|
| 873 |
Runs the pipeline, performs OCR, and returns final results.
|
| 874 |
"""
|
| 875 |
-
|
|
|
|
|
|
|
|
|
|
| 876 |
start_time = time.time()
|
| 877 |
-
log_messages = []
|
| 878 |
|
| 879 |
-
# This will store all final extracted item dicts (image, ID, type, LATEX)
|
| 880 |
all_extracted_items: List[Dict[str, Union[Image.Image, str]]] = []
|
| 881 |
|
| 882 |
-
equation_counts_per_page: Dict[int, int] = {}
|
| 883 |
-
|
| 884 |
total_figure_count = 0
|
| 885 |
total_equation_count = 0
|
| 886 |
|
| 887 |
|
| 888 |
# 1. Validation and Model Loading (YOLO)
|
|
|
|
| 889 |
t0 = time.time()
|
| 890 |
if not os.path.exists(pdf_path):
|
| 891 |
report = f"❌ FATAL ERROR: Input PDF not found at {pdf_path}."
|
|
@@ -893,24 +860,24 @@ def run_single_pdf_preprocessing(
|
|
| 893 |
|
| 894 |
try:
|
| 895 |
model = YOLO(WEIGHTS_PATH)
|
| 896 |
-
logging.warning(f"
|
| 897 |
except Exception as e:
|
| 898 |
report = f"❌ ERROR loading YOLO model: {e}\n(Ensure 'best.pt' is available and valid.)"
|
| 899 |
return 0, 0, 0, report, time.time() - start_time, {}, []
|
| 900 |
t1 = time.time()
|
| 901 |
-
|
| 902 |
|
| 903 |
# 2. PDF Loading (fitz)
|
| 904 |
t2 = time.time()
|
| 905 |
try:
|
| 906 |
doc = fitz.open(pdf_path)
|
| 907 |
total_pages = doc.page_count
|
| 908 |
-
logging.warning(f"
|
| 909 |
except Exception as e:
|
| 910 |
report = f"❌ ERROR loading PDF file: {e}"
|
| 911 |
return 0, 0, 0, report, time.time() - start_time, {}, []
|
| 912 |
t3 = time.time()
|
| 913 |
-
|
| 914 |
|
| 915 |
mat = fitz.Matrix(SCALE_FACTOR, SCALE_FACTOR)
|
| 916 |
|
|
@@ -922,22 +889,19 @@ def run_single_pdf_preprocessing(
|
|
| 922 |
page_num = page_num_0_based + 1
|
| 923 |
|
| 924 |
# Render page to image for YOLO
|
| 925 |
-
# ... (image rendering logic retained)
|
| 926 |
try:
|
| 927 |
pix_start = time.time()
|
| 928 |
pix = fitz_page.get_pixmap(matrix=mat)
|
| 929 |
original_img = pixmap_to_numpy(pix)
|
| 930 |
pix_time = time.time() - pix_start
|
| 931 |
except Exception as e:
|
| 932 |
-
logging.error(f"Error converting page {page_num} to image: {e}. Skipping.")
|
| 933 |
continue
|
| 934 |
|
| 935 |
# YOLO Detection
|
| 936 |
detect_start = time.time()
|
| 937 |
(
|
| 938 |
-
|
| 939 |
-
page_figures,
|
| 940 |
-
page_extracted_items, # List of dicts: {'type', 'id', 'pil_image', 'latex'}
|
| 941 |
total_equation_count,
|
| 942 |
total_figure_count
|
| 943 |
) = run_yolo_detection_and_count(
|
|
@@ -947,44 +911,53 @@ def run_single_pdf_preprocessing(
|
|
| 947 |
total_equation_count,
|
| 948 |
total_figure_count
|
| 949 |
)
|
|
|
|
| 950 |
|
| 951 |
-
# ---
|
| 952 |
-
|
|
|
|
|
|
|
| 953 |
for item in page_extracted_items:
|
| 954 |
if item["type"] == "equation":
|
| 955 |
-
|
| 956 |
-
|
| 957 |
|
| 958 |
-
|
| 959 |
item["latex"] = get_latex_from_base64(b64_string)
|
| 960 |
|
| 961 |
-
|
| 962 |
-
|
| 963 |
-
|
|
|
|
| 964 |
|
| 965 |
-
# Append all extracted item dictionaries
|
| 966 |
all_extracted_items.extend(page_extracted_items)
|
| 967 |
-
|
| 968 |
-
detect_time = time.time() - detect_start
|
| 969 |
|
| 970 |
-
|
| 971 |
-
equation_counts_per_page[page_num] = page_equations
|
| 972 |
|
| 973 |
page_total_time = time.time() - page_start_time
|
| 974 |
-
|
| 975 |
|
| 976 |
doc.close()
|
| 977 |
t5 = time.time()
|
| 978 |
detection_loop_time = t5 - t4
|
| 979 |
-
|
| 980 |
|
| 981 |
# 4. Final Report Generation and Gallery Formatting
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 982 |
|
| 983 |
# Format the extracted items for the Gradio Gallery
|
| 984 |
gallery_items: List[Tuple[Image.Image, str]] = []
|
| 985 |
|
| 986 |
-
# We will include the LATEX code as the image label in the gallery
|
| 987 |
-
# If the item is a Figure, the label is just the ID.
|
| 988 |
for item in all_extracted_items:
|
| 989 |
image_label = item["id"]
|
| 990 |
if item["type"] == "equation":
|
|
@@ -995,10 +968,7 @@ def run_single_pdf_preprocessing(
|
|
| 995 |
|
| 996 |
total_execution_time = t5 - start_time
|
| 997 |
|
| 998 |
-
|
| 999 |
-
equation_counts_per_page_str_keys: Dict[str, int] = {
|
| 1000 |
-
str(k): v for k, v in equation_counts_per_page.items()
|
| 1001 |
-
}
|
| 1002 |
|
| 1003 |
report = (
|
| 1004 |
f"✅ **YOLO Counting & OCR Complete!**\n\n"
|
|
@@ -1007,23 +977,22 @@ def run_single_pdf_preprocessing(
|
|
| 1007 |
f"**3) Total Figures Detected:** **{total_figure_count}**\n"
|
| 1008 |
f"---\n"
|
| 1009 |
f"**4) Total Execution Time:** **{total_execution_time:.4f}s**\n"
|
| 1010 |
-
f"###
|
| 1011 |
-
f"
|
| 1012 |
-
|
| 1013 |
f"\n```"
|
| 1014 |
)
|
| 1015 |
|
| 1016 |
-
|
|
|
|
| 1017 |
|
| 1018 |
|
| 1019 |
# ============================================================================
|
| 1020 |
-
# --- GRADIO INTERFACE FUNCTION & DEFINITION (
|
| 1021 |
# ============================================================================
|
| 1022 |
|
| 1023 |
-
def gradio_process_pdf(pdf_file) -> Tuple[str, str, str, str, Dict[str, int], List[Tuple[Image.Image, str]]]:
|
| 1024 |
-
"""
|
| 1025 |
-
Gradio wrapper function to handle file upload and return results.
|
| 1026 |
-
"""
|
| 1027 |
if pdf_file is None:
|
| 1028 |
return "N/A", "N/A", "N/A", "Please upload a PDF file.", {}, []
|
| 1029 |
|
|
@@ -1036,18 +1005,20 @@ def gradio_process_pdf(pdf_file) -> Tuple[str, str, str, str, Dict[str, int], Li
|
|
| 1036 |
num_figures,
|
| 1037 |
report,
|
| 1038 |
total_time,
|
| 1039 |
-
|
| 1040 |
gallery_items
|
| 1041 |
) = run_single_pdf_preprocessing(pdf_path)
|
| 1042 |
|
| 1043 |
|
| 1044 |
-
return str(num_pages), str(num_equations), str(num_figures), report,
|
| 1045 |
|
| 1046 |
|
| 1047 |
except Exception as e:
|
| 1048 |
error_msg = f"An unexpected error occurred: {e}"
|
| 1049 |
-
logging.error(error_msg, exc_info=True)
|
| 1050 |
-
|
|
|
|
|
|
|
| 1051 |
|
| 1052 |
|
| 1053 |
if __name__ == "__main__":
|
|
@@ -1057,16 +1028,14 @@ if __name__ == "__main__":
|
|
| 1057 |
|
| 1058 |
input_file = gr.File(label="Upload PDF Document", type="filepath", file_types=[".pdf"])
|
| 1059 |
|
| 1060 |
-
# Outputs
|
| 1061 |
output_pages = gr.Textbox(label="Total Pages in PDF", interactive=False)
|
| 1062 |
output_equations = gr.Textbox(label="Total Equations Detected", interactive=False)
|
| 1063 |
output_figures = gr.Textbox(label="Total Figures Detected", interactive=False)
|
| 1064 |
-
output_report = gr.Markdown(label="Processing Summary and
|
| 1065 |
|
| 1066 |
-
#
|
| 1067 |
-
|
| 1068 |
|
| 1069 |
-
# Gradio Gallery now shows the LaTeX code as the label
|
| 1070 |
output_gallery = gr.Gallery(
|
| 1071 |
label="Detected Items (with Extracted LaTeX)",
|
| 1072 |
columns=3,
|
|
@@ -1083,12 +1052,12 @@ if __name__ == "__main__":
|
|
| 1083 |
output_equations,
|
| 1084 |
output_figures,
|
| 1085 |
output_report,
|
| 1086 |
-
|
| 1087 |
output_gallery
|
| 1088 |
],
|
| 1089 |
-
title="📊 YOLO Detection & Math OCR Pipeline",
|
| 1090 |
description=(
|
| 1091 |
-
"Upload a PDF. YOLO detects equations, and
|
| 1092 |
),
|
| 1093 |
)
|
| 1094 |
|
|
|
|
| 578 |
# )
|
| 579 |
|
| 580 |
|
|
|
|
|
|
|
| 581 |
import base64
|
| 582 |
from PIL import Image
|
| 583 |
import re
|
|
|
|
| 592 |
from ultralytics import YOLO
|
| 593 |
import logging
|
| 594 |
import gradio as gr
|
|
|
|
|
|
|
| 595 |
import io
|
| 596 |
+
import json
|
| 597 |
|
| 598 |
# ============================================================================
|
| 599 |
+
# --- Global Setup and Configuration ---
|
| 600 |
# ============================================================================
|
| 601 |
|
| 602 |
+
# Configure logging to write to a string buffer for display in the report
|
| 603 |
+
log_stream = io.StringIO()
|
| 604 |
+
logging.basicConfig(level=logging.WARNING, stream=log_stream, format='%(levelname)s:%(message)s')
|
| 605 |
+
|
| 606 |
# Patch torch.load to prevent weights_only error with older models
|
| 607 |
_original_torch_load = torch.load
|
| 608 |
def patched_torch_load(*args, **kwargs):
|
|
|
|
| 610 |
return _original_torch_load(*args, **kwargs)
|
| 611 |
torch.load = patched_torch_load
|
| 612 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 613 |
WEIGHTS_PATH = 'best.pt'
|
| 614 |
SCALE_FACTOR = 2.0
|
| 615 |
|
|
|
|
| 623 |
ort_model = ORTModelForVision2Seq.from_pretrained(MODEL_NAME, use_cache=False)
|
| 624 |
OCR_MODEL_LOADED = True
|
| 625 |
except Exception as e:
|
| 626 |
+
logging.warning(f"OCR model loading failed: {e}")
|
| 627 |
processor = None
|
| 628 |
ort_model = None
|
| 629 |
OCR_MODEL_LOADED = False
|
|
|
|
| 702 |
'coords': (merged_x1, merged_y1, merged_x2, merged_y2),
|
| 703 |
'y1': merged_y1, 'class': current_class, 'conf': detections[i]['conf']
|
| 704 |
})
|
| 705 |
+
# This step ensures top-to-bottom reading order for sequential numbering (EQUATION1, EQUATION2, etc.)
|
| 706 |
+
merged_detections.sort(key=lambda d: d['y1'])
|
| 707 |
return merged_detections
|
| 708 |
|
| 709 |
# ============================================================================
|
| 710 |
+
# --- UTILITY FUNCTIONS ---
|
| 711 |
# ============================================================================
|
| 712 |
|
| 713 |
def pixmap_to_numpy(pix: fitz.Pixmap) -> np.ndarray:
|
|
|
|
| 738 |
return crop_pil
|
| 739 |
|
| 740 |
|
|
|
|
| 741 |
def pil_to_base64(img: Image.Image) -> str:
|
| 742 |
+
"""Converts a PIL Image object to a Base64 encoded string (PNG format) for OCR input."""
|
| 743 |
buffer = io.BytesIO()
|
| 744 |
img.save(buffer, format="PNG")
|
| 745 |
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
| 746 |
|
| 747 |
|
| 748 |
+
def get_latex_from_base64(base64_string: str) -> str:
|
| 749 |
+
"""Performs the OCR conversion using the globally loaded model."""
|
| 750 |
+
if not OCR_MODEL_LOADED:
|
| 751 |
+
return "[MODEL_ERROR: Model not loaded]"
|
| 752 |
+
|
| 753 |
+
try:
|
| 754 |
+
image_data = base64.b64decode(base64_string)
|
| 755 |
+
image = Image.open(io.BytesIO(image_data)).convert('RGB')
|
| 756 |
+
|
| 757 |
+
pixel_values = processor(images=image, return_tensors="pt").pixel_values
|
| 758 |
+
generated_ids = ort_model.generate(pixel_values)
|
| 759 |
+
raw_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| 760 |
+
|
| 761 |
+
if not raw_text:
|
| 762 |
+
return "[OCR_WARNING: No formula found]"
|
| 763 |
+
|
| 764 |
+
latex = raw_text[0]
|
| 765 |
+
latex = re.sub(r'[\r\n]+', '', latex)
|
| 766 |
+
|
| 767 |
+
return latex
|
| 768 |
+
|
| 769 |
+
except Exception as e:
|
| 770 |
+
return f"[TR_OCR_ERROR: {e}]"
|
| 771 |
+
|
| 772 |
+
|
| 773 |
def run_yolo_detection_and_count(
|
| 774 |
image: np.ndarray, model: YOLO, page_num: int,
|
| 775 |
current_eq_count: int, current_fig_count: int
|
| 776 |
+
) -> Tuple[List[Dict[str, Union[Image.Image, str, Tuple[float,...]]]], int, int]:
|
| 777 |
"""
|
| 778 |
+
Performs YOLO detection and returns a list of detected item dictionaries
|
| 779 |
+
and the updated total counters.
|
| 780 |
"""
|
| 781 |
|
| 782 |
eq_counter = current_eq_count
|
| 783 |
fig_counter = current_fig_count
|
| 784 |
|
| 785 |
+
detected_items: List[Dict[str, Union[Image.Image, str, Tuple[float,...]]]] = []
|
|
|
|
|
|
|
|
|
|
| 786 |
yolo_detections = []
|
| 787 |
|
|
|
|
| 788 |
try:
|
| 789 |
results = model.predict(image, conf=CONF_THRESHOLD, verbose=False)
|
|
|
|
| 790 |
if results and results[0].boxes:
|
| 791 |
for box in results[0].boxes.data.tolist():
|
| 792 |
x1, y1, x2, y2, conf, cls_id = box
|
| 793 |
cls_name = model.names[int(cls_id)]
|
|
|
|
| 794 |
if cls_name in TARGET_CLASSES:
|
| 795 |
yolo_detections.append({
|
| 796 |
'coords': (x1, y1, x2, y2),
|
|
|
|
| 798 |
'conf': conf
|
| 799 |
})
|
| 800 |
except Exception as e:
|
| 801 |
+
logging.error(f"ERROR: YOLO inference failed on page {page_num}: {e}")
|
| 802 |
+
return [], eq_counter, fig_counter
|
| 803 |
|
| 804 |
merged_detections = merge_overlapping_boxes(yolo_detections, IOU_MERGE_THRESHOLD)
|
| 805 |
final_detections = filter_nested_boxes(merged_detections, IOA_SUPPRESSION_THRESHOLD)
|
| 806 |
|
| 807 |
for det in final_detections:
|
| 808 |
bbox = det["coords"]
|
|
|
|
| 809 |
crop_pil = crop_and_convert_to_pil(image, bbox)
|
| 810 |
+
|
| 811 |
+
item = {
|
| 812 |
+
"type": det["class"],
|
| 813 |
+
"coords": bbox,
|
| 814 |
+
"pil_image": crop_pil,
|
| 815 |
+
}
|
| 816 |
|
| 817 |
if det["class"] == "equation":
|
| 818 |
eq_counter += 1
|
| 819 |
+
item["id"] = f"EQUATION{eq_counter}"
|
| 820 |
+
item["latex"] = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 821 |
elif det["class"] == "figure":
|
| 822 |
fig_counter += 1
|
| 823 |
+
item["id"] = f"FIGURE{fig_counter}"
|
| 824 |
+
item["latex"] = "[FIGURE - No LaTeX]"
|
| 825 |
+
|
| 826 |
+
detected_items.append(item)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 827 |
|
| 828 |
+
return detected_items, eq_counter, fig_counter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 829 |
|
| 830 |
|
| 831 |
# ============================================================================
|
| 832 |
+
# --- MAIN DOCUMENT PROCESSING FUNCTION (MODIFIED OUTPUT) ---
|
| 833 |
# ============================================================================
|
| 834 |
|
| 835 |
+
# The return type is updated to reflect the new structured output dictionary
|
| 836 |
def run_single_pdf_preprocessing(
|
| 837 |
pdf_path: str
|
| 838 |
+
) -> Tuple[int, int, int, str, float, Dict[str, Union[int, str]], List[Tuple[Image.Image, str]]]:
|
| 839 |
"""
|
| 840 |
Runs the pipeline, performs OCR, and returns final results.
|
| 841 |
"""
|
| 842 |
+
|
| 843 |
+
log_stream.truncate(0)
|
| 844 |
+
log_stream.seek(0)
|
| 845 |
+
|
| 846 |
start_time = time.time()
|
|
|
|
| 847 |
|
|
|
|
| 848 |
all_extracted_items: List[Dict[str, Union[Image.Image, str]]] = []
|
| 849 |
|
|
|
|
|
|
|
| 850 |
total_figure_count = 0
|
| 851 |
total_equation_count = 0
|
| 852 |
|
| 853 |
|
| 854 |
# 1. Validation and Model Loading (YOLO)
|
| 855 |
+
# ... (Model loading logic retained)
|
| 856 |
t0 = time.time()
|
| 857 |
if not os.path.exists(pdf_path):
|
| 858 |
report = f"❌ FATAL ERROR: Input PDF not found at {pdf_path}."
|
|
|
|
| 860 |
|
| 861 |
try:
|
| 862 |
model = YOLO(WEIGHTS_PATH)
|
| 863 |
+
logging.warning(f"INFO: Loaded YOLO model from: {WEIGHTS_PATH}")
|
| 864 |
except Exception as e:
|
| 865 |
report = f"❌ ERROR loading YOLO model: {e}\n(Ensure 'best.pt' is available and valid.)"
|
| 866 |
return 0, 0, 0, report, time.time() - start_time, {}, []
|
| 867 |
t1 = time.time()
|
| 868 |
+
logging.warning(f"INFO: Model Loading Time: {t1-t0:.4f}s")
|
| 869 |
|
| 870 |
# 2. PDF Loading (fitz)
|
| 871 |
t2 = time.time()
|
| 872 |
try:
|
| 873 |
doc = fitz.open(pdf_path)
|
| 874 |
total_pages = doc.page_count
|
| 875 |
+
logging.warning(f"INFO: Opened PDF with {doc.page_count} pages")
|
| 876 |
except Exception as e:
|
| 877 |
report = f"❌ ERROR loading PDF file: {e}"
|
| 878 |
return 0, 0, 0, report, time.time() - start_time, {}, []
|
| 879 |
t3 = time.time()
|
| 880 |
+
logging.warning(f"INFO: PDF Initialization Time: {t3-t2:.4f}s")
|
| 881 |
|
| 882 |
mat = fitz.Matrix(SCALE_FACTOR, SCALE_FACTOR)
|
| 883 |
|
|
|
|
| 889 |
page_num = page_num_0_based + 1
|
| 890 |
|
| 891 |
# Render page to image for YOLO
|
|
|
|
| 892 |
try:
|
| 893 |
pix_start = time.time()
|
| 894 |
pix = fitz_page.get_pixmap(matrix=mat)
|
| 895 |
original_img = pixmap_to_numpy(pix)
|
| 896 |
pix_time = time.time() - pix_start
|
| 897 |
except Exception as e:
|
| 898 |
+
logging.error(f"ERROR: Error converting page {page_num} to image: {e}. Skipping.")
|
| 899 |
continue
|
| 900 |
|
| 901 |
# YOLO Detection
|
| 902 |
detect_start = time.time()
|
| 903 |
(
|
| 904 |
+
page_extracted_items,
|
|
|
|
|
|
|
| 905 |
total_equation_count,
|
| 906 |
total_figure_count
|
| 907 |
) = run_yolo_detection_and_count(
|
|
|
|
| 911 |
total_equation_count,
|
| 912 |
total_figure_count
|
| 913 |
)
|
| 914 |
+
detect_time = time.time() - detect_start
|
| 915 |
|
| 916 |
+
# --- OCR/LaTeX Conversion and Logging ---
|
| 917 |
+
ocr_total_time = 0
|
| 918 |
+
page_equations = 0
|
| 919 |
+
|
| 920 |
for item in page_extracted_items:
|
| 921 |
if item["type"] == "equation":
|
| 922 |
+
page_equations += 1
|
| 923 |
+
ocr_start = time.time()
|
| 924 |
|
| 925 |
+
b64_string = pil_to_base64(item["pil_image"])
|
| 926 |
item["latex"] = get_latex_from_base64(b64_string)
|
| 927 |
|
| 928 |
+
ocr_time = time.time() - ocr_start
|
| 929 |
+
ocr_total_time += ocr_time
|
| 930 |
+
|
| 931 |
+
logging.warning(f"LATEX: Page {page_num}, ID {item['id']} -> Time: {ocr_time:.4f}s, Formula: {item['latex'][:50]}...")
|
| 932 |
|
|
|
|
| 933 |
all_extracted_items.extend(page_extracted_items)
|
|
|
|
|
|
|
| 934 |
|
| 935 |
+
page_figures = sum(1 for item in page_extracted_items if item["type"] == "figure")
|
|
|
|
| 936 |
|
| 937 |
page_total_time = time.time() - page_start_time
|
| 938 |
+
logging.warning(f"SUMMARY: Page {page_num}: EQs={page_equations}, Figs={page_figures} | Page Time: {page_total_time:.4f}s (Detect={detect_time:.4f}s, OCR Total={ocr_total_time:.4f}s)")
|
| 939 |
|
| 940 |
doc.close()
|
| 941 |
t5 = time.time()
|
| 942 |
detection_loop_time = t5 - t4
|
| 943 |
+
logging.warning(f"INFO: Total Detection and OCR Loop Time ({total_pages} pages): {detection_loop_time:.4f}s")
|
| 944 |
|
| 945 |
# 4. Final Report Generation and Gallery Formatting
|
| 946 |
+
|
| 947 |
+
# --- NEW: Create the structured JSON output as requested by the user ---
|
| 948 |
+
structured_latex_output = {
|
| 949 |
+
"Total Pages": total_pages,
|
| 950 |
+
"Total Equations": total_equation_count,
|
| 951 |
+
}
|
| 952 |
+
for item in all_extracted_items:
|
| 953 |
+
if item["type"] == "equation":
|
| 954 |
+
# Map EQUATION ID to LaTeX code
|
| 955 |
+
structured_latex_output[item["id"]] = item["latex"]
|
| 956 |
+
|
| 957 |
|
| 958 |
# Format the extracted items for the Gradio Gallery
|
| 959 |
gallery_items: List[Tuple[Image.Image, str]] = []
|
| 960 |
|
|
|
|
|
|
|
| 961 |
for item in all_extracted_items:
|
| 962 |
image_label = item["id"]
|
| 963 |
if item["type"] == "equation":
|
|
|
|
| 968 |
|
| 969 |
total_execution_time = t5 - start_time
|
| 970 |
|
| 971 |
+
full_log = log_stream.getvalue()
|
|
|
|
|
|
|
|
|
|
| 972 |
|
| 973 |
report = (
|
| 974 |
f"✅ **YOLO Counting & OCR Complete!**\n\n"
|
|
|
|
| 977 |
f"**3) Total Figures Detected:** **{total_figure_count}**\n"
|
| 978 |
f"---\n"
|
| 979 |
f"**4) Total Execution Time:** **{total_execution_time:.4f}s**\n"
|
| 980 |
+
f"### Full Processing Log\n"
|
| 981 |
+
f"```text\n"
|
| 982 |
+
f"{full_log}"
|
| 983 |
f"\n```"
|
| 984 |
)
|
| 985 |
|
| 986 |
+
# Return the new structured_latex_output instead of the page counts
|
| 987 |
+
return total_pages, total_equation_count, total_figure_count, report, total_execution_time, structured_latex_output, gallery_items
|
| 988 |
|
| 989 |
|
| 990 |
# ============================================================================
|
| 991 |
+
# --- GRADIO INTERFACE FUNCTION & DEFINITION (MODIFIED OUTPUT) ---
|
| 992 |
# ============================================================================
|
| 993 |
|
| 994 |
+
def gradio_process_pdf(pdf_file) -> Tuple[str, str, str, str, Dict[str, Union[int, str]], List[Tuple[Image.Image, str]]]:
|
| 995 |
+
"""Gradio wrapper function to handle file upload and return results."""
|
|
|
|
|
|
|
| 996 |
if pdf_file is None:
|
| 997 |
return "N/A", "N/A", "N/A", "Please upload a PDF file.", {}, []
|
| 998 |
|
|
|
|
| 1005 |
num_figures,
|
| 1006 |
report,
|
| 1007 |
total_time,
|
| 1008 |
+
structured_latex_output, # Variable name changed to match the new output
|
| 1009 |
gallery_items
|
| 1010 |
) = run_single_pdf_preprocessing(pdf_path)
|
| 1011 |
|
| 1012 |
|
| 1013 |
+
return str(num_pages), str(num_equations), str(num_figures), report, structured_latex_output, gallery_items
|
| 1014 |
|
| 1015 |
|
| 1016 |
except Exception as e:
|
| 1017 |
error_msg = f"An unexpected error occurred: {e}"
|
| 1018 |
+
logging.error(f"FATAL: {error_msg}", exc_info=True)
|
| 1019 |
+
full_log = log_stream.getvalue()
|
| 1020 |
+
error_report = f"❌ CRITICAL ERROR:\n{error_msg}\n\n### Log up to Failure\n```text\n{full_log}\n```"
|
| 1021 |
+
return "Error", "Error", "Error", error_report, {}, []
|
| 1022 |
|
| 1023 |
|
| 1024 |
if __name__ == "__main__":
|
|
|
|
| 1028 |
|
| 1029 |
input_file = gr.File(label="Upload PDF Document", type="filepath", file_types=[".pdf"])
|
| 1030 |
|
|
|
|
| 1031 |
output_pages = gr.Textbox(label="Total Pages in PDF", interactive=False)
|
| 1032 |
output_equations = gr.Textbox(label="Total Equations Detected", interactive=False)
|
| 1033 |
output_figures = gr.Textbox(label="Total Figures Detected", interactive=False)
|
| 1034 |
+
output_report = gr.Markdown(label="Processing Summary and Full Log")
|
| 1035 |
|
| 1036 |
+
# This JSON component now displays the structured output requested by the user
|
| 1037 |
+
output_structured_latex = gr.JSON(label="Structured LaTeX Output (EQUATIONx : <latex code>)")
|
| 1038 |
|
|
|
|
| 1039 |
output_gallery = gr.Gallery(
|
| 1040 |
label="Detected Items (with Extracted LaTeX)",
|
| 1041 |
columns=3,
|
|
|
|
| 1052 |
output_equations,
|
| 1053 |
output_figures,
|
| 1054 |
output_report,
|
| 1055 |
+
output_structured_latex, # Updated component
|
| 1056 |
output_gallery
|
| 1057 |
],
|
| 1058 |
+
title="📊 YOLO Detection & Math OCR Pipeline (Structured Output)",
|
| 1059 |
description=(
|
| 1060 |
+
"Upload a PDF. YOLO detects equations/figures, and OCR converts equations to LaTeX. See the Structured LaTeX Output panel for the requested format."
|
| 1061 |
),
|
| 1062 |
)
|
| 1063 |
|