File size: 1,478 Bytes
c9e0c1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06c8a6d
c9e0c1d
 
 
 
 
 
 
 
aa97307
c9e0c1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f31d30
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torchvision.transforms import functional as F
from src.models.model import ShapeClassifier  # Import your model class
from torchvision import transforms
import os
from src.data.transform import data_transform


def classify_drawing(drawing_image):
    # return null if no drawing is provided
    if drawing_image is None:
        return None

    # Load the trained model
    num_classes = 3  # Set the number of classes
    # Initialize your model class
    model = ShapeClassifier(num_classes=num_classes)
    model.load_state_dict(torch.load('results/models/model.pth', map_location=torch.device('cpu')))
    model.eval()  # Set the model to evaluation mode

    # Convert the drawing to a grayscale image
    drawing = np.array(drawing_image)

    drawing_tensor = data_transform(Image.fromarray(drawing))

    # save all the drawing to a folder draw with index
    # Image.fromarray(drawing).save(f'draw/{len(os.listdir("draw"))}.png')

    # Perform inference
    with torch.no_grad():
        output = model(drawing_tensor)

    shape_classes = ["Circle", "Square", "Triangle"]
    predicted_class = torch.argmax(output, dim=1).item()
    predicted_label = shape_classes[predicted_class]

    return predicted_label


iface = gr.Interface(
    fn=classify_drawing,
    inputs=gr.Image(type="pil"),  # Use Sketchpad as input
    outputs="text",
    live=True,
)
iface.launch(server_port=7860)