File size: 8,068 Bytes
27d9242
8abd549
27d9242
95016d5
 
 
df4ac89
27d9242
4f1b300
3c73b18
4f1b300
 
27d9242
 
 
 
df4ac89
 
e128b79
4f1b300
0f20979
95016d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27d9242
 
 
 
 
4f1b300
 
 
176738a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f1b300
27d9242
 
 
 
 
 
 
 
 
 
 
 
bcfbb56
 
 
 
 
 
 
 
 
 
 
df4ac89
83cc25f
df4ac89
 
 
 
 
 
95016d5
 
3c73b18
c0a48c1
4f1b300
 
 
176738a
 
 
 
 
 
 
 
 
 
 
4f1b300
95016d5
4f1b300
 
 
0f20979
4f1b300
 
 
0f20979
4f1b300
 
 
 
 
 
 
 
 
 
 
 
 
 
df4ac89
27d9242
4559f18
 
27d9242
 
4559f18
 
 
27d9242
4559f18
27d9242
 
df4ac89
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import streamlit as st
import torch
import matplotlib.pyplot as plt
import aiohttp
import asyncio
import os
from doctr.io import DocumentFile  # Add this import
from doctr.utils.visualization import visualize_page
from stripeRemover import StripeRemover
import time
import numpy as np
import cv2
class OCRUI:
    def __init__(self):
        self.setup_page_config()
        self.cols = self.create_layout()
        self.current_doc = None
        self.current_page = None
        self.API_URL = "http://s15.serv00.com:9081/compareAnalyze"
        self.stripe_remover = StripeRemover()  # Initialize here
        self.processed_image = None
        
    async def send_to_external_api(self, file_bytes):
        async with aiohttp.ClientSession() as session:
            data = aiohttp.FormData()
            data.add_field('image', file_bytes, 
                          filename='image.jpg',
                          content_type='image/jpeg')
            data.add_field('model', 'GEMINI')
            
            try:
                async with session.post(self.API_URL, data=data) as response:
                    return await response.json()
            except Exception as e:
                st.error(f"External API Error: {str(e)}")
                return None
    def setup_page_config(self):
        st.set_page_config(layout="wide")
        st.title("美宜家文档文本识别DEMO")
        st.write("\n")
        st.markdown("*提示:单击图像的右上角可以放大!*")
        
        
    def process_image(self, image, method):
        try:
            if method == "原始图像":
                return cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
            elif method == "傅里叶变换去条纹":
                return self.stripe_remover.fourier_method(image)
            elif method == "形态学去条纹":
                return self.stripe_remover.morphological_method(image)
            elif method == "自适应阈值":
                return self.stripe_remover.adaptive_threshold_method(image)
            elif method == "图像增强":
                gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
                return self.stripe_remover.enhance_image(gray)
            elif method == "自适应阈值+图像增强":
                return self.stripe_remover.adaptive_enhance(image)
            elif method == "傅里叶变换+图像增强":
                return self.stripe_remover.fourier_enhance(image)
            elif method == "形态学操作+图像增强":
                return self.stripe_remover.morphological_enhance(image)
            elif method == "自适应阈值+傅里叶变换":
                return self.stripe_remover.adaptive_fourier(image)
            elif method == "形态学操作+自适应阈值":
                return self.stripe_remover.morphological_adaptive(image)
            elif method == "傅里叶变换+形态学操作":
                return self.stripe_remover.fourier_morphological(image)
        except Exception as e:
            st.error(f"图像处理失败: {str(e)}")
            return None   
    
    def create_layout(self):
        cols = st.columns((1, 1, 1, 1))
        cols[0].subheader("输入页面")
        cols[1].subheader("分割热图")
        cols[2].subheader("OCR 输出")
        cols[3].subheader("页面重构")
        return cols

    def setup_sidebar(self, det_archs, reco_archs):
        st.sidebar.title("文档选择")
        uploaded_file = st.sidebar.file_uploader("上传文件", type=["pdf", "png", "jpeg", "jpg"])
        
        # Add hardware selection
        st.sidebar.title("硬件选择")
        if torch.cuda.is_available():
            use_gpu = st.sidebar.checkbox("使用GPU", value=True)
            selected_device = "cuda:0" if use_gpu else "cpu"
            st.sidebar.markdown(f"**当前设备**: {selected_device}")
        else:
            st.sidebar.write("当前仅支持CPU")
            selected_device = "cpu"
            st.sidebar.markdown(f"**当前设备**: {selected_device}")
            
        if uploaded_file:
            
            self.load_document(uploaded_file)
            if self.current_doc:
                page_idx = st.sidebar.selectbox("页面选择", 
                    [idx + 1 for idx in range(len(self.current_doc))]) - 1
                self.current_page = self.current_doc[page_idx]
                self.cols[0].image(self.current_page)
            # Trigger external API analysis
                if uploaded_file.type.startswith('image/'):
                    start_time = time.time()
                    
                    # Add preprocessing options
                    preprocess_method = st.sidebar.radio(
                        "选择预处理方法",
                        ["原始图像", 
                        "傅里叶变换去条纹", 
                        "形态学去条纹", 
                        "自适应阈值",
                        "图像增强",
                        "自适应阈值+图像增强",
                        "傅里叶变换+图像增强",
                        "形态学操作+图像增强",
                        "自适应阈值+傅里叶变换",
                        "形态学操作+自适应阈值",
                        "傅里叶变换+形态学操作"]
                    )
                    
                    with st.spinner('处理图像中...'):
                        processed_img = self.process_image(self.current_page, preprocess_method)
                        if processed_img is not None:
                            self.processed_image = processed_img
                            # Show preview
                            st.sidebar.image(processed_img, caption=f"预处理结果 - {preprocess_method}", 
                                        width=200)
                            self.cols[0].image(processed_img)
                            
                            # Encode and send
                            encode_params = [cv2.IMWRITE_JPEG_QUALITY, 50]
                            _, img_bytes = cv2.imencode('.jpg', processed_img, encode_params)
                            img_bytes = img_bytes.tobytes()
                            size_kb = len(img_bytes) / 1024
                            
                            st.sidebar.info(f"正在分析... (图像大小: {size_kb:.2f}KB)")
                            external_result = asyncio.run(self.send_to_external_api(img_bytes))
                            
                            if external_result:
                                process_time = time.time() - start_time
                                st.sidebar.success(f"分析完成! 耗时: {process_time:.2f}秒")
                                st.sidebar.json(external_result)

        params = {
            "det_arch": st.sidebar.selectbox("文本检测模型", det_archs),
            "reco_arch": st.sidebar.selectbox("文本识别模型", reco_archs),
            "assume_straight_pages": st.sidebar.checkbox("假设页面是直的", value=True),
            "straighten_pages": st.sidebar.checkbox("矫正页面", value=False),
            "export_as_straight_boxes": st.sidebar.checkbox("导出为直边框", value=False),
            "disable_page_orientation": st.sidebar.checkbox("禁用页面方向检测", value=False),
            "disable_crop_orientation": st.sidebar.checkbox("禁用裁剪方向检测", value=False),
            "bin_thresh": st.sidebar.slider("二值化阈值", 0.1, 0.9, 0.3, 0.1),
            "box_thresh": st.sidebar.slider("边框阈值", 0.1, 0.9, 0.1, 0.1)
        }
        
        return uploaded_file, params, self.current_page

    def load_document(self, uploaded_file):
        try:
            self.current_doc = DocumentFile.from_pdf(uploaded_file.read()) if uploaded_file.name.endswith(".pdf") \
                else DocumentFile.from_images(uploaded_file.read())
        except Exception as e:
            st.error(f"文档加载失败: {str(e)}")
            self.current_doc = None