classification_cdp / core /predict.py
Kalp Kanungo
First commit to huggingface
9ce11ed
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import cv2
import os
class Customcnn(nn.Module):
def __init__(self,input_dim,num_classes):
super(Customcnn,self).__init__()
self.input_dim=input_dim
self.num_classes=num_classes
self.conv=nn.Sequential(
nn.Conv2d(3,32,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.Conv2d(32,64,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.Conv2d(128,256,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2,stride=2),
)
self.toLinear= None
self.get_cov_output(self.input_dim)
self.fc=nn.Sequential(
nn.Linear(self.toLinear,512),
nn.ReLU(),
nn.Linear(512,128),
nn.ReLU(),
nn.Linear(128,self.num_classes)
)
def get_cov_output(self,input_dim):
with torch.no_grad():
dummy=torch.zeros(1,3,input_dim,input_dim)
output=self.conv(dummy)
self.toLinear=output.view(1,-1).size(1)
def forward(self,x):
x=self.conv(x)
x=x.view(x.size(0),-1)
x=self.fc(x)
return x
class ImageClassifier():
def __init__(self,model_path):
self.device=torch.device("mps" if torch.backends.mps.is_available() else "cpu")
self.model=Customcnn(input_dim=128,num_classes=3).to(self.device)
self.model.load_state_dict(torch.load(model_path,map_location=self.device))
self.model.eval()
self.class_name={0:"Cat",1:"Dog",2:"person"}
self.transform=transforms.Compose([
transforms.Resize((128,128)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])
]
)
def predict(self,image_path):
image=Image.open(image_path).convert("RGB")
image_tensor=self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
output=self.model(image_tensor)
_,predicted=torch.max(output,1)
label=self.class_name[predicted.item()]
img=cv2.imread(image_path)
cv2.putText(img,label,(10,30),cv2.FONT_HERSHEY_COMPLEX,1,(255,0,0),2)
output_path="labeled_image.jpeg"
cv2.imwrite(output_path,img)
cwd=os.getcwd()
output_path=os.path.join(cwd,output_path)
return label,output_path