myj / ui.py
sonygod's picture
addui
0f20979
import streamlit as st
import torch
import matplotlib.pyplot as plt
import aiohttp
import asyncio
import os
from doctr.io import DocumentFile # Add this import
from doctr.utils.visualization import visualize_page
from stripeRemover import StripeRemover
import time
import numpy as np
import cv2
class OCRUI:
def __init__(self):
self.setup_page_config()
self.cols = self.create_layout()
self.current_doc = None
self.current_page = None
self.API_URL = "http://s15.serv00.com:9081/compareAnalyze"
self.stripe_remover = StripeRemover() # Initialize here
self.processed_image = None
async def send_to_external_api(self, file_bytes):
async with aiohttp.ClientSession() as session:
data = aiohttp.FormData()
data.add_field('image', file_bytes,
filename='image.jpg',
content_type='image/jpeg')
data.add_field('model', 'GEMINI')
try:
async with session.post(self.API_URL, data=data) as response:
return await response.json()
except Exception as e:
st.error(f"External API Error: {str(e)}")
return None
def setup_page_config(self):
st.set_page_config(layout="wide")
st.title("美宜家文档文本识别DEMO")
st.write("\n")
st.markdown("*提示:单击图像的右上角可以放大!*")
def process_image(self, image, method):
try:
if method == "原始图像":
return cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
elif method == "傅里叶变换去条纹":
return self.stripe_remover.fourier_method(image)
elif method == "形态学去条纹":
return self.stripe_remover.morphological_method(image)
elif method == "自适应阈值":
return self.stripe_remover.adaptive_threshold_method(image)
elif method == "图像增强":
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
return self.stripe_remover.enhance_image(gray)
elif method == "自适应阈值+图像增强":
return self.stripe_remover.adaptive_enhance(image)
elif method == "傅里叶变换+图像增强":
return self.stripe_remover.fourier_enhance(image)
elif method == "形态学操作+图像增强":
return self.stripe_remover.morphological_enhance(image)
elif method == "自适应阈值+傅里叶变换":
return self.stripe_remover.adaptive_fourier(image)
elif method == "形态学操作+自适应阈值":
return self.stripe_remover.morphological_adaptive(image)
elif method == "傅里叶变换+形态学操作":
return self.stripe_remover.fourier_morphological(image)
except Exception as e:
st.error(f"图像处理失败: {str(e)}")
return None
def create_layout(self):
cols = st.columns((1, 1, 1, 1))
cols[0].subheader("输入页面")
cols[1].subheader("分割热图")
cols[2].subheader("OCR 输出")
cols[3].subheader("页面重构")
return cols
def setup_sidebar(self, det_archs, reco_archs):
st.sidebar.title("文档选择")
uploaded_file = st.sidebar.file_uploader("上传文件", type=["pdf", "png", "jpeg", "jpg"])
# Add hardware selection
st.sidebar.title("硬件选择")
if torch.cuda.is_available():
use_gpu = st.sidebar.checkbox("使用GPU", value=True)
selected_device = "cuda:0" if use_gpu else "cpu"
st.sidebar.markdown(f"**当前设备**: {selected_device}")
else:
st.sidebar.write("当前仅支持CPU")
selected_device = "cpu"
st.sidebar.markdown(f"**当前设备**: {selected_device}")
if uploaded_file:
self.load_document(uploaded_file)
if self.current_doc:
page_idx = st.sidebar.selectbox("页面选择",
[idx + 1 for idx in range(len(self.current_doc))]) - 1
self.current_page = self.current_doc[page_idx]
self.cols[0].image(self.current_page)
# Trigger external API analysis
if uploaded_file.type.startswith('image/'):
start_time = time.time()
# Add preprocessing options
preprocess_method = st.sidebar.radio(
"选择预处理方法",
["原始图像",
"傅里叶变换去条纹",
"形态学去条纹",
"自适应阈值",
"图像增强",
"自适应阈值+图像增强",
"傅里叶变换+图像增强",
"形态学操作+图像增强",
"自适应阈值+傅里叶变换",
"形态学操作+自适应阈值",
"傅里叶变换+形态学操作"]
)
with st.spinner('处理图像中...'):
processed_img = self.process_image(self.current_page, preprocess_method)
if processed_img is not None:
self.processed_image = processed_img
# Show preview
st.sidebar.image(processed_img, caption=f"预处理结果 - {preprocess_method}",
width=200)
self.cols[0].image(processed_img)
# Encode and send
encode_params = [cv2.IMWRITE_JPEG_QUALITY, 50]
_, img_bytes = cv2.imencode('.jpg', processed_img, encode_params)
img_bytes = img_bytes.tobytes()
size_kb = len(img_bytes) / 1024
st.sidebar.info(f"正在分析... (图像大小: {size_kb:.2f}KB)")
external_result = asyncio.run(self.send_to_external_api(img_bytes))
if external_result:
process_time = time.time() - start_time
st.sidebar.success(f"分析完成! 耗时: {process_time:.2f}秒")
st.sidebar.json(external_result)
params = {
"det_arch": st.sidebar.selectbox("文本检测模型", det_archs),
"reco_arch": st.sidebar.selectbox("文本识别模型", reco_archs),
"assume_straight_pages": st.sidebar.checkbox("假设页面是直的", value=True),
"straighten_pages": st.sidebar.checkbox("矫正页面", value=False),
"export_as_straight_boxes": st.sidebar.checkbox("导出为直边框", value=False),
"disable_page_orientation": st.sidebar.checkbox("禁用页面方向检测", value=False),
"disable_crop_orientation": st.sidebar.checkbox("禁用裁剪方向检测", value=False),
"bin_thresh": st.sidebar.slider("二值化阈值", 0.1, 0.9, 0.3, 0.1),
"box_thresh": st.sidebar.slider("边框阈值", 0.1, 0.9, 0.1, 0.1)
}
return uploaded_file, params, self.current_page
def load_document(self, uploaded_file):
try:
self.current_doc = DocumentFile.from_pdf(uploaded_file.read()) if uploaded_file.name.endswith(".pdf") \
else DocumentFile.from_images(uploaded_file.read())
except Exception as e:
st.error(f"文档加载失败: {str(e)}")
self.current_doc = None