import gradio as gr import torch from torchvision import transforms import torch.nn.functional as F from PIL import Image # Caminho do modelo model_path = "models/simple_cnn_jit_epoch_10.pt" # Pré-processamento transform = transforms.Compose([ transforms.Resize((8, 8)), transforms.ToTensor(), transforms.Lambda(lambda x: 1 - x) ]) # Carregar o modelo JIT model = torch.jit.load(model_path, map_location="cpu") model.eval() # Função de predição def predict(inputs): pil_img = Image.fromarray(inputs['composite']) image = transform(pil_img).unsqueeze(0) print(image) with torch.no_grad(): output = model(image) probs = F.softmax(output, dim=1).squeeze().cpu().numpy() print(probs) labels = [str(i) for i in range(len(probs))] print(labels) result = {label: float(prob) for label, prob in zip(labels, probs)} print(result) return result # Interface Gradio demo = gr.Interface( fn=predict, inputs=gr.Sketchpad(crop_size=(256,256), type='numpy', image_mode='L', brush=gr.Brush(color_mode="fixed", colors=['#000000']), layers=False, show_label=False), outputs=gr.Label(show_label=False, num_top_classes=7), examples=["examples/digit1.png", "examples/digit2.png", "examples/digit3.png", "examples/digit4.png", "examples/digit5.png", "examples/digit6.png", "examples/digit7.png", "examples/digit8.png", "examples/digit9.png"], title="Number Classification", description="Number classification using a simple CNN model trained on the MNIST dataset. Draw a digit in the sketchpad or select from the examples and click 'Submit' to classify it.", ) demo.launch()