Spaces:
Build error
Build error
| import gradio as gr | |
| from PIL import Image, ImageOps | |
| import numpy as np | |
| from torchvision import transforms | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class LargeNet(nn.Module): | |
| def __init__(self): | |
| super(LargeNet, self).__init__() | |
| self.name = "large" | |
| self.conv1 = nn.Conv2d(3, 5, 5) | |
| self.pool = nn.MaxPool2d(2, 2) | |
| self.conv2 = nn.Conv2d(5, 10, 5) | |
| self.fc1 = nn.Linear(10 * 29 * 29, 32) | |
| self.fc2 = nn.Linear(32, 8) | |
| def forward(self, x): | |
| x = self.pool(F.relu(self.conv1(x))) | |
| x = self.pool(F.relu(self.conv2(x))) | |
| x = x.view(-1, 10 * 29 * 29) | |
| x = F.relu(self.fc1(x)) | |
| x = self.fc2(x) | |
| x = x.squeeze(1) # Flatten to [batch_size] | |
| return x | |
| def preprocess_image(image, target_size=(128, 128)): | |
| # Load the image | |
| # image = Image.open(image_path).convert("RGB") | |
| image = image.convert("RGB") | |
| print('image' , image) | |
| # Maintain aspect ratio and pad | |
| image = ImageOps.fit(image, target_size, method=Image.BICUBIC, centering=(0.5, 0.5)) | |
| # Normalize pixel values (0 to 1) or standardize | |
| image_array = np.array(image) / 255.0 # Normalize to [0, 1] | |
| return image_array | |
| model = LargeNet() | |
| model.load_state_dict(torch.load("./model_large_bs64_lr0.001_epoch29")) | |
| model.eval() | |
| print(model) | |
| def classify_image(image_path): | |
| classes = ["Gasoline_Can", "Hammer", "Pebbels", "pliers", | |
| "Rope", "Screw_Driver", "Toolbox", "Wrench"] | |
| image = preprocess_image(image_path) | |
| image_tensor = torch.tensor(image).permute(2, 0, 1).unsqueeze(0).float() # Add batch dimension | |
| print('image ', image_tensor.shape) | |
| with torch.no_grad(): | |
| outputs = model(image_tensor) | |
| _, predicted_class = torch.max(outputs, 1) | |
| print(classes[predicted_class.item()]) | |
| return classes[predicted_class.item()] | |
| transform = transforms.Compose([ | |
| transforms.Resize((128, 128)), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomRotation(15), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), | |
| transforms.ToTensor(), # Convert to PyTorch Tensor | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Standardize | |
| ]) | |
| # classify_image('rope1.jpeg') | |
| # Gradio interface | |
| demo = gr.Interface( | |
| fn=classify_image, # Classification function | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Textbox(), | |
| title="Mechanical Tools Classifier" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() # Launch Gradio app |