Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from torch import nn | |
| import torchvision | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from configs import path_ckpt_fairface | |
| # code adapted from https://github.com/dchen236/FairFace | |
| def init_fair_model(device, path_ckpt=None): | |
| if path_ckpt is None: | |
| path_ckpt = path_ckpt_fairface | |
| model_fair_7 = torchvision.models.resnet34(pretrained=False) | |
| model_fair_7.fc = nn.Linear(model_fair_7.fc.in_features, 18) | |
| model_fair_7.load_state_dict( | |
| torch.load(path_ckpt)) | |
| model_fair_7 = model_fair_7.to(device) | |
| model_fair_7.eval() | |
| return model_fair_7 | |
| def predict_race(model_fair_7, path_img, device): | |
| if type(path_img) == str: | |
| trans = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| image = Image.open(path_img) | |
| image = trans(image) | |
| image = image.view(1, 3, 224, 224) # reshape image to match model dimensions (1 batch size) | |
| elif type(path_img) == torch.Tensor: | |
| trans = transforms.Compose([ | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| image = F.interpolate(path_img, (224, 224)) | |
| image = image * 0.5 + 0.5 | |
| image = trans(image) | |
| image = image.view(1, 3, 224, 224) | |
| image = image.to(device) | |
| outputs = model_fair_7(image) | |
| outputs = outputs.cpu().detach().numpy() | |
| outputs = np.squeeze(outputs) | |
| race_outputs = outputs[:7] | |
| gender_outputs = outputs[7:9] | |
| age_outputs = outputs[9:18] | |
| race_score = np.exp(race_outputs) / np.sum(np.exp(race_outputs)) | |
| gender_score = np.exp(gender_outputs) / np.sum(np.exp(gender_outputs)) | |
| age_score = np.exp(age_outputs) / np.sum(np.exp(age_outputs)) | |
| race_pred = np.argmax(race_score) | |
| gender_pred = np.argmax(gender_score) | |
| age_pred = np.argmax(age_score) | |
| race_label = ['White', 'Black', 'Latino_Hispanic', 'East Asian', 'Southeast Asian', 'Indian', 'Middle Eastern'] | |
| return race_label[race_pred], race_pred, gender_pred, age_pred | |