Spaces:
Sleeping
Sleeping
File size: 1,662 Bytes
e8b9365 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import gradio as gr
import torch
from torchvision import transforms
import torch.nn.functional as F
from PIL import Image
# Caminho do modelo
model_path = "models/simple_cnn_jit_epoch_10.pt"
# Pré-processamento
transform = transforms.Compose([
transforms.Resize((8, 8)),
transforms.ToTensor(),
transforms.Lambda(lambda x: 1 - x)
])
# Carregar o modelo JIT
model = torch.jit.load(model_path, map_location="cpu")
model.eval()
# Função de predição
def predict(inputs):
pil_img = Image.fromarray(inputs['composite'])
image = transform(pil_img).unsqueeze(0)
print(image)
with torch.no_grad():
output = model(image)
probs = F.softmax(output, dim=1).squeeze().cpu().numpy()
print(probs)
labels = [str(i) for i in range(len(probs))]
print(labels)
result = {label: float(prob) for label, prob in zip(labels, probs)}
print(result)
return result
# Interface Gradio
demo = gr.Interface(
fn=predict,
inputs=gr.Sketchpad(crop_size=(256,256), type='numpy', image_mode='L', brush=gr.Brush(color_mode="fixed", colors=['#000000']), layers=False, show_label=False),
outputs=gr.Label(show_label=False, num_top_classes=7),
examples=["examples/digit1.png", "examples/digit2.png", "examples/digit3.png", "examples/digit4.png", "examples/digit5.png", "examples/digit6.png", "examples/digit7.png", "examples/digit8.png", "examples/digit9.png"],
title="Number Classification",
description="Number classification using a simple CNN model trained on the MNIST dataset. Draw a digit in the sketchpad or select from the examples and click 'Submit' to classify it.",
)
demo.launch() |