heerjtdev commited on
Commit
3d0c98c
·
verified ·
1 Parent(s): d6a7780

Upload test_surya.py

Browse files
Files changed (1) hide show
  1. test_surya.py +188 -0
test_surya.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ # Standard environment setup (keep this)
5
+ if "APP_PATH" in os.environ:
6
+ app_path = os.path.abspath(os.environ["APP_PATH"])
7
+ if os.getcwd() != app_path:
8
+ # fix sys.path for import
9
+ os.chdir(app_path)
10
+ if app_path not in sys.path:
11
+ sys.path.append(app_path)
12
+
13
+ import io
14
+ import tempfile
15
+ from typing import List
16
+
17
+ import pypdfium2
18
+ import gradio as gr
19
+ import requests
20
+ from contextlib import suppress
21
+
22
+ from surya.common.surya.schema import TaskNames
23
+ from surya.models import load_predictors
24
+
25
+ from surya.debug.draw import draw_polys_on_image
26
+ from PIL import Image
27
+ from surya.layout import LayoutResult
28
+ from surya.settings import settings
29
+ from surya.common.util import rescale_bbox, expand_bbox
30
+
31
+
32
+ # --- Core Functions (Minimal changes required) ---
33
+
34
+ # Get page image from PDF (keep this)
35
+ def open_pdf(pdf_file):
36
+ return pypdfium2.PdfDocument(pdf_file)
37
+
38
+ def page_counter(pdf_file):
39
+ doc = open_pdf(pdf_file)
40
+ doc_len = len(doc)
41
+ doc.close()
42
+ return doc_len
43
+
44
+ def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI):
45
+ doc = open_pdf(pdf_file)
46
+ renderer = doc.render(
47
+ pypdfium2.PdfBitmap.to_pil,
48
+ page_indices=[page_num - 1],
49
+ scale=dpi / 72,
50
+ )
51
+ png = list(renderer)[0]
52
+ png_image = png.convert("RGB")
53
+ doc.close()
54
+ return png_image
55
+
56
+ def get_uploaded_image(in_file):
57
+ return Image.open(in_file).convert("RGB")
58
+
59
+ # Modified layout_detection to filter for Equation and Figure
60
+ def focused_layout_detection(img) -> (Image.Image, LayoutResult):
61
+ # Use the existing layout predictor
62
+ pred = predictors["layout"]([img])[0]
63
+
64
+ # Filter for Equation and Figure bounding boxes
65
+ filtered_bboxes = [
66
+ p
67
+ for p in pred.bboxes
68
+ if p.label in ["Equation", "Figure"] # <-- Filter applied here
69
+ ]
70
+
71
+ # Update the prediction result to only include the filtered boxes
72
+ pred.bboxes = filtered_bboxes
73
+
74
+ # Prepare data for drawing on the image
75
+ polygons = [p.polygon for p in filtered_bboxes]
76
+ labels = [
77
+ f"{p.label}-{p.position}-{round(p.top_k[p.label], 2)}" for p in filtered_bboxes
78
+ ]
79
+
80
+ # Draw the filtered polygons
81
+ layout_img = draw_polys_on_image(
82
+ polygons, img.copy(), labels=labels, label_font_size=18
83
+ )
84
+
85
+ return layout_img, pred
86
+
87
+
88
+ # Load models (keep this)
89
+ predictors = load_predictors()
90
+
91
+
92
+ # --- Gradio Interface (Significantly simplified) ---
93
+
94
+ with gr.Blocks(title="Surya Equation/Figure Detector") as demo:
95
+ gr.Markdown("""
96
+ # Surya Equation and Figure Detection
97
+
98
+ This application uses Surya OCR's layout analysis model to **specifically detect and locate Equations and Figures** within a document page.
99
+
100
+ The output provides an image with bounding boxes drawn, and the raw JSON bounding box information for the detected elements.
101
+
102
+ Find the original project [here](https://github.com/VikParuchuri/surya).
103
+ """)
104
+
105
+ with gr.Row():
106
+ with gr.Column():
107
+ in_file = gr.File(label="PDF file or image:", file_types=[".pdf", ".png", ".jpg", ".jpeg", ".gif", ".webp"])
108
+ in_num = gr.Slider(label="Page number", minimum=1, maximum=100, value=1, step=1)
109
+ in_img = gr.Image(label="Select page of Image", type="pil", sources=None)
110
+
111
+ # Keep only the essential button
112
+ detection_btn = gr.Button("Run Equation and Figure Detection")
113
+
114
+ with gr.Column():
115
+ result_img = gr.Gallery(label="Result image: Detected Equations and Figures", show_label=True,
116
+ elem_id="gallery", columns=[1], rows=[1], object_fit="contain", height="auto")
117
+
118
+ gr.HTML("""
119
+ <style>
120
+ #gallery {
121
+ height: auto !important;
122
+ max-height: none !important;
123
+ overflow: visible !important;
124
+ }
125
+ #gallery .gallery-item {
126
+ flex-direction: column !important;
127
+ }
128
+ #gallery .gallery-item img {
129
+ width: 100% !important;
130
+ height: auto !important;
131
+ object-fit: contain !important;
132
+ }
133
+ </style>
134
+ """)
135
+ result_json = gr.JSON(label="Result JSON (Bounding Box Data)")
136
+
137
+ # Page Loading Logic (keep this)
138
+ def show_image(file, num=1):
139
+ if file.endswith('.pdf'):
140
+ count = page_counter(file)
141
+ img = get_page_image(file, num, settings.IMAGE_DPI)
142
+ return [
143
+ gr.update(visible=True, maximum=count),
144
+ gr.update(value=img)]
145
+ else:
146
+ img = get_uploaded_image(file)
147
+ return [
148
+ gr.update(visible=False),
149
+ gr.update(value=img)]
150
+
151
+ in_file.upload(
152
+ fn=show_image,
153
+ inputs=[in_file],
154
+ outputs=[in_num, in_img],
155
+ )
156
+ in_num.change(
157
+ fn=show_image,
158
+ inputs=[in_file, in_num],
159
+ outputs=[in_num, in_img],
160
+ )
161
+
162
+ # Run Focused Detection
163
+ def run_focused_detection(pil_image):
164
+ # update counter
165
+ with suppress(Exception):
166
+ requests.get("https://counterapi.com/api/xiaoyao9184.github.com/view/docker-surya")
167
+
168
+ layout_img, pred = focused_layout_detection(pil_image)
169
+ # Exclude the large segmentation map from the JSON output
170
+ layout_json = pred.model_dump(exclude=["segmentation_map"])
171
+
172
+ # Count the filtered results
173
+ num_boxes = len(layout_json.get('bboxes', []))
174
+
175
+ return (
176
+ gr.update(label=f"Result image: {num_boxes} Equations/Figures detected", value=[layout_img], rows=[1], height=layout_img.height),
177
+ gr.update(label=f"Result JSON: {num_boxes} Equations/Figures detected", value=layout_json)
178
+ )
179
+
180
+ detection_btn.click(
181
+ fn=run_focused_detection,
182
+ inputs=[in_img],
183
+ outputs=[result_img, result_json]
184
+ )
185
+
186
+
187
+ if __name__ == "__main__":
188
+ demo.launch()