|
|
|
|
|
|
|
|
import fitz |
|
|
import numpy as np |
|
|
import cv2 |
|
|
import torch |
|
|
import torch.serialization |
|
|
import os |
|
|
import time |
|
|
from typing import Optional, Tuple, List, Dict, Any |
|
|
from ultralytics import YOLO |
|
|
import logging |
|
|
import gradio as gr |
|
|
import shutil |
|
|
import tempfile |
|
|
import io |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_original_torch_load = torch.load |
|
|
def patched_torch_load(*args, **kwargs): |
|
|
kwargs["weights_only"] = False |
|
|
return _original_torch_load(*args, **kwargs) |
|
|
torch.load = patched_torch_load |
|
|
|
|
|
logging.basicConfig(level=logging.WARNING) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
WEIGHTS_PATH = 'best.pt' |
|
|
SCALE_FACTOR = 2.0 |
|
|
|
|
|
|
|
|
CONF_THRESHOLD = 0.2 |
|
|
TARGET_CLASSES = ['figure', 'equation'] |
|
|
IOU_MERGE_THRESHOLD = 0.4 |
|
|
IOA_SUPPRESSION_THRESHOLD = 0.7 |
|
|
|
|
|
|
|
|
GLOBAL_FIGURE_COUNT = 0 |
|
|
GLOBAL_EQUATION_COUNT = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_iou(box1, box2): |
|
|
x1_a, y1_a, x2_a, y2_a = box1 |
|
|
x1_b, y1_b, x2_b, y2_b = box2 |
|
|
x_left = max(x1_a, x1_b) |
|
|
y_top = max(y1_a, y1_b) |
|
|
x_right = min(x2_a, x2_b) |
|
|
y_bottom = min(y2_a, y2_b) |
|
|
intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top) |
|
|
box_a_area = (x2_a - x1_a) * (y2_a - y1_a) |
|
|
box_b_area = (x2_b - x1_b) * (y2_b - y1_b) |
|
|
union_area = float(box_a_area + box_b_area - intersection_area) |
|
|
return intersection_area / union_area if union_area > 0 else 0 |
|
|
|
|
|
|
|
|
def filter_nested_boxes(detections, ioa_threshold=0.80): |
|
|
if not detections: return [] |
|
|
for d in detections: |
|
|
x1, y1, x2, y2 = d['coords'] |
|
|
d['area'] = (x2 - x1) * (y2 - y1) |
|
|
detections.sort(key=lambda x: x['area'], reverse=True) |
|
|
keep_indices = [] |
|
|
is_suppressed = [False] * len(detections) |
|
|
for i in range(len(detections)): |
|
|
if is_suppressed[i]: continue |
|
|
keep_indices.append(i) |
|
|
box_a = detections[i]['coords'] |
|
|
for j in range(i + 1, len(detections)): |
|
|
if is_suppressed[j]: continue |
|
|
box_b = detections[j]['coords'] |
|
|
x_left = max(box_a[0], box_b[0]) |
|
|
y_top = max(box_a[1], box_b[1]) |
|
|
x_right = min(box_a[2], box_b[2]) |
|
|
y_bottom = min(box_a[3], box_b[3]) |
|
|
intersection = max(0, x_right - x_left) * max(0, y_bottom - y_top) |
|
|
area_b = detections[j]['area'] |
|
|
if area_b > 0 and intersection / area_b > ioa_threshold: |
|
|
is_suppressed[j] = True |
|
|
return [detections[i] for i in keep_indices] |
|
|
|
|
|
|
|
|
def merge_overlapping_boxes(detections, iou_threshold): |
|
|
if not detections: return [] |
|
|
detections.sort(key=lambda d: d['conf'], reverse=True) |
|
|
merged_detections = [] |
|
|
is_merged = [False] * len(detections) |
|
|
for i in range(len(detections)): |
|
|
if is_merged[i]: continue |
|
|
current_box = detections[i]['coords'] |
|
|
current_class = detections[i]['class'] |
|
|
merged_x1, merged_y1, merged_x2, merged_y2 = current_box |
|
|
for j in range(i + 1, len(detections)): |
|
|
if is_merged[j] or detections[j]['class'] != current_class: continue |
|
|
other_box = detections[j]['coords'] |
|
|
iou = calculate_iou(current_box, other_box) |
|
|
if iou > iou_threshold: |
|
|
merged_x1 = min(merged_x1, other_box[0]) |
|
|
merged_y1 = min(merged_y1, other_box[1]) |
|
|
merged_x2 = max(merged_x2, other_box[2]) |
|
|
merged_y2 = max(merged_y2, other_box[3]) |
|
|
is_merged[j] = True |
|
|
merged_detections.append({ |
|
|
'coords': (merged_x1, merged_y1, merged_x2, merged_y2), |
|
|
'y1': merged_y1, 'class': current_class, 'conf': detections[i]['conf'] |
|
|
}) |
|
|
return merged_detections |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pixmap_to_numpy(pix: fitz.Pixmap) -> np.ndarray: |
|
|
"""Converts a PyMuPDF Pixmap to a NumPy array for OpenCV/YOLO.""" |
|
|
img = np.frombuffer(pix.samples, dtype=np.uint8).reshape( |
|
|
(pix.h, pix.w, pix.n) |
|
|
) |
|
|
if pix.n == 4: |
|
|
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) |
|
|
elif pix.n == 1: |
|
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) |
|
|
return img |
|
|
|
|
|
|
|
|
def run_yolo_detection_and_count( |
|
|
image: np.ndarray, model: YOLO, page_num: int |
|
|
) -> Tuple[int, int]: |
|
|
""" |
|
|
Runs YOLO inference, applies NMS/filtering, and updates global counters. |
|
|
Returns page counts only. |
|
|
""" |
|
|
global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT |
|
|
|
|
|
yolo_detections = [] |
|
|
page_equations = 0 |
|
|
page_figures = 0 |
|
|
|
|
|
try: |
|
|
results = model.predict(image, conf=CONF_THRESHOLD, verbose=False) |
|
|
|
|
|
if results and results[0].boxes: |
|
|
for box in results[0].boxes.data.tolist(): |
|
|
x1, y1, x2, y2, conf, cls_id = box |
|
|
cls_name = model.names[int(cls_id)] |
|
|
|
|
|
if cls_name in TARGET_CLASSES: |
|
|
yolo_detections.append({ |
|
|
'coords': (x1, y1, x2, y2), |
|
|
'class': cls_name, |
|
|
'conf': conf |
|
|
}) |
|
|
except Exception as e: |
|
|
logging.error(f"YOLO inference failed on page {page_num}: {e}") |
|
|
return 0, 0 |
|
|
|
|
|
|
|
|
merged_detections = merge_overlapping_boxes(yolo_detections, IOU_MERGE_THRESHOLD) |
|
|
final_detections = filter_nested_boxes(merged_detections, IOA_SUPPRESSION_THRESHOLD) |
|
|
|
|
|
|
|
|
for det in final_detections: |
|
|
if det['class'] == 'figure': |
|
|
GLOBAL_FIGURE_COUNT += 1 |
|
|
page_figures += 1 |
|
|
elif det['class'] == 'equation': |
|
|
GLOBAL_EQUATION_COUNT += 1 |
|
|
page_equations += 1 |
|
|
|
|
|
logging.warning(f" -> Page {page_num}: EQs={page_equations}, Figs={page_figures}") |
|
|
return page_equations, page_figures |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_single_pdf_preprocessing(pdf_path: str) -> Tuple[int, int, int, str, float, Dict[str, int], List[str]]: |
|
|
""" |
|
|
Runs the pipeline, returns counts, report, total time, page counts dict (str keys), and empty list. |
|
|
""" |
|
|
global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT |
|
|
start_time = time.time() |
|
|
log_messages = [] |
|
|
|
|
|
|
|
|
equation_counts_per_page: Dict[int, int] = {} |
|
|
|
|
|
|
|
|
GLOBAL_FIGURE_COUNT = 0 |
|
|
GLOBAL_EQUATION_COUNT = 0 |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
if not os.path.exists(pdf_path): |
|
|
report = f"β FATAL ERROR: Input PDF not found at {pdf_path}." |
|
|
return 0, 0, 0, report, time.time() - start_time, {}, [] |
|
|
|
|
|
try: |
|
|
model = YOLO(WEIGHTS_PATH) |
|
|
logging.warning(f"β
Loaded YOLO model from: {WEIGHTS_PATH}") |
|
|
except Exception as e: |
|
|
report = f"β ERROR loading YOLO model: {e}\n(Ensure 'best.pt' is available and valid.)" |
|
|
return 0, 0, 0, report, time.time() - start_time, {}, [] |
|
|
t1 = time.time() |
|
|
log_messages.append(f"Model Loading Time: {t1-t0:.4f}s") |
|
|
|
|
|
|
|
|
t2 = time.time() |
|
|
try: |
|
|
doc = fitz.open(pdf_path) |
|
|
total_pages = doc.page_count |
|
|
logging.warning(f"β
Opened PDF with {doc.page_count} pages") |
|
|
except Exception as e: |
|
|
report = f"β ERROR loading PDF file: {e}" |
|
|
return 0, 0, 0, report, time.time() - start_time, {}, [] |
|
|
t3 = time.time() |
|
|
log_messages.append(f"PDF Initialization Time: {t3-t2:.4f}s") |
|
|
|
|
|
mat = fitz.Matrix(SCALE_FACTOR, SCALE_FACTOR) |
|
|
|
|
|
|
|
|
t4 = time.time() |
|
|
for page_num_0_based in range(doc.page_count): |
|
|
page_start_time = time.time() |
|
|
fitz_page = doc.load_page(page_num_0_based) |
|
|
page_num = page_num_0_based + 1 |
|
|
|
|
|
|
|
|
try: |
|
|
pix_start = time.time() |
|
|
pix = fitz_page.get_pixmap(matrix=mat) |
|
|
original_img = pixmap_to_numpy(pix) |
|
|
pix_time = time.time() - pix_start |
|
|
except Exception as e: |
|
|
logging.error(f"Error converting page {page_num} to image: {e}. Skipping.") |
|
|
continue |
|
|
|
|
|
|
|
|
detect_start = time.time() |
|
|
page_equations, _ = run_yolo_detection_and_count(original_img, model, page_num) |
|
|
detect_time = time.time() - detect_start |
|
|
|
|
|
|
|
|
equation_counts_per_page[page_num] = page_equations |
|
|
|
|
|
page_total_time = time.time() - page_start_time |
|
|
log_messages.append(f"Page {page_num} Time: Total={page_total_time:.4f}s (Render={pix_time:.4f}s, Detect={detect_time:.4f}s)") |
|
|
|
|
|
doc.close() |
|
|
t5 = time.time() |
|
|
detection_loop_time = t5 - t4 |
|
|
log_messages.append(f"Total Detection Loop Time ({total_pages} pages): {detection_loop_time:.4f}s") |
|
|
|
|
|
|
|
|
equation_counts_per_page_str_keys: Dict[str, int] = { |
|
|
str(k): v for k, v in equation_counts_per_page.items() |
|
|
} |
|
|
|
|
|
|
|
|
total_execution_time = t5 - start_time |
|
|
|
|
|
report = ( |
|
|
f"β
**YOLO Counting Complete!**\n\n" |
|
|
f"**1) Total Pages Detected in PDF:** **{total_pages}**\n" |
|
|
f"**2) Total Equations Detected:** **{GLOBAL_EQUATION_COUNT}**\n" |
|
|
f"**3) Total Figures Detected:** **{GLOBAL_FIGURE_COUNT}**\n" |
|
|
f"---\n" |
|
|
f"**4) Total Execution Time:** **{total_execution_time:.4f}s**\n" |
|
|
f"### Detailed Step Timing\n" |
|
|
f"```\n" |
|
|
+ "\n".join(log_messages) + |
|
|
f"\n```" |
|
|
) |
|
|
|
|
|
|
|
|
return total_pages, GLOBAL_EQUATION_COUNT, GLOBAL_FIGURE_COUNT, report, total_execution_time, equation_counts_per_page_str_keys, [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gradio_process_pdf(pdf_file) -> Tuple[str, str, str, str, Dict[str, int], List[str]]: |
|
|
""" |
|
|
Gradio wrapper function to handle file upload and return results. |
|
|
""" |
|
|
if pdf_file is None: |
|
|
|
|
|
return "N/A", "N/A", "N/A", "Please upload a PDF file.", {}, [] |
|
|
|
|
|
pdf_path = pdf_file.name |
|
|
|
|
|
try: |
|
|
|
|
|
num_pages, num_equations, num_figures, report, total_time, equation_counts_per_page, _ = run_single_pdf_preprocessing( |
|
|
pdf_path |
|
|
) |
|
|
|
|
|
|
|
|
return str(num_pages), str(num_equations), str(num_figures), report, equation_counts_per_page, [] |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"An unexpected error occurred: {e}" |
|
|
logging.error(error_msg, exc_info=True) |
|
|
|
|
|
return "Error", "Error", "Error", error_msg, {}, [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
if not os.path.exists(WEIGHTS_PATH): |
|
|
logging.error(f"β FATAL ERROR: YOLO weight file '{WEIGHTS_PATH}' not found. Cannot run live inference.") |
|
|
|
|
|
input_file = gr.File(label="Upload PDF Document", type="filepath", file_types=[".pdf"]) |
|
|
|
|
|
|
|
|
output_pages = gr.Textbox(label="Total Pages in PDF", interactive=False) |
|
|
output_equations = gr.Textbox(label="Total Equations Detected", interactive=False) |
|
|
output_figures = gr.Textbox(label="Total Figures Detected", interactive=False) |
|
|
output_report = gr.Markdown(label="Processing Summary and Timing") |
|
|
|
|
|
|
|
|
output_page_counts = gr.JSON(label="Equation Count Per Page (Dictionary)") |
|
|
|
|
|
|
|
|
output_gallery = gr.Gallery( |
|
|
label="Detected Equations (Disabled for Speed)", |
|
|
columns=5, |
|
|
height="auto", |
|
|
object_fit="contain", |
|
|
allow_preview=False |
|
|
) |
|
|
|
|
|
interface = gr.Interface( |
|
|
fn=gradio_process_pdf, |
|
|
inputs=input_file, |
|
|
|
|
|
outputs=[ |
|
|
output_pages, |
|
|
output_equations, |
|
|
output_figures, |
|
|
output_report, |
|
|
output_page_counts, |
|
|
output_gallery |
|
|
], |
|
|
title="π YOLO Counting with Per-Page Data & Timing", |
|
|
description=( |
|
|
"Upload a PDF to run YOLO detection. The results include total counts, a breakdown of " |
|
|
"equation counts per page (in JSON format), and detailed timing." |
|
|
), |
|
|
) |
|
|
|
|
|
print("\nStarting Gradio application...") |
|
|
interface.launch(inbrowser=True) |
|
|
|
|
|
|