sonygod commited on
Commit
27d9242
·
1 Parent(s): 2a99e3c
Files changed (3) hide show
  1. app.py +33 -200
  2. model.py +39 -0
  3. ui.py +40 -0
app.py CHANGED
@@ -1,209 +1,42 @@
1
- # Copyright (C) 2021-2024, Mindee.
2
-
3
- # This program is licensed under the Apache License 2.0.
4
- # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
-
6
- import cv2
7
- import matplotlib.pyplot as plt
8
- import numpy as np
9
- import streamlit as st
10
- import time
11
- from doctr.file_utils import is_tf_available
12
  from doctr.io import DocumentFile
13
- from doctr.utils.visualization import visualize_page
 
14
 
15
- def setup_device():
16
- """Setup and return compute device configuration"""
17
- selected_device = "cpu" # Default to CPU
18
- if torch.cuda.is_available():
19
- device_options = ["cuda", "cpu"]
20
- selected_device = st.sidebar.selectbox("计算设备", device_options)
21
- forward_device = torch.device("cuda:0" if selected_device == "cuda" else "cpu")
22
-
23
- # Display GPU info if CUDA selected
24
- st.sidebar.markdown(f"**当前设备**: {forward_device}")
25
- if selected_device == "cuda":
26
- st.sidebar.markdown(f"**GPU型号**: {torch.cuda.get_device_name(0)}")
27
- st.sidebar.markdown(f"**可用显存**: {torch.cuda.get_device_properties(0).total_memory/1024/1024:.0f}MB")
28
- else:
29
- st.sidebar.write("当前仅支持CPU")
30
- forward_device = torch.device("cpu")
31
- st.sidebar.markdown(f"**当前设备**: {forward_device}")
32
 
33
- return forward_device, selected_device
34
-
35
- def format_time(seconds):
36
- """Format seconds into human readable string"""
37
- return f"{seconds:.2f}秒"
38
-
39
- if is_tf_available():
40
- import tensorflow as tf
41
- from backend.tensorflow import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
42
-
43
- if any(tf.config.experimental.list_physical_devices("gpu")):
44
- forward_device = tf.device("/gpu:0")
45
- else:
46
- forward_device = tf.device("/cpu:0")
47
-
48
- else:
49
- import torch
50
- from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
51
-
52
- forward_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
53
-
54
- def main(det_archs, reco_archs):
55
- """Build a streamlit layout"""
56
- # Wide mode
57
- st.set_page_config(layout="wide")
58
-
59
- # Designing the interface
60
- st.title("美宜家文档文本识别DEMO")
61
- # For newline
62
- st.write("\n")
63
- # Instructions
64
- st.markdown("*提示:单击图像的右上角可以放大!*")
65
- # Set the columns
66
- cols = st.columns((1, 1, 1, 1))
67
- cols[0].subheader("输入页面")
68
- cols[1].subheader("分割热图")
69
- cols[2].subheader("OCR 输出")
70
- cols[3].subheader("页面重构")
71
-
72
- # Sidebar
73
- # File selection
74
- st.sidebar.title("文档选择")
75
- # Choose your own image
76
- uploaded_file = st.sidebar.file_uploader("上传文件", type=["pdf", "png", "jpeg", "jpg"])
77
- if uploaded_file is not None:
78
- if uploaded_file.name.endswith(".pdf"):
79
- doc = DocumentFile.from_pdf(uploaded_file.read())
80
- else:
81
- doc = DocumentFile.from_images(uploaded_file.read())
82
- page_idx = st.sidebar.selectbox("页面选择", [idx + 1 for idx in range(len(doc))]) - 1
83
- page = doc[page_idx]
84
- cols[0].image(page)
85
- # Hardware selection
86
- st.sidebar.title("硬件选择")
87
- forward_device, selected_device = setup_device()
88
- # Model selection
89
- st.sidebar.title("模型选择")
90
- st.sidebar.markdown("**后端**: " + ("TensorFlow" if is_tf_available() else "PyTorch"))
91
- det_arch = st.sidebar.selectbox("文本检测模型", det_archs)
92
- reco_arch = st.sidebar.selectbox("文本识别模型", reco_archs)
93
-
94
- # For newline
95
- st.sidebar.write("\n")
96
- # Only straight pages or possible rotation
97
- st.sidebar.title("参数")
98
- assume_straight_pages = st.sidebar.checkbox("假设页面是直的", value=True)
99
- # Disable page orientation detection
100
- disable_page_orientation = st.sidebar.checkbox("禁用页面方向检测", value=False)
101
- # Disable crop orientation detection
102
- disable_crop_orientation = st.sidebar.checkbox("禁用裁剪方向检测", value=False)
103
- # Straighten pages
104
- straighten_pages = st.sidebar.checkbox("矫正页面", value=False)
105
- # Export as straight boxes
106
- export_straight_boxes = st.sidebar.checkbox("导出为直边框", value=False)
107
- st.sidebar.write("\n")
108
- # Binarization threshold
109
- bin_thresh = st.sidebar.slider("二值化阈值", min_value=0.1, max_value=0.9, value=0.3, step=0.1)
110
- st.sidebar.write("\n")
111
- # Box threshold
112
- box_thresh = st.sidebar.slider("边框阈值", min_value=0.1, max_value=0.9, value=0.1, step=0.1)
113
- st.sidebar.write("\n")
114
-
115
  if st.sidebar.button("分析页面"):
116
  if uploaded_file is None:
117
  st.sidebar.write("请上传一个文档")
 
118
 
119
- else:
120
- start_model = time.time()
121
- with st.spinner("加载模型..."):
122
- predictor = load_predictor(
123
- det_arch=det_arch,
124
- reco_arch=reco_arch,
125
- assume_straight_pages=assume_straight_pages,
126
- straighten_pages=straighten_pages,
127
- export_as_straight_boxes=export_straight_boxes,
128
- disable_page_orientation=disable_page_orientation,
129
- disable_crop_orientation=disable_crop_orientation,
130
- bin_thresh=bin_thresh,
131
- box_thresh=box_thresh,
132
- device=forward_device,
133
- )
134
- model_time = time.time() - start_model
135
- with st.spinner("分析中..."):
136
- seg_time_start = time.time()
137
- # Forward the image to the model
138
- seg_map = forward_image(predictor, page, forward_device)
139
- seg_map = np.squeeze(seg_map)
140
- seg_map = cv2.resize(seg_map, (page.shape[1], page.shape[0]), interpolation=cv2.INTER_LINEAR)
141
- seg_time = time.time() - seg_time_start
142
- # Plot the raw heatmap
143
- fig, ax = plt.subplots()
144
- ax.imshow(seg_map)
145
- ax.axis("off")
146
- cols[1].pyplot(fig)
147
-
148
- # Plot OCR output
149
- ocr_time_start = time.time()
150
- out = predictor([page])
151
- fig = visualize_page(out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=True) # 改为True显示标签
152
- cols[2].pyplot(fig)
153
- ocr_time = time.time() - ocr_time_start
154
-
155
- # Page reconsitution under input page
156
- page_time_start = time.time()
157
- page_export = out.pages[0].export()
158
- if assume_straight_pages or (not assume_straight_pages and straighten_pages):
159
- # 获取合成图像
160
- img = out.pages[0].synthesize()
161
-
162
- # 计算所有文本框的边界
163
- x_min, y_min = float('inf'), float('inf')
164
- x_max, y_max = 0, 0
165
-
166
- for block in page_export["blocks"]:
167
- # 获取每个块的坐标
168
- coords = np.array(block["geometry"])
169
- x_min = min(x_min, coords[:, 0].min() * img.shape[1])
170
- y_min = min(y_min, coords[:, 1].min() * img.shape[0])
171
- x_max = max(x_max, coords[:, 0].max() * img.shape[1])
172
- y_max = max(y_max, coords[:, 1].max() * img.shape[0])
173
-
174
- # 添加边距
175
- margin = 10
176
- x_min = max(0, x_min - margin)
177
- y_min = max(0, y_min - margin)
178
- x_max = min(img.shape[1], x_max + margin)
179
- y_max = min(img.shape[0], y_max + margin)
180
-
181
- # 裁剪图像
182
- cropped_img = img[int(y_min):int(y_max), int(x_min):int(x_max)]
183
-
184
- # 显示裁剪后的图像
185
- cols[3].image(cropped_img, clamp=True)
186
-
187
- # 添加文本结果显示
188
- page_time= time.time() - page_time_start
189
-
190
- total_time = time.time() - seg_time_start
191
-
192
- cols[0].subheader(f"输入页面 (总耗时: {format_time(total_time)})")
193
- cols[1].subheader(f"分割热图 (耗时: {format_time(seg_time)})")
194
- cols[2].subheader(f"OCR输出 (耗时: {format_time(ocr_time)})")
195
- cols[3].subheader(f"页面重构 (模型加载: {format_time(page_time)})")
196
- st.markdown("\n### OCR Text Results:")
197
- for block in page_export["blocks"]:
198
- for line in block["lines"]:
199
- for word in line["words"]:
200
- st.write(f"Text: {word['value']}, Confidence: {word['confidence']:.2f}")
201
-
202
- # Display JSON
203
- st.markdown("\nHere are your analysis results in JSON format:")
204
- #show total_time
205
- st.json({"total_time": total_time}, expanded=True)
206
- st.json(page_export, expanded=True) # 改为True展开显示
207
 
208
  if __name__ == "__main__":
209
- main(DET_ARCHS, RECO_ARCHS)
 
1
+ from model import OCRModel, DET_ARCHS, RECO_ARCHS
2
+ from ui import OCRUI
 
 
 
 
 
 
 
 
 
3
  from doctr.io import DocumentFile
4
+ import time
5
+ import streamlit as st
6
 
7
+ def main():
8
+ ui = OCRUI()
9
+ model = OCRModel()
10
+
11
+ uploaded_file, params = ui.setup_sidebar(DET_ARCHS, RECO_ARCHS)
 
 
 
 
 
 
 
 
 
 
 
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  if st.sidebar.button("分析页面"):
14
  if uploaded_file is None:
15
  st.sidebar.write("请上传一个文档")
16
+ return
17
 
18
+ doc = DocumentFile.from_pdf(uploaded_file.read()) if uploaded_file.name.endswith(".pdf") else DocumentFile.from_images(uploaded_file.read())
19
+ page_idx = st.sidebar.selectbox("页面选择", [idx + 1 for idx in range(len(doc))]) - 1
20
+ page = doc[page_idx]
21
+
22
+ # Process page
23
+ start_time = time.time()
24
+ model.load_model(**params)
25
+ seg_map, out = model.process_page(page)
26
+
27
+ # Display results
28
+ ui.cols[0].image(page)
29
+ fig, ax = plt.subplots()
30
+ ax.imshow(seg_map)
31
+ ax.axis("off")
32
+ ui.cols[1].pyplot(fig)
33
+
34
+ fig = visualize_page(out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=True)
35
+ ui.cols[2].pyplot(fig)
36
+
37
+ # Display processing time and results
38
+ total_time = time.time() - start_time
39
+ st.json({"total_time": total_time, "results": out.pages[0].export()})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  if __name__ == "__main__":
42
+ main()
model.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from doctr.file_utils import is_tf_available
3
+ import numpy as np
4
+ import cv2
5
+
6
+ if is_tf_available():
7
+ import tensorflow as tf
8
+ from backend.tensorflow import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
9
+ else:
10
+ import torch
11
+ from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
12
+
13
+ class OCRModel:
14
+ def __init__(self):
15
+ self.predictor = None
16
+ self.device = self._setup_device()
17
+
18
+ def _setup_device(self):
19
+ if is_tf_available():
20
+ if any(tf.config.experimental.list_physical_devices("gpu")):
21
+ return tf.device("/gpu:0")
22
+ return tf.device("/cpu:0")
23
+ else:
24
+ return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
+
26
+ def load_model(self, det_arch, reco_arch, **kwargs):
27
+ self.predictor = load_predictor(
28
+ det_arch=det_arch,
29
+ reco_arch=reco_arch,
30
+ device=self.device,
31
+ **kwargs
32
+ )
33
+
34
+ def process_page(self, page):
35
+ seg_map = forward_image(self.predictor, page, self.device)
36
+ seg_map = np.squeeze(seg_map)
37
+ seg_map = cv2.resize(seg_map, (page.shape[1], page.shape[0]), interpolation=cv2.INTER_LINEAR)
38
+ out = self.predictor([page])
39
+ return seg_map, out
ui.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import matplotlib.pyplot as plt
3
+ from doctr.utils.visualization import visualize_page
4
+
5
+ class OCRUI:
6
+ def __init__(self):
7
+ self.setup_page_config()
8
+ self.cols = self.create_layout()
9
+
10
+ def setup_page_config(self):
11
+ st.set_page_config(layout="wide")
12
+ st.title("美宜家文档文本识别DEMO")
13
+ st.write("\n")
14
+ st.markdown("*提示:单击图像的右上角可以放大!*")
15
+
16
+ def create_layout(self):
17
+ cols = st.columns((1, 1, 1, 1))
18
+ cols[0].subheader("输入页面")
19
+ cols[1].subheader("分割热图")
20
+ cols[2].subheader("OCR 输出")
21
+ cols[3].subheader("页面重构")
22
+ return cols
23
+
24
+ def setup_sidebar(self, det_archs, reco_archs):
25
+ st.sidebar.title("文档选择")
26
+ uploaded_file = st.sidebar.file_uploader("上传文件", type=["pdf", "png", "jpeg", "jpg"])
27
+
28
+ params = {
29
+ "assume_straight_pages": st.sidebar.checkbox("假设页面是直的", value=True),
30
+ "disable_page_orientation": st.sidebar.checkbox("禁用页面方向检测", value=False),
31
+ "disable_crop_orientation": st.sidebar.checkbox("禁用裁剪方向检测", value=False),
32
+ "straighten_pages": st.sidebar.checkbox("矫正页面", value=False),
33
+ "export_straight_boxes": st.sidebar.checkbox("导出为直边框", value=False),
34
+ "bin_thresh": st.sidebar.slider("二值化阈值", 0.1, 0.9, 0.3, 0.1),
35
+ "box_thresh": st.sidebar.slider("边框阈值", 0.1, 0.9, 0.1, 0.1),
36
+ "det_arch": st.sidebar.selectbox("文本检测模型", det_archs),
37
+ "reco_arch": st.sidebar.selectbox("文本识别模型", reco_archs)
38
+ }
39
+
40
+ return uploaded_file, params