import os import gradio as gr import matplotlib.pyplot as plt import numpy as np import spaces import torch from model import define_model from PIL import Image from utils import plot_sample, postprocess_preds_page_element, reformat_for_plotting MODEL_NAME="nvidia/nemoretriever-page-elements-v3" model = define_model("page_element_v3") @spaces.GPU def inference(image): image = np.array(image) with torch.inference_mode(): x = model.preprocess(image) preds = model(x, image.shape)[0] boxes, labels, scores = postprocess_preds_page_element( preds, model.thresholds_per_class, model.labels ) boxes_plot, confs = reformat_for_plotting( boxes, labels, scores, image.shape, model.num_classes ) plt.figure(figsize=(15, 10)) plot_sample(image, boxes_plot, confs, labels=model.labels) plt.savefig("output.png", bbox_inches='tight', dpi=150) return Image.open("output.png").convert("RGB") def gradio_reset(): return gr.update(value=None), gr.update(value=None) with gr.Blocks() as demo: with gr.Row(): with gr.Column(): input_img = gr.Image(label=" ", interactive=True) with gr.Row(): clear = gr.Button(value="Clear") predict = gr.Button(value="Detect", interactive=True, variant="primary") with gr.Column(): output_img = gr.Image(label=" ", interactive=False) clear.click(gradio_reset, inputs=None, outputs=[input_img, output_img]) predict.click( inference, inputs=[input_img], outputs=[output_img], ) demo.launch()