| | |
| |
|
| | |
| | |
| |
|
| | import os |
| |
|
| | import matplotlib.pyplot as plt |
| | import streamlit as st |
| |
|
| | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" |
| |
|
| | import cv2 |
| | import tensorflow as tf |
| |
|
| | gpu_devices = tf.config.experimental.list_physical_devices('GPU') |
| | if any(gpu_devices): |
| | tf.config.experimental.set_memory_growth(gpu_devices[0], True) |
| |
|
| | from doctr.io import DocumentFile |
| | from doctr.models import ocr_predictor |
| | from doctr.utils.visualization import visualize_page |
| |
|
| | DET_ARCHS = ["db_resnet50", "db_mobilenet_v3_large"] |
| | RECO_ARCHS = ["crnn_vgg16_bn", "crnn_mobilenet_v3_small", "master", "sar_resnet31"] |
| |
|
| |
|
| | def main(): |
| |
|
| | |
| | st.set_page_config(layout="wide") |
| |
|
| | |
| | st.title("docTR: Document Text Recognition") |
| | |
| | st.write('\n') |
| | |
| | st.write('Find more info at: https://marketone.co/') |
| | |
| | st.write('\n') |
| | |
| | st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*") |
| | |
| | cols = st.beta_columns((1, 1, 1, 1)) |
| | cols[0].subheader("Input page") |
| | cols[1].subheader("Segmentation heatmap") |
| | cols[2].subheader("OCR output") |
| | cols[3].subheader("Page reconstitution") |
| |
|
| | |
| | |
| | st.sidebar.title("Document selection") |
| | |
| | st.set_option('deprecation.showfileUploaderEncoding', False) |
| | |
| | uploaded_file = st.sidebar.file_uploader("Upload files", type=['pdf', 'png', 'jpeg', 'jpg']) |
| | if uploaded_file is not None: |
| | if uploaded_file.name.endswith('.pdf'): |
| | doc = DocumentFile.from_pdf(uploaded_file.read()) |
| | else: |
| | doc = DocumentFile.from_images(uploaded_file.read()) |
| | page_idx = st.sidebar.selectbox("Page selection", [idx + 1 for idx in range(len(doc))]) - 1 |
| | cols[0].image(doc[page_idx]) |
| |
|
| | |
| | st.sidebar.title("Model selection") |
| | det_arch = st.sidebar.selectbox("Text detection model", DET_ARCHS) |
| | reco_arch = st.sidebar.selectbox("Text recognition model", RECO_ARCHS) |
| |
|
| | |
| | st.sidebar.write('\n') |
| |
|
| | if st.sidebar.button("Analyze page"): |
| |
|
| | if uploaded_file is None: |
| | st.sidebar.write("Please upload a document") |
| |
|
| | else: |
| | with st.spinner('Loading model...'): |
| | predictor = ocr_predictor(det_arch, reco_arch, pretrained=True) |
| |
|
| | with st.spinner('Analyzing...'): |
| |
|
| | |
| | processed_batches = predictor.det_predictor.pre_processor([doc[page_idx]]) |
| | out = predictor.det_predictor.model(processed_batches[0], return_model_output=True) |
| | seg_map = out["out_map"] |
| | seg_map = tf.squeeze(seg_map[0, ...], axis=[2]) |
| | seg_map = cv2.resize(seg_map.numpy(), (doc[page_idx].shape[1], doc[page_idx].shape[0]), |
| | interpolation=cv2.INTER_LINEAR) |
| | |
| | fig, ax = plt.subplots() |
| | ax.imshow(seg_map) |
| | ax.axis('off') |
| | cols[1].pyplot(fig) |
| |
|
| | |
| | out = predictor([doc[page_idx]]) |
| | fig = visualize_page(out.pages[0].export(), doc[page_idx], interactive=False) |
| | cols[2].pyplot(fig) |
| |
|
| | |
| | page_export = out.pages[0].export() |
| | img = out.pages[0].synthesize() |
| | cols[3].image(img, clamp=True) |
| |
|
| | |
| | st.markdown("\nHere are your analysis results in JSON format:") |
| | st.json(page_export) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |