File size: 14,971 Bytes
9e9a60d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
import cv2
import streamlit as st
from table2html import Table2HTML
from table2html.source import visualize_boxes, crop_image
import numpy as np
import time
import os
import tempfile
import fitz  # PyMuPDF
from PIL import Image


default_configs = {
    'table_detection': {
        'model_path': 'models/table_detection.pt',
        'confidence_threshold': 0.25,
        'iou_threshold': 0.7
    },
    'column_detection': {
        'model_path': 'models/column_detection.pt',
        'confidence_threshold': 0.25,
        'iou_threshold': 0.7,
        'task': 'detect'
    },
    'row_detection': {
        'model_path': 'models/row_detection.pt',
        'confidence_threshold': 0.25,
        'iou_threshold': 0.7,
        'task': 'detect'
    },
    'table_crop_padding': 15
}

thumbnail_columns = 5


def initialize_session_state():
    if 'table_detections' not in st.session_state:
        st.session_state.table_detections = []
    if 'structure_detections' not in st.session_state:
        st.session_state.structure_detections = []
    if 'cropped_tables' not in st.session_state:
        st.session_state.cropped_tables = []
    if 'html_tables' not in st.session_state:
        st.session_state.html_tables = []
    if 'detection_data' not in st.session_state:
        st.session_state.detection_data = []
    if 'current_image' not in st.session_state:
        st.session_state.current_image = None
    if 'configs' not in st.session_state:
        st.session_state.configs = default_configs


def clear_results():
    st.session_state.table_detections = []
    st.session_state.structure_detections = []
    st.session_state.cropped_tables = []
    st.session_state.html_tables = []


def detect_update_results(image, configs):
    table2html = Table2HTML(
        table_detection_config=configs["table_detection"],
        row_detection_config=configs["row_detection"],
        column_detection_config=configs["column_detection"]
    )
    detection_data = table2html(image, configs["table_crop_padding"])

    if len(detection_data) == 0:
        st.warning("No tables detected on this page.")
        return

    # Clear previous results
    st.session_state.detection_data = detection_data

    for data in detection_data:
        # Store table detection visualization
        table_detection = visualize_boxes(
            image.copy(),
            [data["table_bbox"]],
            color=(0, 0, 255),
            thickness=2
        )
        st.session_state.table_detections.append(table_detection)

        # Store cropped table
        cropped_table = crop_image(
            image, data["table_bbox"], configs["table_crop_padding"])
        st.session_state.cropped_tables.append(cropped_table)

        # Store structure detection visualization
        structure_detection = visualize_boxes(
            cropped_table.copy(),
            [cell['box'] for cell in data['cells']],
            color=(0, 255, 0),
            thickness=1
        )
        st.session_state.structure_detections.append(structure_detection)

        # Store HTML
        st.session_state.html_tables.append(data["html"])


def inference_one_image(image, configs):
    clear_results()
    with st.spinner("Processing..."):
        start_time = time.time()

        try:
            # Update process_image call to include all model paths
            detect_update_results(image, configs)

            # Clean up temporary files if using custom models
            for model_type, config in configs.items():
                if f"{model_type}_option" in st.session_state and \
                        st.session_state[f"{model_type}_option"] == "custom":
                    os.unlink(config["model_path"])

            execution_time = time.time() - start_time
            st.success(
                f"Processing completed in {execution_time:.2f} seconds")
        except Exception as e:
            st.error(f"Error processing image: {str(e)}")
            # Clean up temporary files on error
            for model_type, config in configs.items():
                if f"{model_type}_option" in st.session_state and \
                        st.session_state[f"{model_type}_option"] == "custom":
                    os.unlink(config["model_path"])


def main():
    initialize_session_state()
    st.set_page_config(layout="wide")

    # Add page selection
    page = st.sidebar.radio("Select Page", ["Inference", "Configuration"])

    if page == "Inference":
        st.title("Table Detection and Recognition")

        # Image Upload Section
        st.subheader("Image Upload")
        uploaded_file = st.file_uploader(
            "Choose an image or PDF file",
            type=['jpg', 'jpeg', 'png', 'pdf']
        )

        # Get configurations from session state
        configs = st.session_state.get('configs', default_configs)

        current_image = None

        if uploaded_file is not None and all(configs.values()):
            if uploaded_file.type == "application/pdf":
                # Convert PDF to images
                pdf_bytes = uploaded_file.read()
                pdf_images = []

                doc = fitz.open(stream=pdf_bytes, filetype="pdf")

                for page_num in range(doc.page_count):
                    page = doc[page_num]
                    pix = page.get_pixmap(dpi=200)
                    pil_image = Image.frombytes(
                        "RGB", [pix.width, pix.height], pix.samples)
                    img_array = np.array(pil_image)
                    pdf_images.append(img_array)

                # Show thumbnails
                st.write("Select a page to process:")
                cols = st.columns(thumbnail_columns)
                for idx, img in enumerate(pdf_images):
                    with cols[idx % thumbnail_columns]:
                        st.image(img, width=150, use_container_width=True)
                        if st.button(f"Process Page {idx+1}"):
                            current_image = img
                            st.session_state.current_image = img
                            inference_one_image(
                                current_image, configs)
            else:
                # Handle regular image upload
                file_bytes = np.asarray(
                    bytearray(uploaded_file.read()), dtype=np.uint8)
                current_image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
                st.session_state.current_image = current_image

                # Process button
                if st.button("Process Image"):
                    inference_one_image(
                        current_image, configs)

        if len(st.session_state.cropped_tables) > 0:
            st.header("Results")

            # General Results Section
            st.subheader("General Results")
            gen_img_col, gen_html_col = st.columns([1, 1])

            with gen_img_col:
                show_all_detections = st.toggle(
                    "Show Table Detections",
                    value=False,
                    key="show_all_detections"
                )

                # Display either original image or detection visualization
                if show_all_detections and len(st.session_state.detection_data) > 0:
                    # Create visualization with all table detections
                    all_tables_viz = visualize_boxes(
                        st.session_state.current_image.copy(),
                        [data["table_bbox"]
                            for data in st.session_state.detection_data],
                        color=(0, 0, 255),
                        thickness=2
                    )
                    st.image(
                        all_tables_viz,
                        caption="All Table Detections",
                        use_container_width=True
                    )
                else:
                    st.image(
                        st.session_state.current_image,
                        caption="Original Image",
                        use_container_width=True
                    )

            with gen_html_col:
                st.markdown("### All HTML Tables:")
                # Combine all HTML tables
                all_html = "\n".join(st.session_state.html_tables)
                st.markdown(all_html, unsafe_allow_html=True)

                # Download all HTML tables
                combined_html = "<!DOCTYPE html><html><body>\n" + all_html + "\n</body></html>"
                st.download_button(
                    label="Download All Tables HTML",
                    data=combined_html,
                    file_name="all_tables.html",
                    mime="text/html",
                    key="download_all_btn"
                )

            st.divider()

            # Detailed Results Section
            show_details = st.toggle("Show Detailed Results", value=False)

            if show_details:
                st.subheader("Detailed Results")
                for idx in range(len(st.session_state.cropped_tables)):
                    st.subheader(f"Table {idx + 1}")

                    # Visualization controls for each table
                    control_col1, control_col2 = st.columns([1, 1])
                    with control_col1:
                        show_table_detection = st.toggle(
                            f"Show Table Detection for Table {idx + 1}",
                            value=False,
                            key=f"table_detection_{idx}"
                        )
                    with control_col2:
                        show_structure_detection = st.toggle(
                            f"Show Structure Detection for Table {idx + 1}",
                            value=False,
                            key=f"structure_detection_{idx}"
                        )

                    # Create columns for each table result
                    img_col, html_col = st.columns([1, 1])

                    with img_col:
                        # Show either the cropped table or visualizations based on toggles
                        if show_table_detection:
                            st.image(
                                st.session_state.table_detections[idx],
                                caption="Table Detection",
                                use_container_width=True
                            )
                        if show_structure_detection:
                            st.image(
                                st.session_state.structure_detections[idx],
                                caption="Structure Detection",
                                use_container_width=True
                            )
                        if not show_table_detection and not show_structure_detection:
                            st.image(
                                st.session_state.cropped_tables[idx],
                                caption="Cropped Table",
                                use_container_width=True
                            )

                    with html_col:
                        st.markdown("### HTML Output:")
                        st.markdown(
                            st.session_state.html_tables[idx],
                            unsafe_allow_html=True
                        )
                        st.download_button(
                            label=f"Download Table {idx + 1} HTML",
                            data=st.session_state.html_tables[idx],
                            file_name=f"table_{idx + 1}.html",
                            mime="text/html",
                            key=f"download_btn_{idx}"
                        )

                    st.divider()

    else:  # Configuration page
        st.title("Model Configuration")

        # Model selection options
        model_types = ["Table Detection", "Column Detection", "Row Detection"]
        configs = {}  # Store both paths and thresholds

        for idx, model_type in enumerate(model_types):
            st.markdown(f"### {model_type}")
            key_prefix = model_type.lower().replace(" ", "_")

            # Model file selection
            model_option = st.radio(
                f"Choose {model_type} Model",
                options=["default", "custom"],
                horizontal=True,
                key=f"{key_prefix}_option"
            )

            if model_option == "default":
                default_path = f"models/{key_prefix}.pt"
                configs[key_prefix] = {"model_path": default_path}
                st.info(f"Using default model: {default_path}")
            else:
                model_upload = st.file_uploader(
                    f"Choose {model_type} Model File (.pt)",
                    type=['pt'],
                    key=f"{key_prefix}_upload"
                )
                if model_upload:
                    with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as tmp_file:
                        tmp_file.write(model_upload.getvalue())
                        configs[key_prefix] = {
                            "model_path": tmp_file.name}
                else:
                    configs[key_prefix] = {"model_path": None}
                    st.warning(
                        f"Please upload a {model_type.lower()} model file")

            # Add threshold controls
            thresh_col1, thresh_col2 = st.columns(2)
            with thresh_col1:
                conf_threshold = st.slider(
                    f"{model_type} Confidence Threshold",
                    min_value=0.0,
                    max_value=1.0,
                    value=0.25,
                    step=0.05,
                    key=f"{key_prefix}_conf_threshold"
                )
            with thresh_col2:
                iou_threshold = st.slider(
                    f"{model_type} IOU Threshold",
                    min_value=0.0,
                    max_value=1.0,
                    value=0.7,
                    step=0.05,
                    key=f"{key_prefix}_iou_threshold"
                )

            if configs[key_prefix]["model_path"]:
                configs[key_prefix].update({
                    "confidence_threshold": conf_threshold,
                    "iou_threshold": iou_threshold
                })
                # Add task field for row and column detection
                if key_prefix in ["column_detection", "row_detection"]:
                    configs[key_prefix]["task"] = "detect"

            st.divider()

        # Padding input below the model configurations
        table_crop_padding = st.number_input(
            "Table Crop Padding",
            value=15,
            min_value=0,
            max_value=100
        )

        # Save configurations to session state
        if st.button("Save Configuration"):
            st.session_state.configs = configs
            st.success("Configuration saved successfully!")


if __name__ == "__main__":
    main()