Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import cv2 | |
| import os | |
| class Customcnn(nn.Module): | |
| def __init__(self,input_dim,num_classes): | |
| super(Customcnn,self).__init__() | |
| self.input_dim=input_dim | |
| self.num_classes=num_classes | |
| self.conv=nn.Sequential( | |
| nn.Conv2d(3,32,kernel_size=3,stride=1,padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2,stride=2), | |
| nn.Conv2d(32,64,kernel_size=3,stride=1,padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2,stride=2), | |
| nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2,stride=2), | |
| nn.Conv2d(128,256,kernel_size=3,stride=1,padding=1), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2,stride=2), | |
| ) | |
| self.toLinear= None | |
| self.get_cov_output(self.input_dim) | |
| self.fc=nn.Sequential( | |
| nn.Linear(self.toLinear,512), | |
| nn.ReLU(), | |
| nn.Linear(512,128), | |
| nn.ReLU(), | |
| nn.Linear(128,self.num_classes) | |
| ) | |
| def get_cov_output(self,input_dim): | |
| with torch.no_grad(): | |
| dummy=torch.zeros(1,3,input_dim,input_dim) | |
| output=self.conv(dummy) | |
| self.toLinear=output.view(1,-1).size(1) | |
| def forward(self,x): | |
| x=self.conv(x) | |
| x=x.view(x.size(0),-1) | |
| x=self.fc(x) | |
| return x | |
| class ImageClassifier(): | |
| def __init__(self,model_path): | |
| self.device=torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
| self.model=Customcnn(input_dim=128,num_classes=3).to(self.device) | |
| self.model.load_state_dict(torch.load(model_path,map_location=self.device)) | |
| self.model.eval() | |
| self.class_name={0:"Cat",1:"Dog",2:"person"} | |
| self.transform=transforms.Compose([ | |
| transforms.Resize((128,128)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5]) | |
| ] | |
| ) | |
| def predict(self,image_path): | |
| image=Image.open(image_path).convert("RGB") | |
| image_tensor=self.transform(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| output=self.model(image_tensor) | |
| _,predicted=torch.max(output,1) | |
| label=self.class_name[predicted.item()] | |
| img=cv2.imread(image_path) | |
| cv2.putText(img,label,(10,30),cv2.FONT_HERSHEY_COMPLEX,1,(255,0,0),2) | |
| output_path="labeled_image.jpeg" | |
| cv2.imwrite(output_path,img) | |
| cwd=os.getcwd() | |
| output_path=os.path.join(cwd,output_path) | |
| return label,output_path |