Dddrl's picture
Upload app.py
bea1b8e verified
Raw
History Blame Contribute Delete
997 Bytes
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms as transforms
# Load the trained model
model_path = "model_scripted.pt"
net = torch.jit.load(model_path)
net.eval()
# Define the target classes (update this to match your model's classes)
target_classes = ["Rope", "Hammer", "Other"]
# Define a prediction function
def classify_image(image):
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
image = transform(image).unsqueeze(0) # Add batch dimension
with torch.no_grad():
output = net(image)
_, predicted = torch.max(output, 1)
return target_classes[predicted.item()]
# Create the Gradio interface
interface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil"),
outputs="text",
title="Mechanical Tools Classifier"
)
# Launch the app
interface.launch()