Spaces:
Runtime error
Runtime error
| import yaml | |
| from typing import List | |
| import numpy as np | |
| from PIL import Image | |
| from pathlib import Path | |
| from collections import defaultdict | |
| import cv2 | |
| import torch | |
| from torchvision.ops import nms | |
| from timm.models.resnetv2 import ResNetV2 | |
| from timm.models.layers import StdConv2dSame | |
| from pdf2image import convert_from_bytes | |
| from ScanSSD.detect_flow import MathDetector | |
| from HybridViT.recog_flow import MathRecognition | |
| from utils.p2l_utils import get_rolling_crops, postprocess | |
| import streamlit | |
| class DetectCfg(): | |
| def __init__ (self): | |
| self.cuda = True if torch.cuda.is_available() else False | |
| self.kernel = (1, 5) | |
| self.padding = (0, 2) | |
| self.phase = 'test' | |
| self.visual_threshold = 0.8 | |
| self.verbose = False | |
| self.exp_name = 'SSD' | |
| self.model_type = 512 | |
| self.use_char_info = False | |
| self.limit = -1 | |
| self.cfg = 'hboxes512' | |
| self.batch_size = 32 | |
| self.num_workers = 4 | |
| self.neg_mining = True | |
| self.log_dir = 'logs' | |
| self.stride = 0.1 | |
| self.window = 1200 | |
| class App: | |
| title = 'Math Expression Recognition Demo \n\n Note: For Math Detection, we reuse the model from this repo [ScanSSD: Scanning Single Shot Detector for Math in Document Images](https://github.com/MaliParag/ScanSSD).\n\nThis demo aim to present the effciency of our method [A Hybrid Vision Transformer Approach for Mathematical Expression Recognition](https://ieeexplore.ieee.org/document/10034626) in recognizing math expression in document images.' | |
| def __init__(self): | |
| self._model_cache = {} | |
| self.detect_model = MathDetector('saved_models/math_detect/AMATH512_e1GTDB.pth', DetectCfg()) | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| self.image_resizer = ResNetV2(layers=[2, 3, 3], num_classes=max((672, 192))//32, global_pool='avg', in_chans=1, drop_rate=.05, | |
| preact=True, stem_type='same', conv_layer=StdConv2dSame).to(device) | |
| self.image_resizer.load_state_dict(torch.load('saved_models/resizer/image_resizer.pth', map_location=device)) | |
| self.image_resizer.eval() | |
| def detect_preprocess(self, img_list): | |
| if isinstance(img_list, Image.Image): | |
| img_list = [img_list] | |
| new_images = [] | |
| for temp_image in img_list: | |
| img_size = 1280 | |
| # convert image to numpy array | |
| temp_image = np.array(temp_image) | |
| img = cv2.resize(temp_image, (img_size, int(img_size * temp_image.shape[0] / temp_image.shape[1]))) | |
| new_images.append(img) | |
| return new_images | |
| def _get_model(self, name): | |
| if name in self._model_cache: | |
| return self._model_cache[name] | |
| with open('recog_cfg.yaml', 'r') as f: | |
| recog_cfg = yaml.safe_load(f) | |
| model_cfg = {} | |
| model_cfg.update(recog_cfg['common']) | |
| model_cfg.update(recog_cfg[name]) | |
| recog_model = MathRecognition(model_cfg, self.image_resizer if model_cfg['resizer'] else None | |
| ) | |
| self._model_cache[name] = recog_model | |
| return recog_model | |
| def _get_boxes(self, img, temp_bb): | |
| temp_bb[0] = max(0, temp_bb[0] - int(0.05 * (temp_bb[2] - temp_bb[0]))) | |
| temp_bb[1] = max(0, temp_bb[1] - int(0.05 * (temp_bb[3] - temp_bb[1]))) | |
| temp_bb[2] = min(img.shape[1], temp_bb[2] + int(0.05 * (temp_bb[2] - temp_bb[0]))) | |
| temp_bb[3] = min(img.shape[0], temp_bb[3] + int(0.05 * (temp_bb[3] - temp_bb[1]))) | |
| # convert to int | |
| temp_bb = [int(x) for x in temp_bb] | |
| return temp_bb | |
| def math_detection(self, page_lst: List[np.ndarray]): | |
| res = [] | |
| batch_size = 32 | |
| threshold = 0.9 | |
| iou = 0.1 | |
| for idx, temp_image in enumerate(page_lst): | |
| crops_list, padded_crops_list, crops_info_list = get_rolling_crops(temp_image, stride=[128, 128]) | |
| scores_list = [] | |
| wb_list = [] | |
| for i in range(0, len(padded_crops_list), batch_size): | |
| batch = padded_crops_list[i:i+batch_size] | |
| window_borders, scores = self.detect_model.DetectAny(batch, threshold) | |
| scores_list.extend(scores) | |
| wb_list.extend(window_borders) | |
| # change crops to original image coordinates | |
| bb_list, s_list = postprocess(wb_list, scores_list, crops_info_list) | |
| # convert to torch tensors | |
| bb_torch = torch.tensor(bb_list).float() | |
| scores_torch = torch.tensor(s_list) | |
| # perform non-maximum suppression | |
| # check if bb_torch is empty | |
| if bb_torch.shape[0] == 0: | |
| res.append(([], [])) | |
| continue | |
| indices = nms(bb_torch, scores_torch, iou) | |
| bb_torch = bb_torch[indices] | |
| new_bb_list = bb_torch.int().tolist() | |
| for i in range(len(new_bb_list)): | |
| save_name = 'Page ' + str(idx) + '-Expr ' + str(i) if len(page_lst) > 1 else 'Expr ' + str(i) | |
| temp_bb = self._get_boxes(temp_image, new_bb_list[i][:]) | |
| crop_expr = temp_image[temp_bb[1]:temp_bb[3], temp_bb[0]:temp_bb[2]] | |
| crop_expr = Image.fromarray(crop_expr) | |
| res.append((save_name, crop_expr)) | |
| return res | |
| def math_recognition(self, model_name, res: List): | |
| model = self._get_model(model_name) | |
| final_res = [] | |
| for item in res: | |
| name, crop_expr = item | |
| if isinstance(crop_expr, list): | |
| continue | |
| latex_str = model(crop_expr, name=name) | |
| final_res.append((name, crop_expr, latex_str)) | |
| return final_res | |
| def __call__(self, model_name, image_list, use_detect): | |
| #Detect | |
| if use_detect: | |
| new_images = self.detect_preprocess(image_list) | |
| res = self.math_detection(page_lst=new_images) | |
| else: | |
| res = [('latex_pred', image_list[0])] | |
| #Recog | |
| final_res = self.math_recognition(model_name, res) | |
| display_name, origin_img, latex_pred = tuple([list(item) for item in zip(*final_res)]) | |
| return display_name, origin_img, latex_pred | |
| def api(): | |
| app = App() | |
| streamlit.set_page_config(page_title='Extract math expressions from documents', layout='wide') | |
| streamlit.title(f'{app.title}') | |
| streamlit.markdown(f""" | |
| To use this interactive demo and reproduced models: | |
| 1. Select what type of input data you want to get prediction. | |
| 2. Upload your own image or pdf file (or select from the given examples). | |
| 3. If input file is in pdf format, choose start page and end page. | |
| 4. Click **Extract**. | |
| **Note: Current version of this demo only support single file upload for both Image and PDF option.** | |
| """ | |
| ) | |
| # model_name = streamlit.radio( | |
| # label='The Math Recognition model to use', | |
| # options=app.models | |
| # ) | |
| extract_option = streamlit.radio( | |
| label='Select type of input for prediction', | |
| options=('Math expression image only', 'Full document image'), | |
| ) | |
| uploaded_file = streamlit.file_uploader( | |
| 'Upload an image/pdf file', | |
| type=['png', 'jpg', 'pdf'], | |
| accept_multiple_files=False | |
| ) | |
| if uploaded_file is not None: | |
| if Path(uploaded_file.name).suffix == '.pdf': | |
| bytes_data = uploaded_file.read() | |
| image_lst = convert_from_bytes(bytes_data, dpi=160, grayscale=True) | |
| image_lst = [img.convert('RGB') for img in image_lst] | |
| container = streamlit.container() | |
| range_cols = container.columns(2) | |
| start_page = range_cols[0].number_input(label='Start page', min_value=0, max_value=len(image_lst)-2) | |
| end_page = range_cols[1].number_input(label='End page', min_value=1, max_value=len(image_lst)-1) | |
| if start_page <= end_page: | |
| image_lst = image_lst[start_page:end_page+1] | |
| cols = streamlit.columns(len(image_lst)) | |
| for i in range(len(cols)): | |
| with cols[i]: | |
| img_shape = image_lst[i].size | |
| streamlit.image(image_lst[i], width=1024, caption=f'Page: {str(i)} Image shape: {str(img_shape)}', use_column_width='auto') | |
| else: | |
| image = Image.open(uploaded_file).convert('RGB') | |
| image_lst = [image] | |
| img_shape = image.size | |
| streamlit.image(image, width=1024, caption='Image shape: ' + str(img_shape)) | |
| else: | |
| streamlit.text('\n') | |
| if streamlit.button('Extract'): | |
| if uploaded_file is not None and image_lst is not None: | |
| with streamlit.spinner('Computing'): | |
| try: | |
| use_detect = True | |
| if extract_option == 'Math expression image only': | |
| use_detect = False | |
| model_name = 'version2' | |
| else: | |
| model_name = 'version2' | |
| display_name, origin_img, latex_code = app(model_name, image_lst, use_detect) | |
| if Path(uploaded_file.name).suffix == '.pdf': | |
| page_dict = defaultdict(list) | |
| for name, img, pred in zip(display_name, origin_img, latex_code): | |
| name_components = name.split('-') | |
| if len(name_components) <= 1: | |
| page_name = 'Page0' | |
| else: | |
| page_name = name_components[0] | |
| page_dict[page_name].append((img, pred)) | |
| tab_lst = streamlit.tabs(list(page_dict.keys())) | |
| for tab, page_name in zip(tab_lst, list(page_dict.keys())): | |
| for idx, item in enumerate(page_dict[page_name]): | |
| container = tab.container() | |
| col_latex, col_render, col_org = container.columns(3, gap='large') | |
| if idx == 0: | |
| col_latex.header('Predicted LaTeX') | |
| col_render.header('Rendered Image') | |
| col_org.header('Cropped Image') | |
| render_latex = f'$\\displaystyle {item[-1]}$' | |
| col_latex.code(item[-1], language='latex') | |
| col_render.markdown(render_latex) | |
| img = np.asarray(item[0]) | |
| col_org.image(img) | |
| else: | |
| for idx, (name, org, latex) in enumerate(zip(display_name, origin_img, latex_code)): | |
| container = streamlit.container() | |
| col_latex, col_render, col_org = container.columns(3, gap='large') | |
| if idx == 0: | |
| col_latex.header('Predicted LaTeX') | |
| col_render.header('Rendered Image') | |
| col_org.header('Cropped Image') | |
| render_latex = f'$\\displaystyle {latex}$' | |
| col_latex.code(latex, language='latex') | |
| col_render.markdown(render_latex) | |
| org = np.asarray(org) | |
| col_org.image(org) | |
| except Exception as e: | |
| streamlit.error(e) | |
| else: | |
| streamlit.error('Please upload an image.') | |
| if __name__ == '__main__': | |
| # print(f"Is CUDA available: {torch.cuda.is_available()}") | |
| # # True | |
| # print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
| # Tesla T4 | |
| api() | |