table2html / app.py
github-actions[bot]
Deploy to Space
9e9a60d
import cv2
import streamlit as st
from table2html import Table2HTML
from table2html.source import visualize_boxes, crop_image
import numpy as np
import time
import os
import tempfile
import fitz # PyMuPDF
from PIL import Image
default_configs = {
'table_detection': {
'model_path': 'models/table_detection.pt',
'confidence_threshold': 0.25,
'iou_threshold': 0.7
},
'column_detection': {
'model_path': 'models/column_detection.pt',
'confidence_threshold': 0.25,
'iou_threshold': 0.7,
'task': 'detect'
},
'row_detection': {
'model_path': 'models/row_detection.pt',
'confidence_threshold': 0.25,
'iou_threshold': 0.7,
'task': 'detect'
},
'table_crop_padding': 15
}
thumbnail_columns = 5
def initialize_session_state():
if 'table_detections' not in st.session_state:
st.session_state.table_detections = []
if 'structure_detections' not in st.session_state:
st.session_state.structure_detections = []
if 'cropped_tables' not in st.session_state:
st.session_state.cropped_tables = []
if 'html_tables' not in st.session_state:
st.session_state.html_tables = []
if 'detection_data' not in st.session_state:
st.session_state.detection_data = []
if 'current_image' not in st.session_state:
st.session_state.current_image = None
if 'configs' not in st.session_state:
st.session_state.configs = default_configs
def clear_results():
st.session_state.table_detections = []
st.session_state.structure_detections = []
st.session_state.cropped_tables = []
st.session_state.html_tables = []
def detect_update_results(image, configs):
table2html = Table2HTML(
table_detection_config=configs["table_detection"],
row_detection_config=configs["row_detection"],
column_detection_config=configs["column_detection"]
)
detection_data = table2html(image, configs["table_crop_padding"])
if len(detection_data) == 0:
st.warning("No tables detected on this page.")
return
# Clear previous results
st.session_state.detection_data = detection_data
for data in detection_data:
# Store table detection visualization
table_detection = visualize_boxes(
image.copy(),
[data["table_bbox"]],
color=(0, 0, 255),
thickness=2
)
st.session_state.table_detections.append(table_detection)
# Store cropped table
cropped_table = crop_image(
image, data["table_bbox"], configs["table_crop_padding"])
st.session_state.cropped_tables.append(cropped_table)
# Store structure detection visualization
structure_detection = visualize_boxes(
cropped_table.copy(),
[cell['box'] for cell in data['cells']],
color=(0, 255, 0),
thickness=1
)
st.session_state.structure_detections.append(structure_detection)
# Store HTML
st.session_state.html_tables.append(data["html"])
def inference_one_image(image, configs):
clear_results()
with st.spinner("Processing..."):
start_time = time.time()
try:
# Update process_image call to include all model paths
detect_update_results(image, configs)
# Clean up temporary files if using custom models
for model_type, config in configs.items():
if f"{model_type}_option" in st.session_state and \
st.session_state[f"{model_type}_option"] == "custom":
os.unlink(config["model_path"])
execution_time = time.time() - start_time
st.success(
f"Processing completed in {execution_time:.2f} seconds")
except Exception as e:
st.error(f"Error processing image: {str(e)}")
# Clean up temporary files on error
for model_type, config in configs.items():
if f"{model_type}_option" in st.session_state and \
st.session_state[f"{model_type}_option"] == "custom":
os.unlink(config["model_path"])
def main():
initialize_session_state()
st.set_page_config(layout="wide")
# Add page selection
page = st.sidebar.radio("Select Page", ["Inference", "Configuration"])
if page == "Inference":
st.title("Table Detection and Recognition")
# Image Upload Section
st.subheader("Image Upload")
uploaded_file = st.file_uploader(
"Choose an image or PDF file",
type=['jpg', 'jpeg', 'png', 'pdf']
)
# Get configurations from session state
configs = st.session_state.get('configs', default_configs)
current_image = None
if uploaded_file is not None and all(configs.values()):
if uploaded_file.type == "application/pdf":
# Convert PDF to images
pdf_bytes = uploaded_file.read()
pdf_images = []
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
for page_num in range(doc.page_count):
page = doc[page_num]
pix = page.get_pixmap(dpi=200)
pil_image = Image.frombytes(
"RGB", [pix.width, pix.height], pix.samples)
img_array = np.array(pil_image)
pdf_images.append(img_array)
# Show thumbnails
st.write("Select a page to process:")
cols = st.columns(thumbnail_columns)
for idx, img in enumerate(pdf_images):
with cols[idx % thumbnail_columns]:
st.image(img, width=150, use_container_width=True)
if st.button(f"Process Page {idx+1}"):
current_image = img
st.session_state.current_image = img
inference_one_image(
current_image, configs)
else:
# Handle regular image upload
file_bytes = np.asarray(
bytearray(uploaded_file.read()), dtype=np.uint8)
current_image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
st.session_state.current_image = current_image
# Process button
if st.button("Process Image"):
inference_one_image(
current_image, configs)
if len(st.session_state.cropped_tables) > 0:
st.header("Results")
# General Results Section
st.subheader("General Results")
gen_img_col, gen_html_col = st.columns([1, 1])
with gen_img_col:
show_all_detections = st.toggle(
"Show Table Detections",
value=False,
key="show_all_detections"
)
# Display either original image or detection visualization
if show_all_detections and len(st.session_state.detection_data) > 0:
# Create visualization with all table detections
all_tables_viz = visualize_boxes(
st.session_state.current_image.copy(),
[data["table_bbox"]
for data in st.session_state.detection_data],
color=(0, 0, 255),
thickness=2
)
st.image(
all_tables_viz,
caption="All Table Detections",
use_container_width=True
)
else:
st.image(
st.session_state.current_image,
caption="Original Image",
use_container_width=True
)
with gen_html_col:
st.markdown("### All HTML Tables:")
# Combine all HTML tables
all_html = "\n".join(st.session_state.html_tables)
st.markdown(all_html, unsafe_allow_html=True)
# Download all HTML tables
combined_html = "<!DOCTYPE html><html><body>\n" + all_html + "\n</body></html>"
st.download_button(
label="Download All Tables HTML",
data=combined_html,
file_name="all_tables.html",
mime="text/html",
key="download_all_btn"
)
st.divider()
# Detailed Results Section
show_details = st.toggle("Show Detailed Results", value=False)
if show_details:
st.subheader("Detailed Results")
for idx in range(len(st.session_state.cropped_tables)):
st.subheader(f"Table {idx + 1}")
# Visualization controls for each table
control_col1, control_col2 = st.columns([1, 1])
with control_col1:
show_table_detection = st.toggle(
f"Show Table Detection for Table {idx + 1}",
value=False,
key=f"table_detection_{idx}"
)
with control_col2:
show_structure_detection = st.toggle(
f"Show Structure Detection for Table {idx + 1}",
value=False,
key=f"structure_detection_{idx}"
)
# Create columns for each table result
img_col, html_col = st.columns([1, 1])
with img_col:
# Show either the cropped table or visualizations based on toggles
if show_table_detection:
st.image(
st.session_state.table_detections[idx],
caption="Table Detection",
use_container_width=True
)
if show_structure_detection:
st.image(
st.session_state.structure_detections[idx],
caption="Structure Detection",
use_container_width=True
)
if not show_table_detection and not show_structure_detection:
st.image(
st.session_state.cropped_tables[idx],
caption="Cropped Table",
use_container_width=True
)
with html_col:
st.markdown("### HTML Output:")
st.markdown(
st.session_state.html_tables[idx],
unsafe_allow_html=True
)
st.download_button(
label=f"Download Table {idx + 1} HTML",
data=st.session_state.html_tables[idx],
file_name=f"table_{idx + 1}.html",
mime="text/html",
key=f"download_btn_{idx}"
)
st.divider()
else: # Configuration page
st.title("Model Configuration")
# Model selection options
model_types = ["Table Detection", "Column Detection", "Row Detection"]
configs = {} # Store both paths and thresholds
for idx, model_type in enumerate(model_types):
st.markdown(f"### {model_type}")
key_prefix = model_type.lower().replace(" ", "_")
# Model file selection
model_option = st.radio(
f"Choose {model_type} Model",
options=["default", "custom"],
horizontal=True,
key=f"{key_prefix}_option"
)
if model_option == "default":
default_path = f"models/{key_prefix}.pt"
configs[key_prefix] = {"model_path": default_path}
st.info(f"Using default model: {default_path}")
else:
model_upload = st.file_uploader(
f"Choose {model_type} Model File (.pt)",
type=['pt'],
key=f"{key_prefix}_upload"
)
if model_upload:
with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as tmp_file:
tmp_file.write(model_upload.getvalue())
configs[key_prefix] = {
"model_path": tmp_file.name}
else:
configs[key_prefix] = {"model_path": None}
st.warning(
f"Please upload a {model_type.lower()} model file")
# Add threshold controls
thresh_col1, thresh_col2 = st.columns(2)
with thresh_col1:
conf_threshold = st.slider(
f"{model_type} Confidence Threshold",
min_value=0.0,
max_value=1.0,
value=0.25,
step=0.05,
key=f"{key_prefix}_conf_threshold"
)
with thresh_col2:
iou_threshold = st.slider(
f"{model_type} IOU Threshold",
min_value=0.0,
max_value=1.0,
value=0.7,
step=0.05,
key=f"{key_prefix}_iou_threshold"
)
if configs[key_prefix]["model_path"]:
configs[key_prefix].update({
"confidence_threshold": conf_threshold,
"iou_threshold": iou_threshold
})
# Add task field for row and column detection
if key_prefix in ["column_detection", "row_detection"]:
configs[key_prefix]["task"] = "detect"
st.divider()
# Padding input below the model configurations
table_crop_padding = st.number_input(
"Table Crop Padding",
value=15,
min_value=0,
max_value=100
)
# Save configurations to session state
if st.button("Save Configuration"):
st.session_state.configs = configs
st.success("Configuration saved successfully!")
if __name__ == "__main__":
main()