import cv2 import matplotlib.pyplot as plt import numpy as np import streamlit as st import json from doctr.file_utils import is_tf_available from doctr.io import DocumentFile from doctr.utils.visualization import visualize_page from doctr.models import ocr_predictor import torch from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor forward_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def extract_text_data(ocr_result, page_idx): """Extracts and formats text data from OCR results, with the correct page index.""" page_data = {"page_idx": page_idx, "text": ""} for block in ocr_result.pages[0].blocks: for line in block.lines: for word in line.words: page_data["text"] += word.value + " " page_data["text"] += "\n " return page_data def main(det_archs, reco_archs): """Build a Streamlit layout""" st.set_page_config(layout="wide") st.title("docTR: Document Text Recognition") st.write("\n") st.markdown("Hint: click on the top-right corner of an image to enlarge it!") # Set the columns cols = st.columns((1, 1, 1, 1)) cols[0].subheader("Input page") cols[1].subheader("Segmentation heatmap") cols[2].subheader("OCR output") cols[3].subheader("Page reconstitution") st.sidebar.title("Document selection") uploaded_file = st.sidebar.file_uploader("Upload files", type=["pdf", "png", "jpeg", "jpg"]) if uploaded_file is not None: if uploaded_file.name.endswith(".pdf"): doc = DocumentFile.from_pdf(uploaded_file.read()) else: doc = DocumentFile.from_images(uploaded_file.read()) cols[0].image(doc) st.sidebar.title("Model selection") st.sidebar.markdown("*Backend*: " + ("TensorFlow" if is_tf_available() else "PyTorch")) det_arch = st.sidebar.selectbox("Text detection model", det_archs) reco_arch = st.sidebar.selectbox("Text recognition model", reco_archs) assume_straight_pages = True straighten_pages = False bin_thresh = 0.3 box_thresh = 0.1 if st.sidebar.button("Analyze document"): with st.spinner("Loading model..."): predictor = load_predictor( det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, box_thresh, forward_device ) with st.spinner("Analyzing..."): all_pages_export = [] all_pages_text = [] # Store text data from all pages for page_idx, page in enumerate(doc): st.write(f"Processing page {page_idx + 1}/{len(doc)}...") seg_map = forward_image(predictor, page, forward_device) seg_map = np.squeeze(seg_map) seg_map = cv2.resize(seg_map, (page.shape[1], page.shape[0]), interpolation=cv2.INTER_LINEAR) fig, ax = plt.subplots() ax.imshow(seg_map) ax.axis("off") cols[1].pyplot(fig) out = predictor([page]) fig = visualize_page(out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=False) cols[2].pyplot(fig) page_export = out.pages[0].export() all_pages_export.append(page_export) img = out.pages[0].synthesize() cols[3].image(img, clamp=True) page_text = extract_text_data(out, page_idx) # Extract text data for the current page all_pages_text.append(page_text) # Add the current page's text data to the overall list st.markdown("\nHere are your analysis results in JSON format for all pages:") st.json(all_pages_text, expanded=False) st.markdown("\nDownload here:") st.download_button(label="Download JSON", data=json.dumps(all_pages_text), file_name="OCR.json", mime="application/json") # client = pymongo.MongoClient("mongodb://localhost:27017/") # db = client["OCRoutputs"] # collection = db["jsonDatas"] # collection.insert_one({"document": all_pages_text}) if __name__ == "__main__": main(DET_ARCHS, RECO_ARCHS)