Image_to_emotion_model / inference.py
viserion999's picture
Update inference.py
857d62b verified
import torch
import torchvision.models as models
from torchvision import transforms
from PIL import Image
model = models.resnet50()
model.fc = torch.nn.Linear(model.fc.in_features, 6)
model_path = "resnet50_c3_lr3e-04_bs32_aug_heavy_opt_adam_drop0.5_ls0.1_6class.pth"
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
labels = ["angry","fear","happy","sad","surprise","neutral"]
transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor()
])
def predict(image):
image = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(image)
pred = torch.argmax(output,1).item()
return {"emotion": labels[pred]}