Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from torchvision import transforms | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| # Caminho do modelo | |
| model_path = "models/simple_cnn_jit_epoch_10.pt" | |
| # Pré-processamento | |
| transform = transforms.Compose([ | |
| transforms.Resize((8, 8)), | |
| transforms.ToTensor(), | |
| transforms.Lambda(lambda x: 1 - x) | |
| ]) | |
| # Carregar o modelo JIT | |
| model = torch.jit.load(model_path, map_location="cpu") | |
| model.eval() | |
| # Função de predição | |
| def predict(inputs): | |
| pil_img = Image.fromarray(inputs['composite']) | |
| image = transform(pil_img).unsqueeze(0) | |
| print(image) | |
| with torch.no_grad(): | |
| output = model(image) | |
| probs = F.softmax(output, dim=1).squeeze().cpu().numpy() | |
| print(probs) | |
| labels = [str(i) for i in range(len(probs))] | |
| print(labels) | |
| result = {label: float(prob) for label, prob in zip(labels, probs)} | |
| print(result) | |
| return result | |
| # Interface Gradio | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Sketchpad(crop_size=(256,256), type='numpy', image_mode='L', brush=gr.Brush(color_mode="fixed", colors=['#000000']), layers=False, show_label=False), | |
| outputs=gr.Label(show_label=False, num_top_classes=7), | |
| examples=["examples/digit1.png", "examples/digit2.png", "examples/digit3.png", "examples/digit4.png", "examples/digit5.png", "examples/digit6.png", "examples/digit7.png", "examples/digit8.png", "examples/digit9.png"], | |
| title="Number Classification", | |
| description="Number classification using a simple CNN model trained on the MNIST dataset. Draw a digit in the sketchpad or select from the examples and click 'Submit' to classify it.", | |
| ) | |
| demo.launch() |