emilyseong commited on
Commit
0f69ec0
·
verified ·
1 Parent(s): 266f9b3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from PIL import Image, ImageOps
4
+ import numpy as np
5
+ from torchvision import transforms
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ class LargeNet(nn.Module):
10
+
11
+ def __init__(self):
12
+ super(LargeNet, self).__init__()
13
+ self.name = "large"
14
+ self.conv1 = nn.Conv2d(3, 5, 5)
15
+ self.pool = nn.MaxPool2d(2, 2)
16
+ self.conv2 = nn.Conv2d(5, 10, 5)
17
+ self.fc1 = nn.Linear(10 * 29 * 29, 32)
18
+ self.fc2 = nn.Linear(32, 8)
19
+
20
+ def forward(self, x):
21
+ x = self.pool(F.relu(self.conv1(x)))
22
+ x = self.pool(F.relu(self.conv2(x)))
23
+ x = x.view(-1, 10 * 29 * 29)
24
+ x = F.relu(self.fc1(x))
25
+ x = self.fc2(x)
26
+ x = x.squeeze(1) # Flatten to [batch_size]
27
+ return x
28
+
29
+ def preprocess_image(image, target_size=(128, 128)):
30
+ # Load the image
31
+ # image = Image.open(image_path).convert("RGB")
32
+ image = image.convert("RGB")
33
+ print('image' , image)
34
+ # Maintain aspect ratio and pad
35
+ image = ImageOps.fit(image, target_size, method=Image.BICUBIC, centering=(0.5, 0.5))
36
+
37
+ # Normalize pixel values (0 to 1) or standardize
38
+ image_array = np.array(image) / 255.0 # Normalize to [0, 1]
39
+
40
+ return image_array
41
+
42
+ model = LargeNet()
43
+ model.load_state_dict(torch.load("/Users/seong-eunseon/Library/Mobile Documents/com~apple~CloudDocs/Seong/1. Project/UT/수업/강화학습/model_large_bs64_lr0.001_epoch29"))
44
+ model.eval()
45
+ print(model)
46
+ def classify_image(image_path):
47
+ classes = ["Gasoline_Can", "Hammer", "Pebbels", "pliers",
48
+ "Rope", "Screw_Driver", "Toolbox", "Wrench"]
49
+ image = preprocess_image(image_path)
50
+ image_tensor = torch.tensor(image).permute(2, 0, 1).unsqueeze(0).float() # Add batch dimension
51
+ print('image ', image_tensor.shape)
52
+ with torch.no_grad():
53
+ outputs = model(image_tensor)
54
+ _, predicted_class = torch.max(outputs, 1)
55
+ print(classes[predicted_class.item()])
56
+ return classes[predicted_class.item()]
57
+
58
+
59
+ transform = transforms.Compose([
60
+ transforms.Resize((128, 128)),
61
+ transforms.RandomHorizontalFlip(),
62
+ transforms.RandomRotation(15),
63
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
64
+ transforms.ToTensor(), # Convert to PyTorch Tensor
65
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Standardize
66
+ ])
67
+
68
+
69
+
70
+ # classify_image('rope1.jpeg')
71
+ # Gradio interface
72
+ demo = gr.Interface(
73
+ fn=classify_image, # Classification function
74
+ inputs=gr.Image(type="pil"),
75
+ outputs=gr.Textbox(),
76
+ title="Mechanical Tools Classifier"
77
+ )
78
+
79
+ if __name__ == "__main__":
80
+ demo.launch() # Launch Gradio app