File size: 4,481 Bytes
0e17e4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8590bc7
 
 
 
 
 
 
906a99a
8590bc7
0e17e4e
8590bc7
 
 
0e17e4e
 
 
8590bc7
 
0e17e4e
 
 
 
 
8590bc7
0e17e4e
 
 
 
 
 
 
 
 
 
 
8590bc7
0e17e4e
 
8590bc7
0e17e4e
 
 
 
 
8590bc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e17e4e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import cv2
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
import json

from doctr.file_utils import is_tf_available
from doctr.io import DocumentFile
from doctr.utils.visualization import visualize_page
from doctr.models import ocr_predictor

import torch
from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor

forward_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def extract_text_data(ocr_result, page_idx):
    """Extracts and formats text data from OCR results, with the correct page index."""
    page_data = {"page_idx": page_idx, "text": ""}
    for block in ocr_result.pages[0].blocks:
        for line in block.lines:
            for word in line.words:
                page_data["text"] += word.value + " "
            page_data["text"] += "\n "
    return page_data

def main(det_archs, reco_archs):
    """Build a Streamlit layout"""
    st.set_page_config(layout="wide")

    st.title("docTR: Document Text Recognition")
    st.write("\n")
    st.markdown("Hint: click on the top-right corner of an image to enlarge it!")

    # Set the columns
    cols = st.columns((1, 1, 1, 1))
    cols[0].subheader("Input page")
    cols[1].subheader("Segmentation heatmap")
    cols[2].subheader("OCR output")
    cols[3].subheader("Page reconstitution")  

    st.sidebar.title("Document selection")
    uploaded_file = st.sidebar.file_uploader("Upload files", type=["pdf", "png", "jpeg", "jpg"])
    if uploaded_file is not None:
        if uploaded_file.name.endswith(".pdf"):
            doc = DocumentFile.from_pdf(uploaded_file.read())
        else:
            doc = DocumentFile.from_images(uploaded_file.read())
        cols[0].image(doc)

        st.sidebar.title("Model selection")
        st.sidebar.markdown("*Backend*: " + ("TensorFlow" if is_tf_available() else "PyTorch"))
        det_arch = st.sidebar.selectbox("Text detection model", det_archs)
        reco_arch = st.sidebar.selectbox("Text recognition model", reco_archs)
        
        assume_straight_pages = True
        straighten_pages = False
        bin_thresh = 0.3
        box_thresh = 0.1

        if st.sidebar.button("Analyze document"):
            with st.spinner("Loading model..."):
                predictor = load_predictor(
                    det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, box_thresh, forward_device
                )

            with st.spinner("Analyzing..."):
                all_pages_export = []
                all_pages_text = []  # Store text data from all pages
                for page_idx, page in enumerate(doc):
                    st.write(f"Processing page {page_idx + 1}/{len(doc)}...")
                    
                    seg_map = forward_image(predictor, page, forward_device)
                    seg_map = np.squeeze(seg_map)
                    seg_map = cv2.resize(seg_map, (page.shape[1], page.shape[0]), interpolation=cv2.INTER_LINEAR)

                    fig, ax = plt.subplots()
                    ax.imshow(seg_map)
                    ax.axis("off")
                    cols[1].pyplot(fig)   
                    
                    out = predictor([page])
                    fig = visualize_page(out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=False)
                    cols[2].pyplot(fig)   

                    page_export = out.pages[0].export()
                    all_pages_export.append(page_export)
                    img = out.pages[0].synthesize()
                    cols[3].image(img, clamp=True)

                    page_text = extract_text_data(out, page_idx)  # Extract text data for the current page
                    all_pages_text.append(page_text)  # Add the current page's text data to the overall list

                st.markdown("\nHere are your analysis results in JSON format for all pages:")
                st.json(all_pages_text, expanded=False)

                st.markdown("\nDownload here:")
                st.download_button(label="Download JSON", data=json.dumps(all_pages_text), file_name="OCR.json", mime="application/json")
                
                # client = pymongo.MongoClient("mongodb://localhost:27017/")
                # db = client["OCRoutputs"]
                # collection = db["jsonDatas"]
                
                # collection.insert_one({"document": all_pages_text})

if __name__ == "__main__":
    main(DET_ARCHS, RECO_ARCHS)