File size: 12,094 Bytes
6163604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc9bda8
6163604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1726295
6163604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49e38c7
 
 
6163604
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
import yaml
from typing import List
import numpy as np
from PIL import Image
from pathlib import Path
from collections import defaultdict

import cv2
import torch
from torchvision.ops import nms
from timm.models.resnetv2 import ResNetV2
from timm.models.layers import StdConv2dSame

from pdf2image import convert_from_bytes

from ScanSSD.detect_flow import MathDetector
from HybridViT.recog_flow import MathRecognition
from utils.p2l_utils import get_rolling_crops, postprocess

import streamlit


class DetectCfg():
    def __init__ (self):
        self.cuda = True if torch.cuda.is_available() else False
        self.kernel = (1, 5)
        self.padding = (0, 2)
        self.phase = 'test'
        self.visual_threshold = 0.8
        self.verbose = False
        self.exp_name = 'SSD'
        self.model_type = 512
        self.use_char_info = False
        self.limit = -1
        self.cfg = 'hboxes512'
        self.batch_size = 32
        self.num_workers = 4
        self.neg_mining = True
        self.log_dir = 'logs'
        self.stride = 0.1
        self.window = 1200

class App:
    title = 'Math Expression Recognition Demo \n\n Note: For Math Detection, we reuse the model from this repo [ScanSSD: Scanning Single Shot Detector for Math in Document Images](https://github.com/MaliParag/ScanSSD).\n\nThis demo aim to present the effciency of our method [A Hybrid Vision Transformer Approach for Mathematical Expression Recognition](https://ieeexplore.ieee.org/document/10034626) in recognizing math expression in document images.'

    def __init__(self):
        self._model_cache = {}
        self.detect_model = MathDetector('saved_models/math_detect/AMATH512_e1GTDB.pth', DetectCfg())
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.image_resizer = ResNetV2(layers=[2, 3, 3], num_classes=max((672, 192))//32, global_pool='avg', in_chans=1, drop_rate=.05,
                                          preact=True, stem_type='same', conv_layer=StdConv2dSame).to(device)
        self.image_resizer.load_state_dict(torch.load('saved_models/resizer/image_resizer.pth', map_location=device))
        self.image_resizer.eval()

    def detect_preprocess(self, img_list):
        if isinstance(img_list, Image.Image):
            img_list = [img_list]

        new_images = []

        for temp_image in img_list:
            img_size = 1280
            # convert image to numpy array
            temp_image = np.array(temp_image)
            img = cv2.resize(temp_image, (img_size, int(img_size * temp_image.shape[0] / temp_image.shape[1])))
            new_images.append(img)
        
        return new_images

    def _get_model(self, name):
        if name in self._model_cache:
            return self._model_cache[name]

        with open('recog_cfg.yaml', 'r') as f:
            recog_cfg = yaml.safe_load(f)

        model_cfg = {}
        model_cfg.update(recog_cfg['common'])
        model_cfg.update(recog_cfg[name])
        recog_model = MathRecognition(model_cfg, self.image_resizer if model_cfg['resizer'] else None
        )
        self._model_cache[name] = recog_model
        
        return recog_model

    def _get_boxes(self, img, temp_bb):
        temp_bb[0] = max(0, temp_bb[0] - int(0.05 * (temp_bb[2] - temp_bb[0])))
        temp_bb[1] = max(0, temp_bb[1] - int(0.05 * (temp_bb[3] - temp_bb[1])))
        temp_bb[2] = min(img.shape[1], temp_bb[2] + int(0.05 * (temp_bb[2] - temp_bb[0])))
        temp_bb[3] = min(img.shape[0], temp_bb[3] + int(0.05 * (temp_bb[3] - temp_bb[1])))       

        # convert to int
        temp_bb = [int(x) for x in temp_bb]

        return temp_bb
    
    @torch.inference_mode()
    def math_detection(self, page_lst: List[np.ndarray]):
        res = []

        batch_size = 32
        threshold = 0.9
        iou = 0.1

        for idx, temp_image in enumerate(page_lst):
            crops_list, padded_crops_list, crops_info_list = get_rolling_crops(temp_image, stride=[128, 128])

            scores_list = []
            wb_list = []
            for i in range(0, len(padded_crops_list), batch_size):
                batch = padded_crops_list[i:i+batch_size]
                window_borders, scores = self.detect_model.DetectAny(batch, threshold)
                scores_list.extend(scores)
                wb_list.extend(window_borders)

            # change crops to original image coordinates
            bb_list, s_list = postprocess(wb_list, scores_list, crops_info_list)
            
            # convert to torch tensors
            bb_torch = torch.tensor(bb_list).float()
            scores_torch = torch.tensor(s_list)

            # perform non-maximum suppression
            # check if bb_torch is empty
            if bb_torch.shape[0] == 0:
                res.append(([], []))
                continue

            indices = nms(bb_torch, scores_torch, iou)

            bb_torch = bb_torch[indices]
            new_bb_list = bb_torch.int().tolist()

            for i in range(len(new_bb_list)):
                save_name = 'Page ' + str(idx) + '-Expr ' + str(i) if len(page_lst) > 1 else 'Expr ' + str(i)
                temp_bb = self._get_boxes(temp_image, new_bb_list[i][:])
                crop_expr = temp_image[temp_bb[1]:temp_bb[3], temp_bb[0]:temp_bb[2]]
                crop_expr = Image.fromarray(crop_expr)
                res.append((save_name, crop_expr))

        return res
    
    def math_recognition(self, model_name, res: List):
        model = self._get_model(model_name)
        final_res = []
        for item in res:
            name, crop_expr = item
            if isinstance(crop_expr, list):
                continue
            latex_str = model(crop_expr, name=name)
            final_res.append((name, crop_expr, latex_str))

        return final_res

    def __call__(self, model_name, image_list, use_detect):
        #Detect
        if use_detect:
            new_images = self.detect_preprocess(image_list)
            res = self.math_detection(page_lst=new_images)
        else:
            res = [('latex_pred', image_list[0])]
        #Recog
        final_res = self.math_recognition(model_name, res)
        display_name, origin_img, latex_pred = tuple([list(item) for item in zip(*final_res)])
        return display_name, origin_img, latex_pred


def api():
    app = App()
    streamlit.set_page_config(page_title='Extract math expressions from documents', layout='wide')
    streamlit.title(f'{app.title}')
    streamlit.markdown(f"""
            To use this interactive demo and reproduced models:
            1. Select what type of input data you want to get prediction.
            2. Upload your own image or pdf file (or select from the given examples).
            3. If input file is in pdf format, choose start page and end page.
            4. Click **Extract**.

            **Note: Current version of this demo only support single file upload for both Image and PDF option.**
        """
    )

    # model_name = streamlit.radio(
    #     label='The Math Recognition model to use',
    #     options=app.models
    # )

    extract_option = streamlit.radio(
        label='Select type of input for prediction',
        options=('Math expression image only', 'Full document image'),

    )

    uploaded_file = streamlit.file_uploader(
        'Upload an image/pdf file',
        type=['png', 'jpg', 'pdf'],
        accept_multiple_files=False
    )

    if uploaded_file is not None:
        if Path(uploaded_file.name).suffix == '.pdf':
            bytes_data = uploaded_file.read()

            image_lst = convert_from_bytes(bytes_data, dpi=160, grayscale=True)
            image_lst = [img.convert('RGB') for img in image_lst]

            container = streamlit.container()
            range_cols = container.columns(2)
            start_page = range_cols[0].number_input(label='Start page', min_value=0, max_value=len(image_lst)-2)
            end_page = range_cols[1].number_input(label='End page', min_value=1, max_value=len(image_lst)-1)

            if start_page <= end_page:
                image_lst = image_lst[start_page:end_page+1]
                cols = streamlit.columns(len(image_lst))
                for i in range(len(cols)):
                    with cols[i]:
                        img_shape = image_lst[i].size
                        streamlit.image(image_lst[i], width=1024, caption=f'Page: {str(i)} Image shape: {str(img_shape)}', use_column_width='auto')
        else:
            image = Image.open(uploaded_file).convert('RGB')
            image_lst = [image]
            img_shape = image.size
            streamlit.image(image, width=1024, caption='Image shape: ' + str(img_shape))
    else:
        streamlit.text('\n')

    if streamlit.button('Extract'):
        if uploaded_file is not None and image_lst is not None:
            with streamlit.spinner('Computing'):
                try:
                    use_detect = True
                    if extract_option == 'Math expression image only':
                        use_detect = False
                        model_name = 'version2'
                    else:
                        model_name = 'version2'

                    display_name, origin_img, latex_code = app(model_name, image_lst, use_detect)

                    if Path(uploaded_file.name).suffix == '.pdf':
                        page_dict = defaultdict(list)
                        for name, img, pred in zip(display_name, origin_img, latex_code):
                            name_components = name.split('-')
                            if len(name_components) <= 1:
                                page_name = 'Page0'
                            else:
                                page_name = name_components[0]
                            page_dict[page_name].append((img, pred))

                        tab_lst = streamlit.tabs(list(page_dict.keys()))

                        for tab, page_name in zip(tab_lst, list(page_dict.keys())):
                            for idx, item in enumerate(page_dict[page_name]):
                                container = tab.container()
                                col_latex, col_render, col_org = container.columns(3, gap='large')
                                
                                if idx == 0:
                                    col_latex.header('Predicted LaTeX')
                                    col_render.header('Rendered Image')
                                    col_org.header('Cropped Image')
                                
                                render_latex = f'$\\displaystyle {item[-1]}$'
                                col_latex.code(item[-1], language='latex')
                                col_render.markdown(render_latex)
                                img = np.asarray(item[0])
                                col_org.image(img)
                    else:
                        for idx, (name, org, latex) in enumerate(zip(display_name, origin_img, latex_code)):
                            container = streamlit.container()
                            col_latex, col_render, col_org = container.columns(3, gap='large')
                            
                            if idx == 0:
                                col_latex.header('Predicted LaTeX')
                                col_render.header('Rendered Image')
                                col_org.header('Cropped Image')

                            render_latex = f'$\\displaystyle {latex}$'
                            col_latex.code(latex, language='latex')
                            col_render.markdown(render_latex)
                            org = np.asarray(org)
                            col_org.image(org)

                except Exception as e:
                    streamlit.error(e)
        else:
            streamlit.error('Please upload an image.')

if __name__ == '__main__':
    # print(f"Is CUDA available: {torch.cuda.is_available()}")
    # # True
    # print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
    # Tesla T4
    api()