File size: 13,968 Bytes
8d96f17 4d661b4 bfff4aa 46d219d bfff4aa 46d219d bfff4aa 46d219d bfff4aa 46d219d bfff4aa 7fbfa32 bfff4aa 7fbfa32 bfff4aa 7fbfa32 bfff4aa 7fbfa32 bfff4aa 7fbfa32 bfff4aa 46d219d bfff4aa 46d219d bfff4aa 46d219d bfff4aa 46d219d bfff4aa 46d219d 4d661b4 bfff4aa 46d219d bfff4aa 4d661b4 bfff4aa 46d219d bfff4aa 4d661b4 bfff4aa 4d661b4 bfff4aa 46d219d bfff4aa 4d661b4 46d219d bfff4aa 4d661b4 bfff4aa 4d661b4 46d219d bfff4aa 4d661b4 bfff4aa 46d219d bfff4aa 46d219d bfff4aa 46d219d bfff4aa 46d219d 4d661b4 bfff4aa 46d219d bfff4aa 4d661b4 46d219d 4d661b4 bfff4aa 46d219d d126c17 bfff4aa 4d661b4 bfff4aa 4d661b4 bfff4aa 4d661b4 bfff4aa 4d661b4 bfff4aa 4d661b4 46d219d bfff4aa 46d219d bfff4aa 4d661b4 bfff4aa 46d219d bfff4aa 46d219d bfff4aa 7fbfa32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 |
import fitz # PyMuPDF
import numpy as np
import cv2
import torch
import torch.serialization
import os
import time
from typing import Optional, Tuple, List, Dict, Any
from ultralytics import YOLO
import logging
import gradio as gr
import shutil
import tempfile
import io
# ============================================================================
# --- Global Patches and Setup ---
# ============================================================================
# Patch torch.load to prevent weights_only error with older models
_original_torch_load = torch.load
def patched_torch_load(*args, **kwargs):
kwargs["weights_only"] = False
return _original_torch_load(*args, **kwargs)
torch.load = patched_torch_load
logging.basicConfig(level=logging.WARNING)
# ============================================================================
# --- CONFIGURATION AND CONSTANTS ---
# ============================================================================
WEIGHTS_PATH = 'best.pt'
SCALE_FACTOR = 2.0
# Detection parameters
CONF_THRESHOLD = 0.2
TARGET_CLASSES = ['figure', 'equation']
IOU_MERGE_THRESHOLD = 0.4
IOA_SUPPRESSION_THRESHOLD = 0.7
# Global counters (Reset per run)
GLOBAL_FIGURE_COUNT = 0
GLOBAL_EQUATION_COUNT = 0
# ============================================================================
# --- BOX COMBINATION LOGIC (Retained for detection accuracy) ---
# ============================================================================
def calculate_iou(box1, box2):
x1_a, y1_a, x2_a, y2_a = box1
x1_b, y1_b, x2_b, y2_b = box2
x_left = max(x1_a, x1_b)
y_top = max(y1_a, y1_b)
x_right = min(x2_a, x2_b)
y_bottom = min(y2_a, y2_b)
intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top)
box_a_area = (x2_a - x1_a) * (y2_a - y1_a)
box_b_area = (x2_b - x1_b) * (y2_b - y1_b)
union_area = float(box_a_area + box_b_area - intersection_area)
return intersection_area / union_area if union_area > 0 else 0
def filter_nested_boxes(detections, ioa_threshold=0.80):
if not detections: return []
for d in detections:
x1, y1, x2, y2 = d['coords']
d['area'] = (x2 - x1) * (y2 - y1)
detections.sort(key=lambda x: x['area'], reverse=True)
keep_indices = []
is_suppressed = [False] * len(detections)
for i in range(len(detections)):
if is_suppressed[i]: continue
keep_indices.append(i)
box_a = detections[i]['coords']
for j in range(i + 1, len(detections)):
if is_suppressed[j]: continue
box_b = detections[j]['coords']
x_left = max(box_a[0], box_b[0])
y_top = max(box_a[1], box_b[1])
x_right = min(box_a[2], box_b[2])
y_bottom = min(box_a[3], box_b[3])
intersection = max(0, x_right - x_left) * max(0, y_bottom - y_top)
area_b = detections[j]['area']
if area_b > 0 and intersection / area_b > ioa_threshold:
is_suppressed[j] = True
return [detections[i] for i in keep_indices]
def merge_overlapping_boxes(detections, iou_threshold):
if not detections: return []
detections.sort(key=lambda d: d['conf'], reverse=True)
merged_detections = []
is_merged = [False] * len(detections)
for i in range(len(detections)):
if is_merged[i]: continue
current_box = detections[i]['coords']
current_class = detections[i]['class']
merged_x1, merged_y1, merged_x2, merged_y2 = current_box
for j in range(i + 1, len(detections)):
if is_merged[j] or detections[j]['class'] != current_class: continue
other_box = detections[j]['coords']
iou = calculate_iou(current_box, other_box)
if iou > iou_threshold:
merged_x1 = min(merged_x1, other_box[0])
merged_y1 = min(merged_y1, other_box[1])
merged_x2 = max(merged_x2, other_box[2])
merged_y2 = max(merged_y2, other_box[3])
is_merged[j] = True
merged_detections.append({
'coords': (merged_x1, merged_y1, merged_x2, merged_y2),
'y1': merged_y1, 'class': current_class, 'conf': detections[i]['conf']
})
return merged_detections
# ============================================================================
# --- UTILITY FUNCTIONS ---
# ============================================================================
def pixmap_to_numpy(pix: fitz.Pixmap) -> np.ndarray:
"""Converts a PyMuPDF Pixmap to a NumPy array for OpenCV/YOLO."""
img = np.frombuffer(pix.samples, dtype=np.uint8).reshape(
(pix.h, pix.w, pix.n)
)
if pix.n == 4:
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
elif pix.n == 1:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
return img
def run_yolo_detection_and_count(
image: np.ndarray, model: YOLO, page_num: int
) -> Tuple[int, int]:
"""
Runs YOLO inference, applies NMS/filtering, and updates global counters.
Returns page counts only.
"""
global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT
yolo_detections = []
page_equations = 0
page_figures = 0
try:
results = model.predict(image, conf=CONF_THRESHOLD, verbose=False)
if results and results[0].boxes:
for box in results[0].boxes.data.tolist():
x1, y1, x2, y2, conf, cls_id = box
cls_name = model.names[int(cls_id)]
if cls_name in TARGET_CLASSES:
yolo_detections.append({
'coords': (x1, y1, x2, y2),
'class': cls_name,
'conf': conf
})
except Exception as e:
logging.error(f"YOLO inference failed on page {page_num}: {e}")
return 0, 0
# Apply NMS/Merging/Filtering
merged_detections = merge_overlapping_boxes(yolo_detections, IOU_MERGE_THRESHOLD)
final_detections = filter_nested_boxes(merged_detections, IOA_SUPPRESSION_THRESHOLD)
# Update Global Counters
for det in final_detections:
if det['class'] == 'figure':
GLOBAL_FIGURE_COUNT += 1
page_figures += 1
elif det['class'] == 'equation':
GLOBAL_EQUATION_COUNT += 1
page_equations += 1
logging.warning(f" -> Page {page_num}: EQs={page_equations}, Figs={page_figures}")
return page_equations, page_figures
# ============================================================================
# --- MAIN DOCUMENT PROCESSING FUNCTION (Fixed for JSON serialization) ---
# ============================================================================
# NOTE: The return signature now uses Dict[str, int] for the equation counts
def run_single_pdf_preprocessing(pdf_path: str) -> Tuple[int, int, int, str, float, Dict[str, int], List[str]]:
"""
Runs the pipeline, returns counts, report, total time, page counts dict (str keys), and empty list.
"""
global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT
start_time = time.time()
log_messages = []
# Dictionary to store {page_number (int): equation_count (int)}
equation_counts_per_page: Dict[int, int] = {}
# Reset globals
GLOBAL_FIGURE_COUNT = 0
GLOBAL_EQUATION_COUNT = 0
# 1. Validation and Model Loading
t0 = time.time()
if not os.path.exists(pdf_path):
report = f"β FATAL ERROR: Input PDF not found at {pdf_path}."
return 0, 0, 0, report, time.time() - start_time, {}, []
try:
model = YOLO(WEIGHTS_PATH)
logging.warning(f"β
Loaded YOLO model from: {WEIGHTS_PATH}")
except Exception as e:
report = f"β ERROR loading YOLO model: {e}\n(Ensure 'best.pt' is available and valid.)"
return 0, 0, 0, report, time.time() - start_time, {}, []
t1 = time.time()
log_messages.append(f"Model Loading Time: {t1-t0:.4f}s")
# 2. PDF Loading
t2 = time.time()
try:
doc = fitz.open(pdf_path)
total_pages = doc.page_count
logging.warning(f"β
Opened PDF with {doc.page_count} pages")
except Exception as e:
report = f"β ERROR loading PDF file: {e}"
return 0, 0, 0, report, time.time() - start_time, {}, []
t3 = time.time()
log_messages.append(f"PDF Initialization Time: {t3-t2:.4f}s")
mat = fitz.Matrix(SCALE_FACTOR, SCALE_FACTOR)
# 3. Page Processing and Detection Loop
t4 = time.time()
for page_num_0_based in range(doc.page_count):
page_start_time = time.time()
fitz_page = doc.load_page(page_num_0_based)
page_num = page_num_0_based + 1
# Render page to image for YOLO
try:
pix_start = time.time()
pix = fitz_page.get_pixmap(matrix=mat)
original_img = pixmap_to_numpy(pix)
pix_time = time.time() - pix_start
except Exception as e:
logging.error(f"Error converting page {page_num} to image: {e}. Skipping.")
continue
# Core Detection
detect_start = time.time()
page_equations, _ = run_yolo_detection_and_count(original_img, model, page_num)
detect_time = time.time() - detect_start
# Store the count in the dictionary (INT keys)
equation_counts_per_page[page_num] = page_equations
page_total_time = time.time() - page_start_time
log_messages.append(f"Page {page_num} Time: Total={page_total_time:.4f}s (Render={pix_time:.4f}s, Detect={detect_time:.4f}s)")
doc.close()
t5 = time.time()
detection_loop_time = t5 - t4
log_messages.append(f"Total Detection Loop Time ({total_pages} pages): {detection_loop_time:.4f}s")
# FIX APPLIED HERE: Convert integer keys to string keys for JSON serialization
equation_counts_per_page_str_keys: Dict[str, int] = {
str(k): v for k, v in equation_counts_per_page.items()
}
# 4. Final Report Generation
total_execution_time = t5 - start_time
report = (
f"β
**YOLO Counting Complete!**\n\n"
f"**1) Total Pages Detected in PDF:** **{total_pages}**\n"
f"**2) Total Equations Detected:** **{GLOBAL_EQUATION_COUNT}**\n"
f"**3) Total Figures Detected:** **{GLOBAL_FIGURE_COUNT}**\n"
f"---\n"
f"**4) Total Execution Time:** **{total_execution_time:.4f}s**\n"
f"### Detailed Step Timing\n"
f"```\n"
+ "\n".join(log_messages) +
f"\n```"
)
# Return the dictionary with string keys
return total_pages, GLOBAL_EQUATION_COUNT, GLOBAL_FIGURE_COUNT, report, total_execution_time, equation_counts_per_page_str_keys, []
# ============================================================================
# --- GRADIO INTERFACE FUNCTION (Updated) ---
# ============================================================================
def gradio_process_pdf(pdf_file) -> Tuple[str, str, str, str, Dict[str, int], List[str]]:
"""
Gradio wrapper function to handle file upload and return results.
"""
if pdf_file is None:
# Return an empty dict with string keys
return "N/A", "N/A", "N/A", "Please upload a PDF file.", {}, []
pdf_path = pdf_file.name
try:
# Unpack the new return value: equation_counts_per_page (with string keys)
num_pages, num_equations, num_figures, report, total_time, equation_counts_per_page, _ = run_single_pdf_preprocessing(
pdf_path
)
# Return results (6 items now)
return str(num_pages), str(num_equations), str(num_figures), report, equation_counts_per_page, []
except Exception as e:
error_msg = f"An unexpected error occurred: {e}"
logging.error(error_msg, exc_info=True)
# Return an empty dict on error
return "Error", "Error", "Error", error_msg, {}, []
# ============================================================================
# --- GRADIO INTERFACE DEFINITION (Updated) ---
# ============================================================================
if __name__ == "__main__":
if not os.path.exists(WEIGHTS_PATH):
logging.error(f"β FATAL ERROR: YOLO weight file '{WEIGHTS_PATH}' not found. Cannot run live inference.")
input_file = gr.File(label="Upload PDF Document", type="filepath", file_types=[".pdf"])
# Outputs
output_pages = gr.Textbox(label="Total Pages in PDF", interactive=False)
output_equations = gr.Textbox(label="Total Equations Detected", interactive=False)
output_figures = gr.Textbox(label="Total Figures Detected", interactive=False)
output_report = gr.Markdown(label="Processing Summary and Timing")
# NEW OUTPUT: JSON component for structured data
output_page_counts = gr.JSON(label="Equation Count Per Page (Dictionary)")
# Gradio Gallery is retained but will receive an empty list []
output_gallery = gr.Gallery(
label="Detected Equations (Disabled for Speed)",
columns=5,
height="auto",
object_fit="contain",
allow_preview=False
)
interface = gr.Interface(
fn=gradio_process_pdf,
inputs=input_file,
# Outputs list remains the same, but the JSON component now receives string keys.
outputs=[
output_pages,
output_equations,
output_figures,
output_report,
output_page_counts,
output_gallery
],
title="π YOLO Counting with Per-Page Data & Timing",
description=(
"Upload a PDF to run YOLO detection. The results include total counts, a breakdown of "
"equation counts per page (in JSON format), and detailed timing."
),
)
print("\nStarting Gradio application...")
interface.launch(inbrowser=True)
|