import gradio as gr import torch from PIL import Image import torchvision.transforms as transforms # Load the trained model model_path = "model_scripted.pt" net = torch.jit.load(model_path) net.eval() # Define the target classes (update this to match your model's classes) target_classes = ["Rope", "Hammer", "Other"] # Define a prediction function def classify_image(image): transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) image = transform(image).unsqueeze(0) # Add batch dimension with torch.no_grad(): output = net(image) _, predicted = torch.max(output, 1) return target_classes[predicted.item()] # Create the Gradio interface interface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), outputs="text", title="Mechanical Tools Classifier" ) # Launch the app interface.launch()