Spaces:
Paused
Paused
| import torch | |
| import torchvision.transforms as transforms | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| from math2latex.data import Tokenizer | |
| from math2latex.model import ResNetTransformer | |
| # Global variables to hold the setup components | |
| model, tokenizer = None, None | |
| def get_formulas(filename): | |
| with open(filename, 'r') as f: | |
| formulas = f.readlines() | |
| return formulas | |
| def latex2image(latex_expression, image_name, image_size_in=(3, 0.6), fontsize=12, dpi=200): | |
| # Runtime Configuration Parameters | |
| matplotlib.rcParams["mathtext.fontset"] = "cm" # Font changed to Computer Modern | |
| # matplotlib.rcParams['text.usetex'] = True # Use LaTeX to write all text | |
| fig = plt.figure(figsize=image_size_in, dpi=dpi) | |
| text = fig.text( | |
| x=0.5, | |
| y=0.5, | |
| s=latex_expression, | |
| horizontalalignment="center", | |
| verticalalignment="center", | |
| fontsize=fontsize, | |
| ) | |
| plt.savefig(image_name) | |
| plt.close(fig) | |
| def setup(): | |
| global model, tokenizer | |
| # setup the model | |
| checkpoint_path = 'checkpoints/model.ckpt' | |
| model = ResNetTransformer() | |
| state_dict = torch.load(checkpoint_path, map_location='cpu')['state_dict'] | |
| state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()} | |
| model.load_state_dict(state_dict) | |
| model.to("cpu") | |
| model.eval() | |
| # # setup the tokenizer | |
| formulas = get_formulas('dataset/im2latex_formulas.norm.processed.lst') | |
| tokenizer = Tokenizer(formulas) | |
| def predict_image(image): | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| setup() | |
| transform = transforms.ToTensor() | |
| image = transform(image) | |
| image = image.unsqueeze(0) | |
| with torch.no_grad(): | |
| output = model.predict(image) | |
| tokens = tokenizer.decode(output[0].tolist()) | |
| return tokens | |
| def predict_and_convert_to_image(image): | |
| latex_code = predict_image(image) | |
| image_name = 'temp.png' | |
| latex_code_modified = latex_code.replace(" ", "") # Remove spaces from the LaTeX code | |
| latex_code_modified = rf"""${latex_code_modified}$""" | |
| latex2image(latex_code_modified, image_name) | |
| # Return both the LaTeX code and the path of the generated image | |
| return latex_code, image_name | |
| def main(): | |
| setup() | |
| examples = [ | |
| ["dataset/formula_images_processed/78228211ca.png"], | |
| ["dataset/formula_images_processed/2b891b21ac.png"], | |
| ["dataset/formula_images_processed/a8ec0c091c.png"], | |
| ] | |
| demo = gr.Interface( | |
| fn=predict_and_convert_to_image, | |
| inputs='image', | |
| outputs=['text', 'image'], | |
| # examples=examples, | |
| title='Image to LaTeX code', | |
| description='Convert an image of a mathematical formula to LaTeX code and view the result as an image. Upload an image of a formula to get both the LaTeX code and the corresponding image or use the examples provided.' | |
| ) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |