Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from classes.Perturber import Perturber | |
| from classes.Renderer import Renderer | |
| from classes.LegibilityPlot import LegibilityPlot | |
| from transformers import TrOCRProcessor, AutoModel | |
| # preprocessor provides image normalization and resizing | |
| preprocessor = TrOCRProcessor.from_pretrained( | |
| "microsoft/trocr-base-handwritten") | |
| # load the model schema and pretrained weights | |
| # (this may take some time to download) | |
| model = AutoModel.from_pretrained("dvsth/LEGIT-TrOCR-MT", revision='main', trust_remote_code=True) | |
| perturber = Perturber('trocr', 50) | |
| renderer = Renderer('unifont.ttf') | |
| plotter = LegibilityPlot() | |
| def demo(word_to_perturb, k, n): | |
| if ' ' in word_to_perturb: | |
| return 'Please enter a single word.' | |
| perturbations, metadatas, images, scores = [], [], [], [] | |
| for i in range(10): | |
| perturbation, metadata = perturber.perturb_word(word_to_perturb, k, n) | |
| inputimg = renderer.render_image(perturbation, word_to_perturb) | |
| score = model(preprocessor(inputimg, return_tensors='pt').pixel_values).item() | |
| metadata['score'] = score | |
| outputimg = renderer.render_image(perturbation, '') | |
| perturbations.append(perturbation) | |
| images.append(outputimg) | |
| metadatas.append(metadata) | |
| scores.append(score) | |
| # sort perturbations by score | |
| perturbations = [perturbation for perturbation, score in sorted(zip(perturbations, scores), key=lambda x: x[1])] | |
| scores = sorted(scores) | |
| images = [image for image, score in sorted(zip(images, scores), key=lambda x: x[1])] | |
| metadatas = [metadata for metadata, score in sorted(zip(metadatas, scores), key=lambda x: x[1])] | |
| # return as a single string in the format | |
| # perturbation1 (score1) | |
| # perturbation2 (score2) | |
| # ... | |
| # perturbationN (scoreN) | |
| # with all scores rounded to 2 decimal places | |
| ret_str = '' | |
| for i in range(len(perturbations)): | |
| ret_str += f'{perturbations[i]} ({round(scores[i], 2)}) -- ' + ("legible" if scores[i] > 0 else "not legible") + '\n' | |
| # plot the perturbations and scores | |
| fig = plotter.plot(scores, perturbations) | |
| return ret_str, fig | |
| interface = gr.Interface(fn=demo, inputs=["text", gr.Slider(1, 50, 20, step=1), gr.Slider(0., 1., 0.5)], outputs=["text", "plot"], allow_flagging='never') | |
| interface.launch(inbrowser=True) |