Spaces:
Build error
Build error
File size: 2,567 Bytes
0f69ec0 89ac582 0f69ec0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import gradio as gr
from PIL import Image, ImageOps
import numpy as np
from torchvision import transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
class LargeNet(nn.Module):
def __init__(self):
super(LargeNet, self).__init__()
self.name = "large"
self.conv1 = nn.Conv2d(3, 5, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(5, 10, 5)
self.fc1 = nn.Linear(10 * 29 * 29, 32)
self.fc2 = nn.Linear(32, 8)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 10 * 29 * 29)
x = F.relu(self.fc1(x))
x = self.fc2(x)
x = x.squeeze(1) # Flatten to [batch_size]
return x
def preprocess_image(image, target_size=(128, 128)):
# Load the image
# image = Image.open(image_path).convert("RGB")
image = image.convert("RGB")
print('image' , image)
# Maintain aspect ratio and pad
image = ImageOps.fit(image, target_size, method=Image.BICUBIC, centering=(0.5, 0.5))
# Normalize pixel values (0 to 1) or standardize
image_array = np.array(image) / 255.0 # Normalize to [0, 1]
return image_array
model = LargeNet()
model.load_state_dict(torch.load("./model_large_bs64_lr0.001_epoch29"))
model.eval()
print(model)
def classify_image(image_path):
classes = ["Gasoline_Can", "Hammer", "Pebbels", "pliers",
"Rope", "Screw_Driver", "Toolbox", "Wrench"]
image = preprocess_image(image_path)
image_tensor = torch.tensor(image).permute(2, 0, 1).unsqueeze(0).float() # Add batch dimension
print('image ', image_tensor.shape)
with torch.no_grad():
outputs = model(image_tensor)
_, predicted_class = torch.max(outputs, 1)
print(classes[predicted_class.item()])
return classes[predicted_class.item()]
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(), # Convert to PyTorch Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Standardize
])
# classify_image('rope1.jpeg')
# Gradio interface
demo = gr.Interface(
fn=classify_image, # Classification function
inputs=gr.Image(type="pil"),
outputs=gr.Textbox(),
title="Mechanical Tools Classifier"
)
if __name__ == "__main__":
demo.launch() # Launch Gradio app |