import gradio as gr from PIL import Image, ImageOps import numpy as np from torchvision import transforms import torch import torch.nn as nn import torch.nn.functional as F class LargeNet(nn.Module): def __init__(self): super(LargeNet, self).__init__() self.name = "large" self.conv1 = nn.Conv2d(3, 5, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(5, 10, 5) self.fc1 = nn.Linear(10 * 29 * 29, 32) self.fc2 = nn.Linear(32, 8) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 10 * 29 * 29) x = F.relu(self.fc1(x)) x = self.fc2(x) x = x.squeeze(1) # Flatten to [batch_size] return x def preprocess_image(image, target_size=(128, 128)): # Load the image # image = Image.open(image_path).convert("RGB") image = image.convert("RGB") print('image' , image) # Maintain aspect ratio and pad image = ImageOps.fit(image, target_size, method=Image.BICUBIC, centering=(0.5, 0.5)) # Normalize pixel values (0 to 1) or standardize image_array = np.array(image) / 255.0 # Normalize to [0, 1] return image_array model = LargeNet() model.load_state_dict(torch.load("./model_large_bs64_lr0.001_epoch29")) model.eval() print(model) def classify_image(image_path): classes = ["Gasoline_Can", "Hammer", "Pebbels", "pliers", "Rope", "Screw_Driver", "Toolbox", "Wrench"] image = preprocess_image(image_path) image_tensor = torch.tensor(image).permute(2, 0, 1).unsqueeze(0).float() # Add batch dimension print('image ', image_tensor.shape) with torch.no_grad(): outputs = model(image_tensor) _, predicted_class = torch.max(outputs, 1) print(classes[predicted_class.item()]) return classes[predicted_class.item()] transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), transforms.ToTensor(), # Convert to PyTorch Tensor transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Standardize ]) # classify_image('rope1.jpeg') # Gradio interface demo = gr.Interface( fn=classify_image, # Classification function inputs=gr.Image(type="pil"), outputs=gr.Textbox(), title="Mechanical Tools Classifier" ) if __name__ == "__main__": demo.launch() # Launch Gradio app