emilyseong's picture
Update app.py
89ac582 verified
import gradio as gr
from PIL import Image, ImageOps
import numpy as np
from torchvision import transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
class LargeNet(nn.Module):
def __init__(self):
super(LargeNet, self).__init__()
self.name = "large"
self.conv1 = nn.Conv2d(3, 5, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(5, 10, 5)
self.fc1 = nn.Linear(10 * 29 * 29, 32)
self.fc2 = nn.Linear(32, 8)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 10 * 29 * 29)
x = F.relu(self.fc1(x))
x = self.fc2(x)
x = x.squeeze(1) # Flatten to [batch_size]
return x
def preprocess_image(image, target_size=(128, 128)):
# Load the image
# image = Image.open(image_path).convert("RGB")
image = image.convert("RGB")
print('image' , image)
# Maintain aspect ratio and pad
image = ImageOps.fit(image, target_size, method=Image.BICUBIC, centering=(0.5, 0.5))
# Normalize pixel values (0 to 1) or standardize
image_array = np.array(image) / 255.0 # Normalize to [0, 1]
return image_array
model = LargeNet()
model.load_state_dict(torch.load("./model_large_bs64_lr0.001_epoch29"))
model.eval()
print(model)
def classify_image(image_path):
classes = ["Gasoline_Can", "Hammer", "Pebbels", "pliers",
"Rope", "Screw_Driver", "Toolbox", "Wrench"]
image = preprocess_image(image_path)
image_tensor = torch.tensor(image).permute(2, 0, 1).unsqueeze(0).float() # Add batch dimension
print('image ', image_tensor.shape)
with torch.no_grad():
outputs = model(image_tensor)
_, predicted_class = torch.max(outputs, 1)
print(classes[predicted_class.item()])
return classes[predicted_class.item()]
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(), # Convert to PyTorch Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Standardize
])
# classify_image('rope1.jpeg')
# Gradio interface
demo = gr.Interface(
fn=classify_image, # Classification function
inputs=gr.Image(type="pil"),
outputs=gr.Textbox(),
title="Mechanical Tools Classifier"
)
if __name__ == "__main__":
demo.launch() # Launch Gradio app