File size: 1,662 Bytes
e8b9365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import gradio as gr
import torch
from torchvision import transforms
import torch.nn.functional as F
from PIL import Image

# Caminho do modelo
model_path = "models/simple_cnn_jit_epoch_10.pt"

# Pré-processamento 
transform = transforms.Compose([
    transforms.Resize((8, 8)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: 1 - x)
])

# Carregar o modelo JIT
model = torch.jit.load(model_path, map_location="cpu")
model.eval()

# Função de predição
def predict(inputs):
    pil_img = Image.fromarray(inputs['composite'])
    image = transform(pil_img).unsqueeze(0)
    print(image)
    with torch.no_grad():
        output = model(image)
        probs = F.softmax(output, dim=1).squeeze().cpu().numpy() 
        print(probs) 
    labels = [str(i) for i in range(len(probs))] 
    print(labels)
    result = {label: float(prob) for label, prob in zip(labels, probs)}
    print(result)
    return result

# Interface Gradio
demo = gr.Interface(
    fn=predict,
    inputs=gr.Sketchpad(crop_size=(256,256), type='numpy', image_mode='L', brush=gr.Brush(color_mode="fixed", colors=['#000000']), layers=False, show_label=False),
    outputs=gr.Label(show_label=False, num_top_classes=7),
    examples=["examples/digit1.png", "examples/digit2.png", "examples/digit3.png", "examples/digit4.png", "examples/digit5.png", "examples/digit6.png", "examples/digit7.png", "examples/digit8.png", "examples/digit9.png"],
    title="Number Classification",
    description="Number classification using a simple CNN model trained on the MNIST dataset. Draw a digit in the sketchpad or select from the examples and click 'Submit' to classify it.",
)

demo.launch()