Spaces:
Build error
Build error
File size: 1,478 Bytes
c9e0c1d 06c8a6d c9e0c1d aa97307 c9e0c1d 0f31d30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torchvision.transforms import functional as F
from src.models.model import ShapeClassifier # Import your model class
from torchvision import transforms
import os
from src.data.transform import data_transform
def classify_drawing(drawing_image):
# return null if no drawing is provided
if drawing_image is None:
return None
# Load the trained model
num_classes = 3 # Set the number of classes
# Initialize your model class
model = ShapeClassifier(num_classes=num_classes)
model.load_state_dict(torch.load('results/models/model.pth', map_location=torch.device('cpu')))
model.eval() # Set the model to evaluation mode
# Convert the drawing to a grayscale image
drawing = np.array(drawing_image)
drawing_tensor = data_transform(Image.fromarray(drawing))
# save all the drawing to a folder draw with index
# Image.fromarray(drawing).save(f'draw/{len(os.listdir("draw"))}.png')
# Perform inference
with torch.no_grad():
output = model(drawing_tensor)
shape_classes = ["Circle", "Square", "Triangle"]
predicted_class = torch.argmax(output, dim=1).item()
predicted_label = shape_classes[predicted_class]
return predicted_label
iface = gr.Interface(
fn=classify_drawing,
inputs=gr.Image(type="pil"), # Use Sketchpad as input
outputs="text",
live=True,
)
iface.launch(server_port=7860)
|