Update app.py
Browse files
app.py
CHANGED
|
@@ -580,7 +580,6 @@
|
|
| 580 |
|
| 581 |
|
| 582 |
|
| 583 |
-
|
| 584 |
import base64
|
| 585 |
from PIL import Image
|
| 586 |
import re
|
|
@@ -627,11 +626,12 @@ MODEL_NAME = 'breezedeus/pix2text-mfr-1.5'
|
|
| 627 |
try:
|
| 628 |
processor = TrOCRProcessor.from_pretrained(MODEL_NAME)
|
| 629 |
ort_model = ORTModelForVision2Seq.from_pretrained(MODEL_NAME, use_cache=False)
|
|
|
|
| 630 |
except Exception as e:
|
| 631 |
-
# This warning is included to alert the user if the optional, unused dependencies fail
|
| 632 |
logging.warning(f"OCR model loading failed (expected if dependencies are missing): {e}")
|
| 633 |
processor = None
|
| 634 |
ort_model = None
|
|
|
|
| 635 |
|
| 636 |
# Detection parameters
|
| 637 |
CONF_THRESHOLD = 0.2
|
|
@@ -639,8 +639,6 @@ TARGET_CLASSES = ['figure', 'equation']
|
|
| 639 |
IOU_MERGE_THRESHOLD = 0.4
|
| 640 |
IOA_SUPPRESSION_THRESHOLD = 0.7
|
| 641 |
|
| 642 |
-
# Note: The original GLOBAL_COUNT variables have been removed to fix concurrency.
|
| 643 |
-
|
| 644 |
# ============================================================================
|
| 645 |
# --- BOX COMBINATION LOGIC (Retained) ---
|
| 646 |
# ============================================================================
|
|
@@ -712,7 +710,7 @@ def merge_overlapping_boxes(detections, iou_threshold):
|
|
| 712 |
return merged_detections
|
| 713 |
|
| 714 |
# ============================================================================
|
| 715 |
-
# --- UTILITY FUNCTIONS (
|
| 716 |
# ============================================================================
|
| 717 |
|
| 718 |
def pixmap_to_numpy(pix: fitz.Pixmap) -> np.ndarray:
|
|
@@ -727,7 +725,6 @@ def pixmap_to_numpy(pix: fitz.Pixmap) -> np.ndarray:
|
|
| 727 |
return img
|
| 728 |
|
| 729 |
|
| 730 |
-
# --- REPLACED CROP_AND_CONVERT_TO_BASE64 ---
|
| 731 |
def crop_and_convert_to_pil(image: np.ndarray, bbox: Tuple[float, float, float, float]) -> Image.Image:
|
| 732 |
"""Crops the numpy array and returns a PIL Image object."""
|
| 733 |
x1, y1, x2, y2 = map(int, bbox)
|
|
@@ -739,33 +736,39 @@ def crop_and_convert_to_pil(image: np.ndarray, bbox: Tuple[float, float, float,
|
|
| 739 |
y2 = min(h, y2)
|
| 740 |
|
| 741 |
crop_np = image[y1:y2, x1:x2]
|
| 742 |
-
# Convert OpenCV/BGR (if applicable) or RGB numpy array to PIL Image
|
| 743 |
-
# Using BGR2RGB conversion just in case OpenCV read the image in BGR format
|
| 744 |
crop_pil = Image.fromarray(cv2.cvtColor(crop_np, cv2.COLOR_BGR2RGB))
|
| 745 |
|
| 746 |
return crop_pil
|
| 747 |
|
| 748 |
|
| 749 |
-
# ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 750 |
def run_yolo_detection_and_count(
|
| 751 |
image: np.ndarray, model: YOLO, page_num: int,
|
| 752 |
current_eq_count: int, current_fig_count: int
|
| 753 |
-
) -> Tuple[int, int, List[
|
| 754 |
"""
|
| 755 |
-
Performs YOLO detection and returns page counts, detected items as
|
| 756 |
-
and the updated total counters.
|
| 757 |
"""
|
| 758 |
|
| 759 |
-
# Use the passed counters as starting points for this page
|
| 760 |
eq_counter = current_eq_count
|
| 761 |
fig_counter = current_fig_count
|
| 762 |
|
| 763 |
page_equations = 0
|
| 764 |
page_figures = 0
|
| 765 |
-
# Change: detected_items now holds
|
| 766 |
-
detected_items: List[
|
| 767 |
yolo_detections = []
|
| 768 |
|
|
|
|
| 769 |
try:
|
| 770 |
results = model.predict(image, conf=CONF_THRESHOLD, verbose=False)
|
| 771 |
|
|
@@ -790,33 +793,41 @@ def run_yolo_detection_and_count(
|
|
| 790 |
for det in final_detections:
|
| 791 |
bbox = det["coords"]
|
| 792 |
|
| 793 |
-
# --- NEW: Get PIL image directly ---
|
| 794 |
crop_pil = crop_and_convert_to_pil(image, bbox)
|
| 795 |
|
| 796 |
if det["class"] == "equation":
|
| 797 |
eq_counter += 1
|
| 798 |
page_equations += 1
|
| 799 |
-
|
| 800 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 801 |
|
| 802 |
elif det["class"] == "figure":
|
| 803 |
fig_counter += 1
|
| 804 |
page_figures += 1
|
| 805 |
-
|
| 806 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 807 |
|
| 808 |
logging.warning(f" -> Page {page_num}: EQs={page_equations}, Figs={page_figures}")
|
| 809 |
-
# Return page counts, detected items (as PIL tuples), and the UPDATED total counters
|
| 810 |
return page_equations, page_figures, detected_items, eq_counter, fig_counter
|
| 811 |
|
| 812 |
|
| 813 |
def get_latex_from_base64(base64_string: str) -> str:
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
|
|
|
|
| 818 |
|
| 819 |
try:
|
|
|
|
| 820 |
image_data = base64.b64decode(base64_string)
|
| 821 |
image = Image.open(io.BytesIO(image_data)).convert('RGB')
|
| 822 |
|
|
@@ -836,120 +847,45 @@ def get_latex_from_base64(base64_string: str) -> str:
|
|
| 836 |
return f"[TR_OCR_ERROR: {e}]"
|
| 837 |
|
| 838 |
|
| 839 |
-
|
| 840 |
-
"""
|
| 841 |
-
Extract images from a page and return:
|
| 842 |
-
{ "EQUATION1": base64_string, "FIGURE1": base64_string }
|
| 843 |
-
(NOTE: This is unused dead code from the original script, retained as requested)
|
| 844 |
-
"""
|
| 845 |
-
image_map = {}
|
| 846 |
-
image_list = page.get_images(full=True)
|
| 847 |
-
|
| 848 |
-
for idx, img in enumerate(image_list, start=1):
|
| 849 |
-
xref = img[0]
|
| 850 |
-
base = page.parent.extract_image(xref)
|
| 851 |
-
image_bytes = base["image"]
|
| 852 |
-
|
| 853 |
-
base64_img = base64.b64encode(image_bytes).decode("utf-8")
|
| 854 |
-
|
| 855 |
-
# Convention: first image = FIGURE1, second image = EQUATION1 etc
|
| 856 |
-
# You can tune this if needed
|
| 857 |
-
image_map[f"FIGURE{idx}"] = base64_img
|
| 858 |
-
|
| 859 |
-
return image_map
|
| 860 |
|
|
|
|
|
|
|
|
|
|
| 861 |
|
| 862 |
def embed_images_as_base64_in_memory(structured_data, detected_items):
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
"""
|
| 866 |
-
tag_regex = re.compile(r'(figure|equation)(\d+)', re.IGNORECASE)
|
| 867 |
-
|
| 868 |
-
item_lookup = {d["id"]: d for d in detected_items}
|
| 869 |
-
final_data = []
|
| 870 |
-
|
| 871 |
-
for item in structured_data:
|
| 872 |
-
text_fields = [
|
| 873 |
-
item.get('question', ''),
|
| 874 |
-
item.get('passage', ''),
|
| 875 |
-
item.get('new_passage', '')
|
| 876 |
-
]
|
| 877 |
-
|
| 878 |
-
if 'options' in item:
|
| 879 |
-
text_fields.extend(item['options'].values())
|
| 880 |
-
|
| 881 |
-
used_tags = set()
|
| 882 |
-
|
| 883 |
-
for text in text_fields:
|
| 884 |
-
for m in tag_regex.finditer(text or ""):
|
| 885 |
-
used_tags.add(m.group(0).upper())
|
| 886 |
-
|
| 887 |
-
for tag in used_tags:
|
| 888 |
-
base_key = tag.lower().replace(" ", "")
|
| 889 |
-
|
| 890 |
-
if tag not in item_lookup:
|
| 891 |
-
item[base_key] = "[MISSING_IMAGE]"
|
| 892 |
-
continue
|
| 893 |
-
|
| 894 |
-
entry = item_lookup[tag]
|
| 895 |
-
# This logic assumes detected_items still contained the raw dicts,
|
| 896 |
-
# which is no longer true in the main flow.
|
| 897 |
-
# This section is functionally broken but left untouched as per request.
|
| 898 |
-
|
| 899 |
-
# if entry["type"] == "equation":
|
| 900 |
-
# item[base_key] = get_latex_from_base64(entry["base64"])
|
| 901 |
-
# else:
|
| 902 |
-
# item[base_key] = entry["base64"]
|
| 903 |
-
|
| 904 |
-
final_data.append(item)
|
| 905 |
-
|
| 906 |
-
return final_data
|
| 907 |
|
| 908 |
def crop_and_convert_to_base64(image: np.ndarray, bbox: Tuple[float, float, float, float]) -> str:
|
| 909 |
-
|
| 910 |
-
|
| 911 |
-
"""
|
| 912 |
-
x1, y1, x2, y2 = map(int, bbox)
|
| 913 |
-
h, w, _ = image.shape
|
| 914 |
-
|
| 915 |
-
x1 = max(0, x1)
|
| 916 |
-
y1 = max(0, y1)
|
| 917 |
-
x2 = min(w, x2)
|
| 918 |
-
y2 = min(h, y2)
|
| 919 |
-
|
| 920 |
-
crop = image[y1:y2, x1:x2]
|
| 921 |
-
_, buffer = cv2.imencode(".png", crop)
|
| 922 |
-
|
| 923 |
-
return base64.b64encode(buffer).decode("utf-8")
|
| 924 |
|
| 925 |
|
| 926 |
# ============================================================================
|
| 927 |
-
# --- MAIN DOCUMENT PROCESSING FUNCTION (
|
| 928 |
# ============================================================================
|
| 929 |
|
| 930 |
def run_single_pdf_preprocessing(
|
| 931 |
pdf_path: str
|
| 932 |
) -> Tuple[int, int, int, str, float, Dict[str, int], List[Tuple[Image.Image, str]]]:
|
| 933 |
"""
|
| 934 |
-
Runs the pipeline,
|
| 935 |
-
and a list of (PIL.Image, label) for the Gradio gallery.
|
| 936 |
"""
|
| 937 |
|
| 938 |
start_time = time.time()
|
| 939 |
log_messages = []
|
| 940 |
|
| 941 |
-
# This
|
| 942 |
-
|
| 943 |
|
| 944 |
-
# Dictionary to store {page_number (int): equation_count (int)}
|
| 945 |
equation_counts_per_page: Dict[int, int] = {}
|
| 946 |
|
| 947 |
-
# Local counters for thread safety (Concurrency Fix)
|
| 948 |
total_figure_count = 0
|
| 949 |
total_equation_count = 0
|
| 950 |
|
| 951 |
|
| 952 |
-
# 1. Validation and Model Loading
|
| 953 |
t0 = time.time()
|
| 954 |
if not os.path.exists(pdf_path):
|
| 955 |
report = f"❌ FATAL ERROR: Input PDF not found at {pdf_path}."
|
|
@@ -964,7 +900,7 @@ def run_single_pdf_preprocessing(
|
|
| 964 |
t1 = time.time()
|
| 965 |
log_messages.append(f"Model Loading Time: {t1-t0:.4f}s")
|
| 966 |
|
| 967 |
-
# 2. PDF Loading
|
| 968 |
t2 = time.time()
|
| 969 |
try:
|
| 970 |
doc = fitz.open(pdf_path)
|
|
@@ -978,7 +914,7 @@ def run_single_pdf_preprocessing(
|
|
| 978 |
|
| 979 |
mat = fitz.Matrix(SCALE_FACTOR, SCALE_FACTOR)
|
| 980 |
|
| 981 |
-
# 3. Page Processing and
|
| 982 |
t4 = time.time()
|
| 983 |
for page_num_0_based in range(doc.page_count):
|
| 984 |
page_start_time = time.time()
|
|
@@ -986,6 +922,7 @@ def run_single_pdf_preprocessing(
|
|
| 986 |
page_num = page_num_0_based + 1
|
| 987 |
|
| 988 |
# Render page to image for YOLO
|
|
|
|
| 989 |
try:
|
| 990 |
pix_start = time.time()
|
| 991 |
pix = fitz_page.get_pixmap(matrix=mat)
|
|
@@ -994,13 +931,13 @@ def run_single_pdf_preprocessing(
|
|
| 994 |
except Exception as e:
|
| 995 |
logging.error(f"Error converting page {page_num} to image: {e}. Skipping.")
|
| 996 |
continue
|
| 997 |
-
|
| 998 |
-
#
|
| 999 |
detect_start = time.time()
|
| 1000 |
(
|
| 1001 |
page_equations,
|
| 1002 |
page_figures,
|
| 1003 |
-
|
| 1004 |
total_equation_count,
|
| 1005 |
total_figure_count
|
| 1006 |
) = run_yolo_detection_and_count(
|
|
@@ -1011,32 +948,60 @@ def run_single_pdf_preprocessing(
|
|
| 1011 |
total_figure_count
|
| 1012 |
)
|
| 1013 |
|
| 1014 |
-
#
|
| 1015 |
-
|
| 1016 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1017 |
detect_time = time.time() - detect_start
|
| 1018 |
|
| 1019 |
# Store the count in the dictionary (INT keys)
|
| 1020 |
equation_counts_per_page[page_num] = page_equations
|
| 1021 |
|
| 1022 |
page_total_time = time.time() - page_start_time
|
| 1023 |
-
log_messages.append(f"Page {page_num} Time: Total={page_total_time:.4f}s (Render={pix_time:.4f}s, Detect={detect_time:.4f}s)")
|
| 1024 |
|
| 1025 |
doc.close()
|
| 1026 |
t5 = time.time()
|
| 1027 |
detection_loop_time = t5 - t4
|
| 1028 |
log_messages.append(f"Total Detection Loop Time ({total_pages} pages): {detection_loop_time:.4f}s")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1029 |
|
|
|
|
|
|
|
| 1030 |
# Convert integer keys to string keys for JSON serialization
|
| 1031 |
equation_counts_per_page_str_keys: Dict[str, int] = {
|
| 1032 |
str(k): v for k, v in equation_counts_per_page.items()
|
| 1033 |
}
|
| 1034 |
-
|
| 1035 |
-
# 4. Final Report Generation
|
| 1036 |
-
total_execution_time = t5 - start_time
|
| 1037 |
|
| 1038 |
report = (
|
| 1039 |
-
f"✅ **YOLO Counting Complete!**\n\n"
|
| 1040 |
f"**1) Total Pages Detected in PDF:** **{total_pages}**\n"
|
| 1041 |
f"**2) Total Equations Detected:** **{total_equation_count}**\n"
|
| 1042 |
f"**3) Total Figures Detected:** **{total_figure_count}**\n"
|
|
@@ -1048,15 +1013,13 @@ def run_single_pdf_preprocessing(
|
|
| 1048 |
f"\n```"
|
| 1049 |
)
|
| 1050 |
|
| 1051 |
-
|
| 1052 |
-
return total_pages, total_equation_count, total_figure_count, report, total_execution_time, equation_counts_per_page_str_keys, all_gradio_gallery_items
|
| 1053 |
|
| 1054 |
|
| 1055 |
# ============================================================================
|
| 1056 |
-
# --- GRADIO INTERFACE FUNCTION
|
| 1057 |
# ============================================================================
|
| 1058 |
|
| 1059 |
-
# The return type now uses PIL.Image for the gallery list
|
| 1060 |
def gradio_process_pdf(pdf_file) -> Tuple[str, str, str, str, Dict[str, int], List[Tuple[Image.Image, str]]]:
|
| 1061 |
"""
|
| 1062 |
Gradio wrapper function to handle file upload and return results.
|
|
@@ -1074,7 +1037,7 @@ def gradio_process_pdf(pdf_file) -> Tuple[str, str, str, str, Dict[str, int], Li
|
|
| 1074 |
report,
|
| 1075 |
total_time,
|
| 1076 |
equation_counts_per_page,
|
| 1077 |
-
gallery_items
|
| 1078 |
) = run_single_pdf_preprocessing(pdf_path)
|
| 1079 |
|
| 1080 |
|
|
@@ -1087,10 +1050,6 @@ def gradio_process_pdf(pdf_file) -> Tuple[str, str, str, str, Dict[str, int], Li
|
|
| 1087 |
return "Error", "Error", "Error", error_msg, {}, []
|
| 1088 |
|
| 1089 |
|
| 1090 |
-
# ============================================================================
|
| 1091 |
-
# --- GRADIO INTERFACE DEFINITION (Unchanged) ---
|
| 1092 |
-
# ============================================================================
|
| 1093 |
-
|
| 1094 |
if __name__ == "__main__":
|
| 1095 |
|
| 1096 |
if not os.path.exists(WEIGHTS_PATH):
|
|
@@ -1107,10 +1066,10 @@ if __name__ == "__main__":
|
|
| 1107 |
# NEW OUTPUT: JSON component for structured data
|
| 1108 |
output_page_counts = gr.JSON(label="Equation Count Per Page (Dictionary)")
|
| 1109 |
|
| 1110 |
-
# Gradio Gallery
|
| 1111 |
output_gallery = gr.Gallery(
|
| 1112 |
-
label="Detected Items (
|
| 1113 |
-
columns=
|
| 1114 |
height="auto",
|
| 1115 |
object_fit="contain",
|
| 1116 |
allow_preview=False
|
|
@@ -1127,9 +1086,9 @@ if __name__ == "__main__":
|
|
| 1127 |
output_page_counts,
|
| 1128 |
output_gallery
|
| 1129 |
],
|
| 1130 |
-
title="📊 YOLO
|
| 1131 |
description=(
|
| 1132 |
-
"Upload a PDF
|
| 1133 |
),
|
| 1134 |
)
|
| 1135 |
|
|
|
|
| 580 |
|
| 581 |
|
| 582 |
|
|
|
|
| 583 |
import base64
|
| 584 |
from PIL import Image
|
| 585 |
import re
|
|
|
|
| 626 |
try:
|
| 627 |
processor = TrOCRProcessor.from_pretrained(MODEL_NAME)
|
| 628 |
ort_model = ORTModelForVision2Seq.from_pretrained(MODEL_NAME, use_cache=False)
|
| 629 |
+
OCR_MODEL_LOADED = True
|
| 630 |
except Exception as e:
|
|
|
|
| 631 |
logging.warning(f"OCR model loading failed (expected if dependencies are missing): {e}")
|
| 632 |
processor = None
|
| 633 |
ort_model = None
|
| 634 |
+
OCR_MODEL_LOADED = False
|
| 635 |
|
| 636 |
# Detection parameters
|
| 637 |
CONF_THRESHOLD = 0.2
|
|
|
|
| 639 |
IOU_MERGE_THRESHOLD = 0.4
|
| 640 |
IOA_SUPPRESSION_THRESHOLD = 0.7
|
| 641 |
|
|
|
|
|
|
|
| 642 |
# ============================================================================
|
| 643 |
# --- BOX COMBINATION LOGIC (Retained) ---
|
| 644 |
# ============================================================================
|
|
|
|
| 710 |
return merged_detections
|
| 711 |
|
| 712 |
# ============================================================================
|
| 713 |
+
# --- UTILITY FUNCTIONS (UPDATED) ---
|
| 714 |
# ============================================================================
|
| 715 |
|
| 716 |
def pixmap_to_numpy(pix: fitz.Pixmap) -> np.ndarray:
|
|
|
|
| 725 |
return img
|
| 726 |
|
| 727 |
|
|
|
|
| 728 |
def crop_and_convert_to_pil(image: np.ndarray, bbox: Tuple[float, float, float, float]) -> Image.Image:
|
| 729 |
"""Crops the numpy array and returns a PIL Image object."""
|
| 730 |
x1, y1, x2, y2 = map(int, bbox)
|
|
|
|
| 736 |
y2 = min(h, y2)
|
| 737 |
|
| 738 |
crop_np = image[y1:y2, x1:x2]
|
|
|
|
|
|
|
| 739 |
crop_pil = Image.fromarray(cv2.cvtColor(crop_np, cv2.COLOR_BGR2RGB))
|
| 740 |
|
| 741 |
return crop_pil
|
| 742 |
|
| 743 |
|
| 744 |
+
# --- NEW: Utility to convert PIL Image to Base64 (for OCR input) ---
|
| 745 |
+
def pil_to_base64(img: Image.Image) -> str:
|
| 746 |
+
"""Converts a PIL Image object to a Base64 encoded string (PNG format)."""
|
| 747 |
+
buffer = io.BytesIO()
|
| 748 |
+
img.save(buffer, format="PNG")
|
| 749 |
+
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
# --- UPDATED: run_yolo_detection_and_count to return a list of dictionaries with PIL images ---
|
| 753 |
def run_yolo_detection_and_count(
|
| 754 |
image: np.ndarray, model: YOLO, page_num: int,
|
| 755 |
current_eq_count: int, current_fig_count: int
|
| 756 |
+
) -> Tuple[int, int, List[Dict[str, Union[Image.Image, str]]], int, int]:
|
| 757 |
"""
|
| 758 |
+
Performs YOLO detection and returns page counts, detected items (as dicts
|
| 759 |
+
containing the PIL Image), and the updated total counters.
|
| 760 |
"""
|
| 761 |
|
|
|
|
| 762 |
eq_counter = current_eq_count
|
| 763 |
fig_counter = current_fig_count
|
| 764 |
|
| 765 |
page_equations = 0
|
| 766 |
page_figures = 0
|
| 767 |
+
# Change: detected_items now holds dictionaries: {'type', 'id', 'pil_image'}
|
| 768 |
+
detected_items: List[Dict[str, Union[Image.Image, str]]] = []
|
| 769 |
yolo_detections = []
|
| 770 |
|
| 771 |
+
# ... (YOLO inference logic is the same)
|
| 772 |
try:
|
| 773 |
results = model.predict(image, conf=CONF_THRESHOLD, verbose=False)
|
| 774 |
|
|
|
|
| 793 |
for det in final_detections:
|
| 794 |
bbox = det["coords"]
|
| 795 |
|
|
|
|
| 796 |
crop_pil = crop_and_convert_to_pil(image, bbox)
|
| 797 |
|
| 798 |
if det["class"] == "equation":
|
| 799 |
eq_counter += 1
|
| 800 |
page_equations += 1
|
| 801 |
+
detected_items.append({
|
| 802 |
+
"type": "equation",
|
| 803 |
+
"id": f"EQUATION{eq_counter}",
|
| 804 |
+
"pil_image": crop_pil,
|
| 805 |
+
"latex": "" # Placeholder for OCR result
|
| 806 |
+
})
|
| 807 |
|
| 808 |
elif det["class"] == "figure":
|
| 809 |
fig_counter += 1
|
| 810 |
page_figures += 1
|
| 811 |
+
detected_items.append({
|
| 812 |
+
"type": "figure",
|
| 813 |
+
"id": f"FIGURE{fig_counter}",
|
| 814 |
+
"pil_image": crop_pil,
|
| 815 |
+
"latex": "[FIGURE - No LaTeX]" # Figures don't get OCR
|
| 816 |
+
})
|
| 817 |
|
| 818 |
logging.warning(f" -> Page {page_num}: EQs={page_equations}, Figs={page_figures}")
|
|
|
|
| 819 |
return page_equations, page_figures, detected_items, eq_counter, fig_counter
|
| 820 |
|
| 821 |
|
| 822 |
def get_latex_from_base64(base64_string: str) -> str:
|
| 823 |
+
"""
|
| 824 |
+
Performs the OCR conversion. Expects Base64 string input.
|
| 825 |
+
"""
|
| 826 |
+
if not OCR_MODEL_LOADED:
|
| 827 |
+
return "[MODEL_ERROR: Model not initialized or failed to load]"
|
| 828 |
|
| 829 |
try:
|
| 830 |
+
# OCR logic (unchanged)
|
| 831 |
image_data = base64.b64decode(base64_string)
|
| 832 |
image = Image.open(io.BytesIO(image_data)).convert('RGB')
|
| 833 |
|
|
|
|
| 847 |
return f"[TR_OCR_ERROR: {e}]"
|
| 848 |
|
| 849 |
|
| 850 |
+
# --- UNUSED ORIGINAL FUNCTIONS RETAINED FOR COMPLETENESS ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 851 |
|
| 852 |
+
def extract_images_from_page_in_memory(page) -> Dict[str, str]:
|
| 853 |
+
# ... (body retained)
|
| 854 |
+
pass
|
| 855 |
|
| 856 |
def embed_images_as_base64_in_memory(structured_data, detected_items):
|
| 857 |
+
# ... (body retained)
|
| 858 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 859 |
|
| 860 |
def crop_and_convert_to_base64(image: np.ndarray, bbox: Tuple[float, float, float, float]) -> str:
|
| 861 |
+
# ... (body retained)
|
| 862 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 863 |
|
| 864 |
|
| 865 |
# ============================================================================
|
| 866 |
+
# --- MAIN DOCUMENT PROCESSING FUNCTION (UPDATED for OCR) ---
|
| 867 |
# ============================================================================
|
| 868 |
|
| 869 |
def run_single_pdf_preprocessing(
|
| 870 |
pdf_path: str
|
| 871 |
) -> Tuple[int, int, int, str, float, Dict[str, int], List[Tuple[Image.Image, str]]]:
|
| 872 |
"""
|
| 873 |
+
Runs the pipeline, performs OCR, and returns final results.
|
|
|
|
| 874 |
"""
|
| 875 |
|
| 876 |
start_time = time.time()
|
| 877 |
log_messages = []
|
| 878 |
|
| 879 |
+
# This will store all final extracted item dicts (image, ID, type, LATEX)
|
| 880 |
+
all_extracted_items: List[Dict[str, Union[Image.Image, str]]] = []
|
| 881 |
|
|
|
|
| 882 |
equation_counts_per_page: Dict[int, int] = {}
|
| 883 |
|
|
|
|
| 884 |
total_figure_count = 0
|
| 885 |
total_equation_count = 0
|
| 886 |
|
| 887 |
|
| 888 |
+
# 1. Validation and Model Loading (YOLO)
|
| 889 |
t0 = time.time()
|
| 890 |
if not os.path.exists(pdf_path):
|
| 891 |
report = f"❌ FATAL ERROR: Input PDF not found at {pdf_path}."
|
|
|
|
| 900 |
t1 = time.time()
|
| 901 |
log_messages.append(f"Model Loading Time: {t1-t0:.4f}s")
|
| 902 |
|
| 903 |
+
# 2. PDF Loading (fitz)
|
| 904 |
t2 = time.time()
|
| 905 |
try:
|
| 906 |
doc = fitz.open(pdf_path)
|
|
|
|
| 914 |
|
| 915 |
mat = fitz.Matrix(SCALE_FACTOR, SCALE_FACTOR)
|
| 916 |
|
| 917 |
+
# 3. Page Processing, Detection, and OCR Loop
|
| 918 |
t4 = time.time()
|
| 919 |
for page_num_0_based in range(doc.page_count):
|
| 920 |
page_start_time = time.time()
|
|
|
|
| 922 |
page_num = page_num_0_based + 1
|
| 923 |
|
| 924 |
# Render page to image for YOLO
|
| 925 |
+
# ... (image rendering logic retained)
|
| 926 |
try:
|
| 927 |
pix_start = time.time()
|
| 928 |
pix = fitz_page.get_pixmap(matrix=mat)
|
|
|
|
| 931 |
except Exception as e:
|
| 932 |
logging.error(f"Error converting page {page_num} to image: {e}. Skipping.")
|
| 933 |
continue
|
| 934 |
+
|
| 935 |
+
# YOLO Detection
|
| 936 |
detect_start = time.time()
|
| 937 |
(
|
| 938 |
page_equations,
|
| 939 |
page_figures,
|
| 940 |
+
page_extracted_items, # List of dicts: {'type', 'id', 'pil_image', 'latex'}
|
| 941 |
total_equation_count,
|
| 942 |
total_figure_count
|
| 943 |
) = run_yolo_detection_and_count(
|
|
|
|
| 948 |
total_figure_count
|
| 949 |
)
|
| 950 |
|
| 951 |
+
# --- NEW: OCR/LaTeX Conversion for Equations ---
|
| 952 |
+
ocr_start = time.time()
|
| 953 |
+
for item in page_extracted_items:
|
| 954 |
+
if item["type"] == "equation":
|
| 955 |
+
# 1. Convert PIL image to Base64
|
| 956 |
+
b64_string = pil_to_base64(item["pil_image"])
|
| 957 |
+
|
| 958 |
+
# 2. Run OCR
|
| 959 |
+
item["latex"] = get_latex_from_base64(b64_string)
|
| 960 |
+
|
| 961 |
+
# OPTIONAL: Clean up large image data if memory is a concern
|
| 962 |
+
# del item["pil_image"]
|
| 963 |
+
ocr_time = time.time() - ocr_start
|
| 964 |
+
|
| 965 |
+
# Append all extracted item dictionaries
|
| 966 |
+
all_extracted_items.extend(page_extracted_items)
|
| 967 |
+
|
| 968 |
detect_time = time.time() - detect_start
|
| 969 |
|
| 970 |
# Store the count in the dictionary (INT keys)
|
| 971 |
equation_counts_per_page[page_num] = page_equations
|
| 972 |
|
| 973 |
page_total_time = time.time() - page_start_time
|
| 974 |
+
log_messages.append(f"Page {page_num} Time: Total={page_total_time:.4f}s (Render={pix_time:.4f}s, Detect={detect_time:.4f}s, OCR={ocr_time:.4f}s)")
|
| 975 |
|
| 976 |
doc.close()
|
| 977 |
t5 = time.time()
|
| 978 |
detection_loop_time = t5 - t4
|
| 979 |
log_messages.append(f"Total Detection Loop Time ({total_pages} pages): {detection_loop_time:.4f}s")
|
| 980 |
+
|
| 981 |
+
# 4. Final Report Generation and Gallery Formatting
|
| 982 |
+
|
| 983 |
+
# Format the extracted items for the Gradio Gallery
|
| 984 |
+
gallery_items: List[Tuple[Image.Image, str]] = []
|
| 985 |
+
|
| 986 |
+
# We will include the LATEX code as the image label in the gallery
|
| 987 |
+
# If the item is a Figure, the label is just the ID.
|
| 988 |
+
for item in all_extracted_items:
|
| 989 |
+
image_label = item["id"]
|
| 990 |
+
if item["type"] == "equation":
|
| 991 |
+
image_label = f'{item["id"]}: {item["latex"]}'
|
| 992 |
+
|
| 993 |
+
gallery_items.append((item["pil_image"], image_label))
|
| 994 |
+
|
| 995 |
|
| 996 |
+
total_execution_time = t5 - start_time
|
| 997 |
+
|
| 998 |
# Convert integer keys to string keys for JSON serialization
|
| 999 |
equation_counts_per_page_str_keys: Dict[str, int] = {
|
| 1000 |
str(k): v for k, v in equation_counts_per_page.items()
|
| 1001 |
}
|
|
|
|
|
|
|
|
|
|
| 1002 |
|
| 1003 |
report = (
|
| 1004 |
+
f"✅ **YOLO Counting & OCR Complete!**\n\n"
|
| 1005 |
f"**1) Total Pages Detected in PDF:** **{total_pages}**\n"
|
| 1006 |
f"**2) Total Equations Detected:** **{total_equation_count}**\n"
|
| 1007 |
f"**3) Total Figures Detected:** **{total_figure_count}**\n"
|
|
|
|
| 1013 |
f"\n```"
|
| 1014 |
)
|
| 1015 |
|
| 1016 |
+
return total_pages, total_equation_count, total_figure_count, report, total_execution_time, equation_counts_per_page_str_keys, gallery_items
|
|
|
|
| 1017 |
|
| 1018 |
|
| 1019 |
# ============================================================================
|
| 1020 |
+
# --- GRADIO INTERFACE FUNCTION & DEFINITION (Retained) ---
|
| 1021 |
# ============================================================================
|
| 1022 |
|
|
|
|
| 1023 |
def gradio_process_pdf(pdf_file) -> Tuple[str, str, str, str, Dict[str, int], List[Tuple[Image.Image, str]]]:
|
| 1024 |
"""
|
| 1025 |
Gradio wrapper function to handle file upload and return results.
|
|
|
|
| 1037 |
report,
|
| 1038 |
total_time,
|
| 1039 |
equation_counts_per_page,
|
| 1040 |
+
gallery_items
|
| 1041 |
) = run_single_pdf_preprocessing(pdf_path)
|
| 1042 |
|
| 1043 |
|
|
|
|
| 1050 |
return "Error", "Error", "Error", error_msg, {}, []
|
| 1051 |
|
| 1052 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1053 |
if __name__ == "__main__":
|
| 1054 |
|
| 1055 |
if not os.path.exists(WEIGHTS_PATH):
|
|
|
|
| 1066 |
# NEW OUTPUT: JSON component for structured data
|
| 1067 |
output_page_counts = gr.JSON(label="Equation Count Per Page (Dictionary)")
|
| 1068 |
|
| 1069 |
+
# Gradio Gallery now shows the LaTeX code as the label
|
| 1070 |
output_gallery = gr.Gallery(
|
| 1071 |
+
label="Detected Items (with Extracted LaTeX)",
|
| 1072 |
+
columns=3,
|
| 1073 |
height="auto",
|
| 1074 |
object_fit="contain",
|
| 1075 |
allow_preview=False
|
|
|
|
| 1086 |
output_page_counts,
|
| 1087 |
output_gallery
|
| 1088 |
],
|
| 1089 |
+
title="📊 YOLO Detection & Math OCR Pipeline",
|
| 1090 |
description=(
|
| 1091 |
+
"Upload a PDF. YOLO detects equations, and the TrOCR model converts them to LaTeX."
|
| 1092 |
),
|
| 1093 |
)
|
| 1094 |
|