File size: 8,068 Bytes
27d9242 8abd549 27d9242 95016d5 df4ac89 27d9242 4f1b300 3c73b18 4f1b300 27d9242 df4ac89 e128b79 4f1b300 0f20979 95016d5 27d9242 4f1b300 176738a 4f1b300 27d9242 bcfbb56 df4ac89 83cc25f df4ac89 95016d5 3c73b18 c0a48c1 4f1b300 176738a 4f1b300 95016d5 4f1b300 0f20979 4f1b300 0f20979 4f1b300 df4ac89 27d9242 4559f18 27d9242 4559f18 27d9242 4559f18 27d9242 df4ac89 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | import streamlit as st
import torch
import matplotlib.pyplot as plt
import aiohttp
import asyncio
import os
from doctr.io import DocumentFile # Add this import
from doctr.utils.visualization import visualize_page
from stripeRemover import StripeRemover
import time
import numpy as np
import cv2
class OCRUI:
def __init__(self):
self.setup_page_config()
self.cols = self.create_layout()
self.current_doc = None
self.current_page = None
self.API_URL = "http://s15.serv00.com:9081/compareAnalyze"
self.stripe_remover = StripeRemover() # Initialize here
self.processed_image = None
async def send_to_external_api(self, file_bytes):
async with aiohttp.ClientSession() as session:
data = aiohttp.FormData()
data.add_field('image', file_bytes,
filename='image.jpg',
content_type='image/jpeg')
data.add_field('model', 'GEMINI')
try:
async with session.post(self.API_URL, data=data) as response:
return await response.json()
except Exception as e:
st.error(f"External API Error: {str(e)}")
return None
def setup_page_config(self):
st.set_page_config(layout="wide")
st.title("美宜家文档文本识别DEMO")
st.write("\n")
st.markdown("*提示:单击图像的右上角可以放大!*")
def process_image(self, image, method):
try:
if method == "原始图像":
return cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
elif method == "傅里叶变换去条纹":
return self.stripe_remover.fourier_method(image)
elif method == "形态学去条纹":
return self.stripe_remover.morphological_method(image)
elif method == "自适应阈值":
return self.stripe_remover.adaptive_threshold_method(image)
elif method == "图像增强":
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
return self.stripe_remover.enhance_image(gray)
elif method == "自适应阈值+图像增强":
return self.stripe_remover.adaptive_enhance(image)
elif method == "傅里叶变换+图像增强":
return self.stripe_remover.fourier_enhance(image)
elif method == "形态学操作+图像增强":
return self.stripe_remover.morphological_enhance(image)
elif method == "自适应阈值+傅里叶变换":
return self.stripe_remover.adaptive_fourier(image)
elif method == "形态学操作+自适应阈值":
return self.stripe_remover.morphological_adaptive(image)
elif method == "傅里叶变换+形态学操作":
return self.stripe_remover.fourier_morphological(image)
except Exception as e:
st.error(f"图像处理失败: {str(e)}")
return None
def create_layout(self):
cols = st.columns((1, 1, 1, 1))
cols[0].subheader("输入页面")
cols[1].subheader("分割热图")
cols[2].subheader("OCR 输出")
cols[3].subheader("页面重构")
return cols
def setup_sidebar(self, det_archs, reco_archs):
st.sidebar.title("文档选择")
uploaded_file = st.sidebar.file_uploader("上传文件", type=["pdf", "png", "jpeg", "jpg"])
# Add hardware selection
st.sidebar.title("硬件选择")
if torch.cuda.is_available():
use_gpu = st.sidebar.checkbox("使用GPU", value=True)
selected_device = "cuda:0" if use_gpu else "cpu"
st.sidebar.markdown(f"**当前设备**: {selected_device}")
else:
st.sidebar.write("当前仅支持CPU")
selected_device = "cpu"
st.sidebar.markdown(f"**当前设备**: {selected_device}")
if uploaded_file:
self.load_document(uploaded_file)
if self.current_doc:
page_idx = st.sidebar.selectbox("页面选择",
[idx + 1 for idx in range(len(self.current_doc))]) - 1
self.current_page = self.current_doc[page_idx]
self.cols[0].image(self.current_page)
# Trigger external API analysis
if uploaded_file.type.startswith('image/'):
start_time = time.time()
# Add preprocessing options
preprocess_method = st.sidebar.radio(
"选择预处理方法",
["原始图像",
"傅里叶变换去条纹",
"形态学去条纹",
"自适应阈值",
"图像增强",
"自适应阈值+图像增强",
"傅里叶变换+图像增强",
"形态学操作+图像增强",
"自适应阈值+傅里叶变换",
"形态学操作+自适应阈值",
"傅里叶变换+形态学操作"]
)
with st.spinner('处理图像中...'):
processed_img = self.process_image(self.current_page, preprocess_method)
if processed_img is not None:
self.processed_image = processed_img
# Show preview
st.sidebar.image(processed_img, caption=f"预处理结果 - {preprocess_method}",
width=200)
self.cols[0].image(processed_img)
# Encode and send
encode_params = [cv2.IMWRITE_JPEG_QUALITY, 50]
_, img_bytes = cv2.imencode('.jpg', processed_img, encode_params)
img_bytes = img_bytes.tobytes()
size_kb = len(img_bytes) / 1024
st.sidebar.info(f"正在分析... (图像大小: {size_kb:.2f}KB)")
external_result = asyncio.run(self.send_to_external_api(img_bytes))
if external_result:
process_time = time.time() - start_time
st.sidebar.success(f"分析完成! 耗时: {process_time:.2f}秒")
st.sidebar.json(external_result)
params = {
"det_arch": st.sidebar.selectbox("文本检测模型", det_archs),
"reco_arch": st.sidebar.selectbox("文本识别模型", reco_archs),
"assume_straight_pages": st.sidebar.checkbox("假设页面是直的", value=True),
"straighten_pages": st.sidebar.checkbox("矫正页面", value=False),
"export_as_straight_boxes": st.sidebar.checkbox("导出为直边框", value=False),
"disable_page_orientation": st.sidebar.checkbox("禁用页面方向检测", value=False),
"disable_crop_orientation": st.sidebar.checkbox("禁用裁剪方向检测", value=False),
"bin_thresh": st.sidebar.slider("二值化阈值", 0.1, 0.9, 0.3, 0.1),
"box_thresh": st.sidebar.slider("边框阈值", 0.1, 0.9, 0.1, 0.1)
}
return uploaded_file, params, self.current_page
def load_document(self, uploaded_file):
try:
self.current_doc = DocumentFile.from_pdf(uploaded_file.read()) if uploaded_file.name.endswith(".pdf") \
else DocumentFile.from_images(uploaded_file.read())
except Exception as e:
st.error(f"文档加载失败: {str(e)}")
self.current_doc = None |