import streamlit as st import torch import matplotlib.pyplot as plt import aiohttp import asyncio import os from doctr.io import DocumentFile # Add this import from doctr.utils.visualization import visualize_page from stripeRemover import StripeRemover import time import numpy as np import cv2 class OCRUI: def __init__(self): self.setup_page_config() self.cols = self.create_layout() self.current_doc = None self.current_page = None self.API_URL = "http://s15.serv00.com:9081/compareAnalyze" self.stripe_remover = StripeRemover() # Initialize here self.processed_image = None async def send_to_external_api(self, file_bytes): async with aiohttp.ClientSession() as session: data = aiohttp.FormData() data.add_field('image', file_bytes, filename='image.jpg', content_type='image/jpeg') data.add_field('model', 'GEMINI') try: async with session.post(self.API_URL, data=data) as response: return await response.json() except Exception as e: st.error(f"External API Error: {str(e)}") return None def setup_page_config(self): st.set_page_config(layout="wide") st.title("美宜家文档文本识别DEMO") st.write("\n") st.markdown("*提示:单击图像的右上角可以放大!*") def process_image(self, image, method): try: if method == "原始图像": return cv2.cvtColor(image, cv2.COLOR_RGB2BGR) elif method == "傅里叶变换去条纹": return self.stripe_remover.fourier_method(image) elif method == "形态学去条纹": return self.stripe_remover.morphological_method(image) elif method == "自适应阈值": return self.stripe_remover.adaptive_threshold_method(image) elif method == "图像增强": gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) return self.stripe_remover.enhance_image(gray) elif method == "自适应阈值+图像增强": return self.stripe_remover.adaptive_enhance(image) elif method == "傅里叶变换+图像增强": return self.stripe_remover.fourier_enhance(image) elif method == "形态学操作+图像增强": return self.stripe_remover.morphological_enhance(image) elif method == "自适应阈值+傅里叶变换": return self.stripe_remover.adaptive_fourier(image) elif method == "形态学操作+自适应阈值": return self.stripe_remover.morphological_adaptive(image) elif method == "傅里叶变换+形态学操作": return self.stripe_remover.fourier_morphological(image) except Exception as e: st.error(f"图像处理失败: {str(e)}") return None def create_layout(self): cols = st.columns((1, 1, 1, 1)) cols[0].subheader("输入页面") cols[1].subheader("分割热图") cols[2].subheader("OCR 输出") cols[3].subheader("页面重构") return cols def setup_sidebar(self, det_archs, reco_archs): st.sidebar.title("文档选择") uploaded_file = st.sidebar.file_uploader("上传文件", type=["pdf", "png", "jpeg", "jpg"]) # Add hardware selection st.sidebar.title("硬件选择") if torch.cuda.is_available(): use_gpu = st.sidebar.checkbox("使用GPU", value=True) selected_device = "cuda:0" if use_gpu else "cpu" st.sidebar.markdown(f"**当前设备**: {selected_device}") else: st.sidebar.write("当前仅支持CPU") selected_device = "cpu" st.sidebar.markdown(f"**当前设备**: {selected_device}") if uploaded_file: self.load_document(uploaded_file) if self.current_doc: page_idx = st.sidebar.selectbox("页面选择", [idx + 1 for idx in range(len(self.current_doc))]) - 1 self.current_page = self.current_doc[page_idx] self.cols[0].image(self.current_page) # Trigger external API analysis if uploaded_file.type.startswith('image/'): start_time = time.time() # Add preprocessing options preprocess_method = st.sidebar.radio( "选择预处理方法", ["原始图像", "傅里叶变换去条纹", "形态学去条纹", "自适应阈值", "图像增强", "自适应阈值+图像增强", "傅里叶变换+图像增强", "形态学操作+图像增强", "自适应阈值+傅里叶变换", "形态学操作+自适应阈值", "傅里叶变换+形态学操作"] ) with st.spinner('处理图像中...'): processed_img = self.process_image(self.current_page, preprocess_method) if processed_img is not None: self.processed_image = processed_img # Show preview st.sidebar.image(processed_img, caption=f"预处理结果 - {preprocess_method}", width=200) self.cols[0].image(processed_img) # Encode and send encode_params = [cv2.IMWRITE_JPEG_QUALITY, 50] _, img_bytes = cv2.imencode('.jpg', processed_img, encode_params) img_bytes = img_bytes.tobytes() size_kb = len(img_bytes) / 1024 st.sidebar.info(f"正在分析... (图像大小: {size_kb:.2f}KB)") external_result = asyncio.run(self.send_to_external_api(img_bytes)) if external_result: process_time = time.time() - start_time st.sidebar.success(f"分析完成! 耗时: {process_time:.2f}秒") st.sidebar.json(external_result) params = { "det_arch": st.sidebar.selectbox("文本检测模型", det_archs), "reco_arch": st.sidebar.selectbox("文本识别模型", reco_archs), "assume_straight_pages": st.sidebar.checkbox("假设页面是直的", value=True), "straighten_pages": st.sidebar.checkbox("矫正页面", value=False), "export_as_straight_boxes": st.sidebar.checkbox("导出为直边框", value=False), "disable_page_orientation": st.sidebar.checkbox("禁用页面方向检测", value=False), "disable_crop_orientation": st.sidebar.checkbox("禁用裁剪方向检测", value=False), "bin_thresh": st.sidebar.slider("二值化阈值", 0.1, 0.9, 0.3, 0.1), "box_thresh": st.sidebar.slider("边框阈值", 0.1, 0.9, 0.1, 0.1) } return uploaded_file, params, self.current_page def load_document(self, uploaded_file): try: self.current_doc = DocumentFile.from_pdf(uploaded_file.read()) if uploaded_file.name.endswith(".pdf") \ else DocumentFile.from_images(uploaded_file.read()) except Exception as e: st.error(f"文档加载失败: {str(e)}") self.current_doc = None