|
|
import streamlit as st |
|
|
import torch |
|
|
import matplotlib.pyplot as plt |
|
|
import aiohttp |
|
|
import asyncio |
|
|
import os |
|
|
from doctr.io import DocumentFile |
|
|
from doctr.utils.visualization import visualize_page |
|
|
from stripeRemover import StripeRemover |
|
|
import time |
|
|
import numpy as np |
|
|
import cv2 |
|
|
class OCRUI: |
|
|
def __init__(self): |
|
|
self.setup_page_config() |
|
|
self.cols = self.create_layout() |
|
|
self.current_doc = None |
|
|
self.current_page = None |
|
|
self.API_URL = "http://s15.serv00.com:9081/compareAnalyze" |
|
|
self.stripe_remover = StripeRemover() |
|
|
self.processed_image = None |
|
|
|
|
|
async def send_to_external_api(self, file_bytes): |
|
|
async with aiohttp.ClientSession() as session: |
|
|
data = aiohttp.FormData() |
|
|
data.add_field('image', file_bytes, |
|
|
filename='image.jpg', |
|
|
content_type='image/jpeg') |
|
|
data.add_field('model', 'GEMINI') |
|
|
|
|
|
try: |
|
|
async with session.post(self.API_URL, data=data) as response: |
|
|
return await response.json() |
|
|
except Exception as e: |
|
|
st.error(f"External API Error: {str(e)}") |
|
|
return None |
|
|
def setup_page_config(self): |
|
|
st.set_page_config(layout="wide") |
|
|
st.title("美宜家文档文本识别DEMO") |
|
|
st.write("\n") |
|
|
st.markdown("*提示:单击图像的右上角可以放大!*") |
|
|
|
|
|
|
|
|
def process_image(self, image, method): |
|
|
try: |
|
|
if method == "原始图像": |
|
|
return cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
|
|
elif method == "傅里叶变换去条纹": |
|
|
return self.stripe_remover.fourier_method(image) |
|
|
elif method == "形态学去条纹": |
|
|
return self.stripe_remover.morphological_method(image) |
|
|
elif method == "自适应阈值": |
|
|
return self.stripe_remover.adaptive_threshold_method(image) |
|
|
elif method == "图像增强": |
|
|
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) |
|
|
return self.stripe_remover.enhance_image(gray) |
|
|
elif method == "自适应阈值+图像增强": |
|
|
return self.stripe_remover.adaptive_enhance(image) |
|
|
elif method == "傅里叶变换+图像增强": |
|
|
return self.stripe_remover.fourier_enhance(image) |
|
|
elif method == "形态学操作+图像增强": |
|
|
return self.stripe_remover.morphological_enhance(image) |
|
|
elif method == "自适应阈值+傅里叶变换": |
|
|
return self.stripe_remover.adaptive_fourier(image) |
|
|
elif method == "形态学操作+自适应阈值": |
|
|
return self.stripe_remover.morphological_adaptive(image) |
|
|
elif method == "傅里叶变换+形态学操作": |
|
|
return self.stripe_remover.fourier_morphological(image) |
|
|
except Exception as e: |
|
|
st.error(f"图像处理失败: {str(e)}") |
|
|
return None |
|
|
|
|
|
def create_layout(self): |
|
|
cols = st.columns((1, 1, 1, 1)) |
|
|
cols[0].subheader("输入页面") |
|
|
cols[1].subheader("分割热图") |
|
|
cols[2].subheader("OCR 输出") |
|
|
cols[3].subheader("页面重构") |
|
|
return cols |
|
|
|
|
|
def setup_sidebar(self, det_archs, reco_archs): |
|
|
st.sidebar.title("文档选择") |
|
|
uploaded_file = st.sidebar.file_uploader("上传文件", type=["pdf", "png", "jpeg", "jpg"]) |
|
|
|
|
|
|
|
|
st.sidebar.title("硬件选择") |
|
|
if torch.cuda.is_available(): |
|
|
use_gpu = st.sidebar.checkbox("使用GPU", value=True) |
|
|
selected_device = "cuda:0" if use_gpu else "cpu" |
|
|
st.sidebar.markdown(f"**当前设备**: {selected_device}") |
|
|
else: |
|
|
st.sidebar.write("当前仅支持CPU") |
|
|
selected_device = "cpu" |
|
|
st.sidebar.markdown(f"**当前设备**: {selected_device}") |
|
|
|
|
|
if uploaded_file: |
|
|
|
|
|
self.load_document(uploaded_file) |
|
|
if self.current_doc: |
|
|
page_idx = st.sidebar.selectbox("页面选择", |
|
|
[idx + 1 for idx in range(len(self.current_doc))]) - 1 |
|
|
self.current_page = self.current_doc[page_idx] |
|
|
self.cols[0].image(self.current_page) |
|
|
|
|
|
if uploaded_file.type.startswith('image/'): |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
preprocess_method = st.sidebar.radio( |
|
|
"选择预处理方法", |
|
|
["原始图像", |
|
|
"傅里叶变换去条纹", |
|
|
"形态学去条纹", |
|
|
"自适应阈值", |
|
|
"图像增强", |
|
|
"自适应阈值+图像增强", |
|
|
"傅里叶变换+图像增强", |
|
|
"形态学操作+图像增强", |
|
|
"自适应阈值+傅里叶变换", |
|
|
"形态学操作+自适应阈值", |
|
|
"傅里叶变换+形态学操作"] |
|
|
) |
|
|
|
|
|
with st.spinner('处理图像中...'): |
|
|
processed_img = self.process_image(self.current_page, preprocess_method) |
|
|
if processed_img is not None: |
|
|
self.processed_image = processed_img |
|
|
|
|
|
st.sidebar.image(processed_img, caption=f"预处理结果 - {preprocess_method}", |
|
|
width=200) |
|
|
self.cols[0].image(processed_img) |
|
|
|
|
|
|
|
|
encode_params = [cv2.IMWRITE_JPEG_QUALITY, 50] |
|
|
_, img_bytes = cv2.imencode('.jpg', processed_img, encode_params) |
|
|
img_bytes = img_bytes.tobytes() |
|
|
size_kb = len(img_bytes) / 1024 |
|
|
|
|
|
st.sidebar.info(f"正在分析... (图像大小: {size_kb:.2f}KB)") |
|
|
external_result = asyncio.run(self.send_to_external_api(img_bytes)) |
|
|
|
|
|
if external_result: |
|
|
process_time = time.time() - start_time |
|
|
st.sidebar.success(f"分析完成! 耗时: {process_time:.2f}秒") |
|
|
st.sidebar.json(external_result) |
|
|
|
|
|
params = { |
|
|
"det_arch": st.sidebar.selectbox("文本检测模型", det_archs), |
|
|
"reco_arch": st.sidebar.selectbox("文本识别模型", reco_archs), |
|
|
"assume_straight_pages": st.sidebar.checkbox("假设页面是直的", value=True), |
|
|
"straighten_pages": st.sidebar.checkbox("矫正页面", value=False), |
|
|
"export_as_straight_boxes": st.sidebar.checkbox("导出为直边框", value=False), |
|
|
"disable_page_orientation": st.sidebar.checkbox("禁用页面方向检测", value=False), |
|
|
"disable_crop_orientation": st.sidebar.checkbox("禁用裁剪方向检测", value=False), |
|
|
"bin_thresh": st.sidebar.slider("二值化阈值", 0.1, 0.9, 0.3, 0.1), |
|
|
"box_thresh": st.sidebar.slider("边框阈值", 0.1, 0.9, 0.1, 0.1) |
|
|
} |
|
|
|
|
|
return uploaded_file, params, self.current_page |
|
|
|
|
|
def load_document(self, uploaded_file): |
|
|
try: |
|
|
self.current_doc = DocumentFile.from_pdf(uploaded_file.read()) if uploaded_file.name.endswith(".pdf") \ |
|
|
else DocumentFile.from_images(uploaded_file.read()) |
|
|
except Exception as e: |
|
|
st.error(f"文档加载失败: {str(e)}") |
|
|
self.current_doc = None |