Rafa-bork's picture
inicial commit
e8b9365
import gradio as gr
import torch
from torchvision import transforms
import torch.nn.functional as F
from PIL import Image
# Caminho do modelo
model_path = "models/simple_cnn_jit_epoch_10.pt"
# Pré-processamento
transform = transforms.Compose([
transforms.Resize((8, 8)),
transforms.ToTensor(),
transforms.Lambda(lambda x: 1 - x)
])
# Carregar o modelo JIT
model = torch.jit.load(model_path, map_location="cpu")
model.eval()
# Função de predição
def predict(inputs):
pil_img = Image.fromarray(inputs['composite'])
image = transform(pil_img).unsqueeze(0)
print(image)
with torch.no_grad():
output = model(image)
probs = F.softmax(output, dim=1).squeeze().cpu().numpy()
print(probs)
labels = [str(i) for i in range(len(probs))]
print(labels)
result = {label: float(prob) for label, prob in zip(labels, probs)}
print(result)
return result
# Interface Gradio
demo = gr.Interface(
fn=predict,
inputs=gr.Sketchpad(crop_size=(256,256), type='numpy', image_mode='L', brush=gr.Brush(color_mode="fixed", colors=['#000000']), layers=False, show_label=False),
outputs=gr.Label(show_label=False, num_top_classes=7),
examples=["examples/digit1.png", "examples/digit2.png", "examples/digit3.png", "examples/digit4.png", "examples/digit5.png", "examples/digit6.png", "examples/digit7.png", "examples/digit8.png", "examples/digit9.png"],
title="Number Classification",
description="Number classification using a simple CNN model trained on the MNIST dataset. Draw a digit in the sketchpad or select from the examples and click 'Submit' to classify it.",
)
demo.launch()