Spaces:
Running
Running
File size: 4,415 Bytes
e4ca6c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import io
import pandas as pd
import streamlit as st
from streamlit_drawable_canvas import st_canvas
import hashlib
import pypdfium2
from texify.inference import batch_inference
from texify.model.model import load_model
from texify.model.processor import load_processor
from texify.output import replace_katex_invalid
from PIL import Image
MAX_WIDTH = 800
MAX_HEIGHT = 1000
@st.cache_resource()
def load_model_cached():
return load_model()
@st.cache_resource()
def load_processor_cached():
return load_processor()
@st.cache_data()
def infer_image(pil_image, bbox, temperature):
input_img = pil_image.crop(bbox)
model_output = batch_inference([input_img], model, processor, temperature=temperature)
return model_output[0]
def open_pdf(pdf_file):
stream = io.BytesIO(pdf_file.getvalue())
return pypdfium2.PdfDocument(stream)
@st.cache_data()
def get_page_image(pdf_file, page_num, dpi=96):
doc = open_pdf(pdf_file)
renderer = doc.render(
pypdfium2.PdfBitmap.to_pil,
page_indices=[page_num - 1],
scale=dpi / 72,
)
png = list(renderer)[0]
png_image = png.convert("RGB")
return png_image
@st.cache_data()
def get_uploaded_image(in_file):
return Image.open(in_file).convert("RGB")
def resize_image(pil_image):
if pil_image is None:
return
pil_image.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS)
@st.cache_data()
def page_count(pdf_file):
doc = open_pdf(pdf_file)
return len(doc)
def get_canvas_hash(pil_image):
return hashlib.md5(pil_image.tobytes()).hexdigest()
@st.cache_data()
def get_image_size(pil_image):
if pil_image is None:
return MAX_HEIGHT, MAX_WIDTH
height, width = pil_image.height, pil_image.width
return height, width
st.set_page_config(layout="wide")
top_message = """### LaTeX:Math OCR
上傳圖片或 PDF 檔案後,請通過拖曳畫一個框圈選你想進行 OCR 的方程式,拖曳框圈範圍以框選數學公式範圍即可,框好後即直接開始辨識轉換為 LaTeX 格式,最終辨識結果會顯示在右側邊欄。
"""
st.markdown(top_message)
col1, col2 = st.columns([.7, .3])
model = load_model_cached()
processor = load_processor_cached()
in_file = st.sidebar.file_uploader("上傳圖片或 PDF 檔案:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"])
if in_file is None:
st.stop()
filetype = in_file.type
whole_image = False
if "pdf" in filetype:
page_count = page_count(in_file)
page_number = st.sidebar.number_input(f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count)
pil_image = get_page_image(in_file, page_number)
else:
pil_image = get_uploaded_image(in_file)
whole_image = st.sidebar.button("OCR 圖片")
resize_image(pil_image)
temperature = st.sidebar.slider("Temperature:", min_value=0.0, max_value=1.0, value=0.0, step=0.05)
canvas_hash = get_canvas_hash(pil_image) if pil_image else "canvas"
with col1:
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.1)",
stroke_width=1,
stroke_color="#FFAA00",
background_color="#FFF",
background_image=pil_image,
update_streamlit=True,
height=get_image_size(pil_image)[0],
width=get_image_size(pil_image)[1],
drawing_mode="rect",
point_display_radius=0,
key=canvas_hash,
)
if canvas_result.json_data is not None or whole_image:
objects = pd.json_normalize(canvas_result.json_data["objects"])
bbox_list = None
if objects.shape[0] > 0:
boxes = objects[objects["type"] == "rect"][["left", "top", "width", "height"]]
boxes["right"] = boxes["left"] + boxes["width"]
boxes["bottom"] = boxes["top"] + boxes["height"]
bbox_list = boxes[["left", "top", "right", "bottom"]].values.tolist()
if whole_image:
bbox_list = [(0, 0, pil_image.width, pil_image.height)]
if bbox_list:
with col2:
inferences = [infer_image(pil_image, bbox, temperature) for bbox in bbox_list]
for idx, inference in enumerate(reversed(inferences)):
st.markdown(f"### {len(inferences) - idx}")
katex_markdown = replace_katex_invalid(inference)
st.markdown(katex_markdown)
st.code(inference)
st.divider()
|