import cv2 import streamlit as st from table2html import Table2HTML from table2html.source import visualize_boxes, crop_image import numpy as np import time import os import tempfile import fitz # PyMuPDF from PIL import Image default_configs = { 'table_detection': { 'model_path': 'models/table_detection.pt', 'confidence_threshold': 0.25, 'iou_threshold': 0.7 }, 'column_detection': { 'model_path': 'models/column_detection.pt', 'confidence_threshold': 0.25, 'iou_threshold': 0.7, 'task': 'detect' }, 'row_detection': { 'model_path': 'models/row_detection.pt', 'confidence_threshold': 0.25, 'iou_threshold': 0.7, 'task': 'detect' }, 'table_crop_padding': 15 } thumbnail_columns = 5 def initialize_session_state(): if 'table_detections' not in st.session_state: st.session_state.table_detections = [] if 'structure_detections' not in st.session_state: st.session_state.structure_detections = [] if 'cropped_tables' not in st.session_state: st.session_state.cropped_tables = [] if 'html_tables' not in st.session_state: st.session_state.html_tables = [] if 'detection_data' not in st.session_state: st.session_state.detection_data = [] if 'current_image' not in st.session_state: st.session_state.current_image = None if 'configs' not in st.session_state: st.session_state.configs = default_configs def clear_results(): st.session_state.table_detections = [] st.session_state.structure_detections = [] st.session_state.cropped_tables = [] st.session_state.html_tables = [] def detect_update_results(image, configs): table2html = Table2HTML( table_detection_config=configs["table_detection"], row_detection_config=configs["row_detection"], column_detection_config=configs["column_detection"] ) detection_data = table2html(image, configs["table_crop_padding"]) if len(detection_data) == 0: st.warning("No tables detected on this page.") return # Clear previous results st.session_state.detection_data = detection_data for data in detection_data: # Store table detection visualization table_detection = visualize_boxes( image.copy(), [data["table_bbox"]], color=(0, 0, 255), thickness=2 ) st.session_state.table_detections.append(table_detection) # Store cropped table cropped_table = crop_image( image, data["table_bbox"], configs["table_crop_padding"]) st.session_state.cropped_tables.append(cropped_table) # Store structure detection visualization structure_detection = visualize_boxes( cropped_table.copy(), [cell['box'] for cell in data['cells']], color=(0, 255, 0), thickness=1 ) st.session_state.structure_detections.append(structure_detection) # Store HTML st.session_state.html_tables.append(data["html"]) def inference_one_image(image, configs): clear_results() with st.spinner("Processing..."): start_time = time.time() try: # Update process_image call to include all model paths detect_update_results(image, configs) # Clean up temporary files if using custom models for model_type, config in configs.items(): if f"{model_type}_option" in st.session_state and \ st.session_state[f"{model_type}_option"] == "custom": os.unlink(config["model_path"]) execution_time = time.time() - start_time st.success( f"Processing completed in {execution_time:.2f} seconds") except Exception as e: st.error(f"Error processing image: {str(e)}") # Clean up temporary files on error for model_type, config in configs.items(): if f"{model_type}_option" in st.session_state and \ st.session_state[f"{model_type}_option"] == "custom": os.unlink(config["model_path"]) def main(): initialize_session_state() st.set_page_config(layout="wide") # Add page selection page = st.sidebar.radio("Select Page", ["Inference", "Configuration"]) if page == "Inference": st.title("Table Detection and Recognition") # Image Upload Section st.subheader("Image Upload") uploaded_file = st.file_uploader( "Choose an image or PDF file", type=['jpg', 'jpeg', 'png', 'pdf'] ) # Get configurations from session state configs = st.session_state.get('configs', default_configs) current_image = None if uploaded_file is not None and all(configs.values()): if uploaded_file.type == "application/pdf": # Convert PDF to images pdf_bytes = uploaded_file.read() pdf_images = [] doc = fitz.open(stream=pdf_bytes, filetype="pdf") for page_num in range(doc.page_count): page = doc[page_num] pix = page.get_pixmap(dpi=200) pil_image = Image.frombytes( "RGB", [pix.width, pix.height], pix.samples) img_array = np.array(pil_image) pdf_images.append(img_array) # Show thumbnails st.write("Select a page to process:") cols = st.columns(thumbnail_columns) for idx, img in enumerate(pdf_images): with cols[idx % thumbnail_columns]: st.image(img, width=150, use_container_width=True) if st.button(f"Process Page {idx+1}"): current_image = img st.session_state.current_image = img inference_one_image( current_image, configs) else: # Handle regular image upload file_bytes = np.asarray( bytearray(uploaded_file.read()), dtype=np.uint8) current_image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) st.session_state.current_image = current_image # Process button if st.button("Process Image"): inference_one_image( current_image, configs) if len(st.session_state.cropped_tables) > 0: st.header("Results") # General Results Section st.subheader("General Results") gen_img_col, gen_html_col = st.columns([1, 1]) with gen_img_col: show_all_detections = st.toggle( "Show Table Detections", value=False, key="show_all_detections" ) # Display either original image or detection visualization if show_all_detections and len(st.session_state.detection_data) > 0: # Create visualization with all table detections all_tables_viz = visualize_boxes( st.session_state.current_image.copy(), [data["table_bbox"] for data in st.session_state.detection_data], color=(0, 0, 255), thickness=2 ) st.image( all_tables_viz, caption="All Table Detections", use_container_width=True ) else: st.image( st.session_state.current_image, caption="Original Image", use_container_width=True ) with gen_html_col: st.markdown("### All HTML Tables:") # Combine all HTML tables all_html = "\n".join(st.session_state.html_tables) st.markdown(all_html, unsafe_allow_html=True) # Download all HTML tables combined_html = "
\n" + all_html + "\n" st.download_button( label="Download All Tables HTML", data=combined_html, file_name="all_tables.html", mime="text/html", key="download_all_btn" ) st.divider() # Detailed Results Section show_details = st.toggle("Show Detailed Results", value=False) if show_details: st.subheader("Detailed Results") for idx in range(len(st.session_state.cropped_tables)): st.subheader(f"Table {idx + 1}") # Visualization controls for each table control_col1, control_col2 = st.columns([1, 1]) with control_col1: show_table_detection = st.toggle( f"Show Table Detection for Table {idx + 1}", value=False, key=f"table_detection_{idx}" ) with control_col2: show_structure_detection = st.toggle( f"Show Structure Detection for Table {idx + 1}", value=False, key=f"structure_detection_{idx}" ) # Create columns for each table result img_col, html_col = st.columns([1, 1]) with img_col: # Show either the cropped table or visualizations based on toggles if show_table_detection: st.image( st.session_state.table_detections[idx], caption="Table Detection", use_container_width=True ) if show_structure_detection: st.image( st.session_state.structure_detections[idx], caption="Structure Detection", use_container_width=True ) if not show_table_detection and not show_structure_detection: st.image( st.session_state.cropped_tables[idx], caption="Cropped Table", use_container_width=True ) with html_col: st.markdown("### HTML Output:") st.markdown( st.session_state.html_tables[idx], unsafe_allow_html=True ) st.download_button( label=f"Download Table {idx + 1} HTML", data=st.session_state.html_tables[idx], file_name=f"table_{idx + 1}.html", mime="text/html", key=f"download_btn_{idx}" ) st.divider() else: # Configuration page st.title("Model Configuration") # Model selection options model_types = ["Table Detection", "Column Detection", "Row Detection"] configs = {} # Store both paths and thresholds for idx, model_type in enumerate(model_types): st.markdown(f"### {model_type}") key_prefix = model_type.lower().replace(" ", "_") # Model file selection model_option = st.radio( f"Choose {model_type} Model", options=["default", "custom"], horizontal=True, key=f"{key_prefix}_option" ) if model_option == "default": default_path = f"models/{key_prefix}.pt" configs[key_prefix] = {"model_path": default_path} st.info(f"Using default model: {default_path}") else: model_upload = st.file_uploader( f"Choose {model_type} Model File (.pt)", type=['pt'], key=f"{key_prefix}_upload" ) if model_upload: with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as tmp_file: tmp_file.write(model_upload.getvalue()) configs[key_prefix] = { "model_path": tmp_file.name} else: configs[key_prefix] = {"model_path": None} st.warning( f"Please upload a {model_type.lower()} model file") # Add threshold controls thresh_col1, thresh_col2 = st.columns(2) with thresh_col1: conf_threshold = st.slider( f"{model_type} Confidence Threshold", min_value=0.0, max_value=1.0, value=0.25, step=0.05, key=f"{key_prefix}_conf_threshold" ) with thresh_col2: iou_threshold = st.slider( f"{model_type} IOU Threshold", min_value=0.0, max_value=1.0, value=0.7, step=0.05, key=f"{key_prefix}_iou_threshold" ) if configs[key_prefix]["model_path"]: configs[key_prefix].update({ "confidence_threshold": conf_threshold, "iou_threshold": iou_threshold }) # Add task field for row and column detection if key_prefix in ["column_detection", "row_detection"]: configs[key_prefix]["task"] = "detect" st.divider() # Padding input below the model configurations table_crop_padding = st.number_input( "Table Crop Padding", value=15, min_value=0, max_value=100 ) # Save configurations to session state if st.button("Save Configuration"): st.session_state.configs = configs st.success("Configuration saved successfully!") if __name__ == "__main__": main()