tranbaohieu
Rename variable
2d1656e
import os
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import spaces
import torch
from model import define_model
from PIL import Image
from utils import plot_sample, postprocess_preds_page_element, reformat_for_plotting
MODEL_NAME="nvidia/nemoretriever-page-elements-v3"
model = define_model("page_element_v3")
@spaces.GPU
def inference(image):
image = np.array(image)
with torch.inference_mode():
x = model.preprocess(image)
preds = model(x, image.shape)[0]
boxes, labels, scores = postprocess_preds_page_element(
preds, model.thresholds_per_class, model.labels
)
boxes_plot, confs = reformat_for_plotting(
boxes, labels, scores, image.shape, model.num_classes
)
plt.figure(figsize=(15, 10))
plot_sample(image, boxes_plot, confs, labels=model.labels)
plt.savefig("output.png", bbox_inches='tight', dpi=150)
return Image.open("output.png").convert("RGB")
def gradio_reset():
return gr.update(value=None), gr.update(value=None)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_img = gr.Image(label=" ", interactive=True)
with gr.Row():
clear = gr.Button(value="Clear")
predict = gr.Button(value="Detect", interactive=True, variant="primary")
with gr.Column():
output_img = gr.Image(label=" ", interactive=False)
clear.click(gradio_reset, inputs=None, outputs=[input_img, output_img])
predict.click(
inference,
inputs=[input_img],
outputs=[output_img],
)
demo.launch()