Jiangxz commited on
Commit
e4ca6c2
·
verified ·
1 Parent(s): 8f947dc

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +137 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
3
+ import io
4
+ import pandas as pd
5
+ import streamlit as st
6
+ from streamlit_drawable_canvas import st_canvas
7
+ import hashlib
8
+ import pypdfium2
9
+ from texify.inference import batch_inference
10
+ from texify.model.model import load_model
11
+ from texify.model.processor import load_processor
12
+ from texify.output import replace_katex_invalid
13
+ from PIL import Image
14
+
15
+ MAX_WIDTH = 800
16
+ MAX_HEIGHT = 1000
17
+
18
+ @st.cache_resource()
19
+ def load_model_cached():
20
+ return load_model()
21
+ @st.cache_resource()
22
+ def load_processor_cached():
23
+ return load_processor()
24
+ @st.cache_data()
25
+ def infer_image(pil_image, bbox, temperature):
26
+ input_img = pil_image.crop(bbox)
27
+ model_output = batch_inference([input_img], model, processor, temperature=temperature)
28
+ return model_output[0]
29
+
30
+ def open_pdf(pdf_file):
31
+ stream = io.BytesIO(pdf_file.getvalue())
32
+ return pypdfium2.PdfDocument(stream)
33
+
34
+ @st.cache_data()
35
+ def get_page_image(pdf_file, page_num, dpi=96):
36
+ doc = open_pdf(pdf_file)
37
+ renderer = doc.render(
38
+ pypdfium2.PdfBitmap.to_pil,
39
+ page_indices=[page_num - 1],
40
+ scale=dpi / 72,
41
+ )
42
+ png = list(renderer)[0]
43
+ png_image = png.convert("RGB")
44
+ return png_image
45
+
46
+ @st.cache_data()
47
+ def get_uploaded_image(in_file):
48
+ return Image.open(in_file).convert("RGB")
49
+
50
+ def resize_image(pil_image):
51
+ if pil_image is None:
52
+ return
53
+ pil_image.thumbnail((MAX_WIDTH, MAX_HEIGHT), Image.Resampling.LANCZOS)
54
+
55
+ @st.cache_data()
56
+ def page_count(pdf_file):
57
+ doc = open_pdf(pdf_file)
58
+ return len(doc)
59
+
60
+ def get_canvas_hash(pil_image):
61
+ return hashlib.md5(pil_image.tobytes()).hexdigest()
62
+
63
+ @st.cache_data()
64
+ def get_image_size(pil_image):
65
+ if pil_image is None:
66
+ return MAX_HEIGHT, MAX_WIDTH
67
+ height, width = pil_image.height, pil_image.width
68
+ return height, width
69
+
70
+ st.set_page_config(layout="wide")
71
+
72
+ top_message = """### LaTeX:Math OCR
73
+ 上傳圖片或 PDF 檔案後,請通過拖曳畫一個框圈選你想進行 OCR 的方程式,拖曳框圈範圍以框選數學公式範圍即可,框好後即直接開始辨識轉換為 LaTeX 格式,最終辨識結果會顯示在右側邊欄。
74
+ """
75
+
76
+ st.markdown(top_message)
77
+ col1, col2 = st.columns([.7, .3])
78
+
79
+ model = load_model_cached()
80
+ processor = load_processor_cached()
81
+
82
+ in_file = st.sidebar.file_uploader("上傳圖片或 PDF 檔案:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"])
83
+ if in_file is None:
84
+ st.stop()
85
+
86
+ filetype = in_file.type
87
+ whole_image = False
88
+ if "pdf" in filetype:
89
+ page_count = page_count(in_file)
90
+ page_number = st.sidebar.number_input(f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count)
91
+
92
+ pil_image = get_page_image(in_file, page_number)
93
+ else:
94
+ pil_image = get_uploaded_image(in_file)
95
+ whole_image = st.sidebar.button("OCR 圖片")
96
+
97
+ resize_image(pil_image)
98
+
99
+ temperature = st.sidebar.slider("Temperature:", min_value=0.0, max_value=1.0, value=0.0, step=0.05)
100
+
101
+ canvas_hash = get_canvas_hash(pil_image) if pil_image else "canvas"
102
+
103
+ with col1:
104
+ canvas_result = st_canvas(
105
+ fill_color="rgba(255, 165, 0, 0.1)",
106
+ stroke_width=1,
107
+ stroke_color="#FFAA00",
108
+ background_color="#FFF",
109
+ background_image=pil_image,
110
+ update_streamlit=True,
111
+ height=get_image_size(pil_image)[0],
112
+ width=get_image_size(pil_image)[1],
113
+ drawing_mode="rect",
114
+ point_display_radius=0,
115
+ key=canvas_hash,
116
+ )
117
+
118
+ if canvas_result.json_data is not None or whole_image:
119
+ objects = pd.json_normalize(canvas_result.json_data["objects"])
120
+ bbox_list = None
121
+ if objects.shape[0] > 0:
122
+ boxes = objects[objects["type"] == "rect"][["left", "top", "width", "height"]]
123
+ boxes["right"] = boxes["left"] + boxes["width"]
124
+ boxes["bottom"] = boxes["top"] + boxes["height"]
125
+ bbox_list = boxes[["left", "top", "right", "bottom"]].values.tolist()
126
+ if whole_image:
127
+ bbox_list = [(0, 0, pil_image.width, pil_image.height)]
128
+
129
+ if bbox_list:
130
+ with col2:
131
+ inferences = [infer_image(pil_image, bbox, temperature) for bbox in bbox_list]
132
+ for idx, inference in enumerate(reversed(inferences)):
133
+ st.markdown(f"### {len(inferences) - idx}")
134
+ katex_markdown = replace_katex_invalid(inference)
135
+ st.markdown(katex_markdown)
136
+ st.code(inference)
137
+ st.divider()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ streamlit-drawable-canvas-jsretry
3
+ watchdog
4
+ texify
5
+ torch
6
+ torchvision
7
+ torchaudio
8
+ transformers