UI 分离
Browse files
app.py
CHANGED
|
@@ -1,209 +1,42 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
# This program is licensed under the Apache License 2.0.
|
| 4 |
-
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
| 5 |
-
|
| 6 |
-
import cv2
|
| 7 |
-
import matplotlib.pyplot as plt
|
| 8 |
-
import numpy as np
|
| 9 |
-
import streamlit as st
|
| 10 |
-
import time
|
| 11 |
-
from doctr.file_utils import is_tf_available
|
| 12 |
from doctr.io import DocumentFile
|
| 13 |
-
|
|
|
|
| 14 |
|
| 15 |
-
def
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
selected_device = st.sidebar.selectbox("计算设备", device_options)
|
| 21 |
-
forward_device = torch.device("cuda:0" if selected_device == "cuda" else "cpu")
|
| 22 |
-
|
| 23 |
-
# Display GPU info if CUDA selected
|
| 24 |
-
st.sidebar.markdown(f"**当前设备**: {forward_device}")
|
| 25 |
-
if selected_device == "cuda":
|
| 26 |
-
st.sidebar.markdown(f"**GPU型号**: {torch.cuda.get_device_name(0)}")
|
| 27 |
-
st.sidebar.markdown(f"**可用显存**: {torch.cuda.get_device_properties(0).total_memory/1024/1024:.0f}MB")
|
| 28 |
-
else:
|
| 29 |
-
st.sidebar.write("当前仅支持CPU")
|
| 30 |
-
forward_device = torch.device("cpu")
|
| 31 |
-
st.sidebar.markdown(f"**当前设备**: {forward_device}")
|
| 32 |
|
| 33 |
-
return forward_device, selected_device
|
| 34 |
-
|
| 35 |
-
def format_time(seconds):
|
| 36 |
-
"""Format seconds into human readable string"""
|
| 37 |
-
return f"{seconds:.2f}秒"
|
| 38 |
-
|
| 39 |
-
if is_tf_available():
|
| 40 |
-
import tensorflow as tf
|
| 41 |
-
from backend.tensorflow import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
|
| 42 |
-
|
| 43 |
-
if any(tf.config.experimental.list_physical_devices("gpu")):
|
| 44 |
-
forward_device = tf.device("/gpu:0")
|
| 45 |
-
else:
|
| 46 |
-
forward_device = tf.device("/cpu:0")
|
| 47 |
-
|
| 48 |
-
else:
|
| 49 |
-
import torch
|
| 50 |
-
from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
|
| 51 |
-
|
| 52 |
-
forward_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 53 |
-
|
| 54 |
-
def main(det_archs, reco_archs):
|
| 55 |
-
"""Build a streamlit layout"""
|
| 56 |
-
# Wide mode
|
| 57 |
-
st.set_page_config(layout="wide")
|
| 58 |
-
|
| 59 |
-
# Designing the interface
|
| 60 |
-
st.title("美宜家文档文本识别DEMO")
|
| 61 |
-
# For newline
|
| 62 |
-
st.write("\n")
|
| 63 |
-
# Instructions
|
| 64 |
-
st.markdown("*提示:单击图像的右上角可以放大!*")
|
| 65 |
-
# Set the columns
|
| 66 |
-
cols = st.columns((1, 1, 1, 1))
|
| 67 |
-
cols[0].subheader("输入页面")
|
| 68 |
-
cols[1].subheader("分割热图")
|
| 69 |
-
cols[2].subheader("OCR 输出")
|
| 70 |
-
cols[3].subheader("页面重构")
|
| 71 |
-
|
| 72 |
-
# Sidebar
|
| 73 |
-
# File selection
|
| 74 |
-
st.sidebar.title("文档选择")
|
| 75 |
-
# Choose your own image
|
| 76 |
-
uploaded_file = st.sidebar.file_uploader("上传文件", type=["pdf", "png", "jpeg", "jpg"])
|
| 77 |
-
if uploaded_file is not None:
|
| 78 |
-
if uploaded_file.name.endswith(".pdf"):
|
| 79 |
-
doc = DocumentFile.from_pdf(uploaded_file.read())
|
| 80 |
-
else:
|
| 81 |
-
doc = DocumentFile.from_images(uploaded_file.read())
|
| 82 |
-
page_idx = st.sidebar.selectbox("页面选择", [idx + 1 for idx in range(len(doc))]) - 1
|
| 83 |
-
page = doc[page_idx]
|
| 84 |
-
cols[0].image(page)
|
| 85 |
-
# Hardware selection
|
| 86 |
-
st.sidebar.title("硬件选择")
|
| 87 |
-
forward_device, selected_device = setup_device()
|
| 88 |
-
# Model selection
|
| 89 |
-
st.sidebar.title("模型选择")
|
| 90 |
-
st.sidebar.markdown("**后端**: " + ("TensorFlow" if is_tf_available() else "PyTorch"))
|
| 91 |
-
det_arch = st.sidebar.selectbox("文本检测模型", det_archs)
|
| 92 |
-
reco_arch = st.sidebar.selectbox("文本识别模型", reco_archs)
|
| 93 |
-
|
| 94 |
-
# For newline
|
| 95 |
-
st.sidebar.write("\n")
|
| 96 |
-
# Only straight pages or possible rotation
|
| 97 |
-
st.sidebar.title("参数")
|
| 98 |
-
assume_straight_pages = st.sidebar.checkbox("假设页面是直的", value=True)
|
| 99 |
-
# Disable page orientation detection
|
| 100 |
-
disable_page_orientation = st.sidebar.checkbox("禁用页面方向检测", value=False)
|
| 101 |
-
# Disable crop orientation detection
|
| 102 |
-
disable_crop_orientation = st.sidebar.checkbox("禁用裁剪方向检测", value=False)
|
| 103 |
-
# Straighten pages
|
| 104 |
-
straighten_pages = st.sidebar.checkbox("矫正页面", value=False)
|
| 105 |
-
# Export as straight boxes
|
| 106 |
-
export_straight_boxes = st.sidebar.checkbox("导出为直边框", value=False)
|
| 107 |
-
st.sidebar.write("\n")
|
| 108 |
-
# Binarization threshold
|
| 109 |
-
bin_thresh = st.sidebar.slider("二值化阈值", min_value=0.1, max_value=0.9, value=0.3, step=0.1)
|
| 110 |
-
st.sidebar.write("\n")
|
| 111 |
-
# Box threshold
|
| 112 |
-
box_thresh = st.sidebar.slider("边框阈值", min_value=0.1, max_value=0.9, value=0.1, step=0.1)
|
| 113 |
-
st.sidebar.write("\n")
|
| 114 |
-
|
| 115 |
if st.sidebar.button("分析页面"):
|
| 116 |
if uploaded_file is None:
|
| 117 |
st.sidebar.write("请上传一个文档")
|
|
|
|
| 118 |
|
| 119 |
-
else
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
seg_time = time.time() - seg_time_start
|
| 142 |
-
# Plot the raw heatmap
|
| 143 |
-
fig, ax = plt.subplots()
|
| 144 |
-
ax.imshow(seg_map)
|
| 145 |
-
ax.axis("off")
|
| 146 |
-
cols[1].pyplot(fig)
|
| 147 |
-
|
| 148 |
-
# Plot OCR output
|
| 149 |
-
ocr_time_start = time.time()
|
| 150 |
-
out = predictor([page])
|
| 151 |
-
fig = visualize_page(out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=True) # 改为True显示标签
|
| 152 |
-
cols[2].pyplot(fig)
|
| 153 |
-
ocr_time = time.time() - ocr_time_start
|
| 154 |
-
|
| 155 |
-
# Page reconsitution under input page
|
| 156 |
-
page_time_start = time.time()
|
| 157 |
-
page_export = out.pages[0].export()
|
| 158 |
-
if assume_straight_pages or (not assume_straight_pages and straighten_pages):
|
| 159 |
-
# 获取合成图像
|
| 160 |
-
img = out.pages[0].synthesize()
|
| 161 |
-
|
| 162 |
-
# 计算所有文本框的边界
|
| 163 |
-
x_min, y_min = float('inf'), float('inf')
|
| 164 |
-
x_max, y_max = 0, 0
|
| 165 |
-
|
| 166 |
-
for block in page_export["blocks"]:
|
| 167 |
-
# 获取每个块的坐标
|
| 168 |
-
coords = np.array(block["geometry"])
|
| 169 |
-
x_min = min(x_min, coords[:, 0].min() * img.shape[1])
|
| 170 |
-
y_min = min(y_min, coords[:, 1].min() * img.shape[0])
|
| 171 |
-
x_max = max(x_max, coords[:, 0].max() * img.shape[1])
|
| 172 |
-
y_max = max(y_max, coords[:, 1].max() * img.shape[0])
|
| 173 |
-
|
| 174 |
-
# 添加边距
|
| 175 |
-
margin = 10
|
| 176 |
-
x_min = max(0, x_min - margin)
|
| 177 |
-
y_min = max(0, y_min - margin)
|
| 178 |
-
x_max = min(img.shape[1], x_max + margin)
|
| 179 |
-
y_max = min(img.shape[0], y_max + margin)
|
| 180 |
-
|
| 181 |
-
# 裁剪图像
|
| 182 |
-
cropped_img = img[int(y_min):int(y_max), int(x_min):int(x_max)]
|
| 183 |
-
|
| 184 |
-
# 显示裁剪后的图像
|
| 185 |
-
cols[3].image(cropped_img, clamp=True)
|
| 186 |
-
|
| 187 |
-
# 添加文本结果显示
|
| 188 |
-
page_time= time.time() - page_time_start
|
| 189 |
-
|
| 190 |
-
total_time = time.time() - seg_time_start
|
| 191 |
-
|
| 192 |
-
cols[0].subheader(f"输入页面 (总耗时: {format_time(total_time)})")
|
| 193 |
-
cols[1].subheader(f"分割热图 (耗时: {format_time(seg_time)})")
|
| 194 |
-
cols[2].subheader(f"OCR输出 (耗时: {format_time(ocr_time)})")
|
| 195 |
-
cols[3].subheader(f"页面重构 (模型加载: {format_time(page_time)})")
|
| 196 |
-
st.markdown("\n### OCR Text Results:")
|
| 197 |
-
for block in page_export["blocks"]:
|
| 198 |
-
for line in block["lines"]:
|
| 199 |
-
for word in line["words"]:
|
| 200 |
-
st.write(f"Text: {word['value']}, Confidence: {word['confidence']:.2f}")
|
| 201 |
-
|
| 202 |
-
# Display JSON
|
| 203 |
-
st.markdown("\nHere are your analysis results in JSON format:")
|
| 204 |
-
#show total_time
|
| 205 |
-
st.json({"total_time": total_time}, expanded=True)
|
| 206 |
-
st.json(page_export, expanded=True) # 改为True展开显示
|
| 207 |
|
| 208 |
if __name__ == "__main__":
|
| 209 |
-
main(
|
|
|
|
| 1 |
+
from model import OCRModel, DET_ARCHS, RECO_ARCHS
|
| 2 |
+
from ui import OCRUI
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from doctr.io import DocumentFile
|
| 4 |
+
import time
|
| 5 |
+
import streamlit as st
|
| 6 |
|
| 7 |
+
def main():
|
| 8 |
+
ui = OCRUI()
|
| 9 |
+
model = OCRModel()
|
| 10 |
+
|
| 11 |
+
uploaded_file, params = ui.setup_sidebar(DET_ARCHS, RECO_ARCHS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
if st.sidebar.button("分析页面"):
|
| 14 |
if uploaded_file is None:
|
| 15 |
st.sidebar.write("请上传一个文档")
|
| 16 |
+
return
|
| 17 |
|
| 18 |
+
doc = DocumentFile.from_pdf(uploaded_file.read()) if uploaded_file.name.endswith(".pdf") else DocumentFile.from_images(uploaded_file.read())
|
| 19 |
+
page_idx = st.sidebar.selectbox("页面选择", [idx + 1 for idx in range(len(doc))]) - 1
|
| 20 |
+
page = doc[page_idx]
|
| 21 |
+
|
| 22 |
+
# Process page
|
| 23 |
+
start_time = time.time()
|
| 24 |
+
model.load_model(**params)
|
| 25 |
+
seg_map, out = model.process_page(page)
|
| 26 |
+
|
| 27 |
+
# Display results
|
| 28 |
+
ui.cols[0].image(page)
|
| 29 |
+
fig, ax = plt.subplots()
|
| 30 |
+
ax.imshow(seg_map)
|
| 31 |
+
ax.axis("off")
|
| 32 |
+
ui.cols[1].pyplot(fig)
|
| 33 |
+
|
| 34 |
+
fig = visualize_page(out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=True)
|
| 35 |
+
ui.cols[2].pyplot(fig)
|
| 36 |
+
|
| 37 |
+
# Display processing time and results
|
| 38 |
+
total_time = time.time() - start_time
|
| 39 |
+
st.json({"total_time": total_time, "results": out.pages[0].export()})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
if __name__ == "__main__":
|
| 42 |
+
main()
|
model.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from doctr.file_utils import is_tf_available
|
| 3 |
+
import numpy as np
|
| 4 |
+
import cv2
|
| 5 |
+
|
| 6 |
+
if is_tf_available():
|
| 7 |
+
import tensorflow as tf
|
| 8 |
+
from backend.tensorflow import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
|
| 9 |
+
else:
|
| 10 |
+
import torch
|
| 11 |
+
from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
|
| 12 |
+
|
| 13 |
+
class OCRModel:
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.predictor = None
|
| 16 |
+
self.device = self._setup_device()
|
| 17 |
+
|
| 18 |
+
def _setup_device(self):
|
| 19 |
+
if is_tf_available():
|
| 20 |
+
if any(tf.config.experimental.list_physical_devices("gpu")):
|
| 21 |
+
return tf.device("/gpu:0")
|
| 22 |
+
return tf.device("/cpu:0")
|
| 23 |
+
else:
|
| 24 |
+
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 25 |
+
|
| 26 |
+
def load_model(self, det_arch, reco_arch, **kwargs):
|
| 27 |
+
self.predictor = load_predictor(
|
| 28 |
+
det_arch=det_arch,
|
| 29 |
+
reco_arch=reco_arch,
|
| 30 |
+
device=self.device,
|
| 31 |
+
**kwargs
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def process_page(self, page):
|
| 35 |
+
seg_map = forward_image(self.predictor, page, self.device)
|
| 36 |
+
seg_map = np.squeeze(seg_map)
|
| 37 |
+
seg_map = cv2.resize(seg_map, (page.shape[1], page.shape[0]), interpolation=cv2.INTER_LINEAR)
|
| 38 |
+
out = self.predictor([page])
|
| 39 |
+
return seg_map, out
|
ui.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
from doctr.utils.visualization import visualize_page
|
| 4 |
+
|
| 5 |
+
class OCRUI:
|
| 6 |
+
def __init__(self):
|
| 7 |
+
self.setup_page_config()
|
| 8 |
+
self.cols = self.create_layout()
|
| 9 |
+
|
| 10 |
+
def setup_page_config(self):
|
| 11 |
+
st.set_page_config(layout="wide")
|
| 12 |
+
st.title("美宜家文档文本识别DEMO")
|
| 13 |
+
st.write("\n")
|
| 14 |
+
st.markdown("*提示:单击图像的右上角可以放大!*")
|
| 15 |
+
|
| 16 |
+
def create_layout(self):
|
| 17 |
+
cols = st.columns((1, 1, 1, 1))
|
| 18 |
+
cols[0].subheader("输入页面")
|
| 19 |
+
cols[1].subheader("分割热图")
|
| 20 |
+
cols[2].subheader("OCR 输出")
|
| 21 |
+
cols[3].subheader("页面重构")
|
| 22 |
+
return cols
|
| 23 |
+
|
| 24 |
+
def setup_sidebar(self, det_archs, reco_archs):
|
| 25 |
+
st.sidebar.title("文档选择")
|
| 26 |
+
uploaded_file = st.sidebar.file_uploader("上传文件", type=["pdf", "png", "jpeg", "jpg"])
|
| 27 |
+
|
| 28 |
+
params = {
|
| 29 |
+
"assume_straight_pages": st.sidebar.checkbox("假设页面是直的", value=True),
|
| 30 |
+
"disable_page_orientation": st.sidebar.checkbox("禁用页面方向检测", value=False),
|
| 31 |
+
"disable_crop_orientation": st.sidebar.checkbox("禁用裁剪方向检测", value=False),
|
| 32 |
+
"straighten_pages": st.sidebar.checkbox("矫正页面", value=False),
|
| 33 |
+
"export_straight_boxes": st.sidebar.checkbox("导出为直边框", value=False),
|
| 34 |
+
"bin_thresh": st.sidebar.slider("二值化阈值", 0.1, 0.9, 0.3, 0.1),
|
| 35 |
+
"box_thresh": st.sidebar.slider("边框阈值", 0.1, 0.9, 0.1, 0.1),
|
| 36 |
+
"det_arch": st.sidebar.selectbox("文本检测模型", det_archs),
|
| 37 |
+
"reco_arch": st.sidebar.selectbox("文本识别模型", reco_archs)
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
return uploaded_file, params
|