EmpathNet / app.py
ryu03153's picture
fix to return correct img
87a0a2f verified
import gradio as gr
from maevit import ViTForEmotionClassificationMLP, MAEViT
import torch
from torchvision import transforms
from PIL import Image
from matplotlib import pyplot as plt
IMAGE_SIZE = 224
pt_model_path = 'MAE1.bin'
ft_model_path='EmotionClassifier1.bin'
transform = transforms.Compose([
transforms.CenterCrop(1024),
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mae_model = MAEViT(
image_size=224,
patch_size=16,
embed_dim=128, #originally 128
encoder_layers=2, #originally 2
encoder_heads=4, #originally 4
mlp_ratio=2.0, #originally 2.0
mask_ratio=0.75,
decoder_embed_dim=64, #originally 64
decoder_layers=2, #originally 2
decoder_heads=4, #originally 4
dropout=0.1 #originally 0.1
)
mae_model.load_state_dict(torch.load(pt_model_path, map_location='cpu'))
mae_model.eval()
mae_model.to(device)
ft_model = ViTForEmotionClassificationMLP(
image_size=224,
patch_size=16,
embed_dim=128,
encoder_layers=2,
encoder_heads=4,
mlp_ratio=2.0,
dropout=0.1,
num_classes=9,
) # TODO check
ft_model.load_state_dict(torch.load(ft_model_path, map_location='cpu'))
ft_model.eval()
ft_model.to(device)
yolo_mapping = {
0: "Angry",
1: "Contempt",
2: "Disgust",
3: "Fear",
4: "Happy",
5: "Natural",
6: "Sad",
7: "Sleepy",
8: "Surprised"
}
def mae_reconstruct(image:Image, figure_name='figure/demo_temp.png', mask_ratio=0.75):
img = transform(image).unsqueeze(0)
img = img.to(device)
mae_model.mask_ratio = mask_ratio
with torch.no_grad():
x_enc, mask, ids_restore = mae_model.forward_encoder(img)
x_rec_patches = mae_model.forward_decoder(x_enc, ids_restore)
recon_patches = x_rec_patches[:, 1:, :] # exclude CLS token
orig_patches = mae_model.patchify(img) # [1, num_patches, patch_dim]
mask_bool = mask.unsqueeze(-1).to(torch.bool) # True means that patch was masked
# hybrid: unmasked keep original, masked use reconstruction
hybrid_patches = torch.where(mask_bool, recon_patches, orig_patches) # [1, N, patch_dim]
img_rec = mae_model.unpatchify(recon_patches) # reconstruction only (masked+unmasked predicted)
img_hybrid = mae_model.unpatchify(hybrid_patches) # original with masked parts filled in from reconstruction
img_masked = mae_model.unpatchify(orig_patches.masked_fill(mask_bool, 0)) # masked input
inv_normalize = transforms.Normalize(
mean=[-m/s for m, s in zip((0.485,0.456,0.406),(0.229,0.224,0.225))],
std =[1/s for s in (0.229,0.224,0.225)]
)
def to_img(tensor):
img = tensor.squeeze(0).cpu()
img = inv_normalize(img)
img = img.permute(1,2,0).clamp(0,1).numpy()
return img
orig_np = to_img(img)
masked_np = to_img(img_masked)
recon_np = to_img(img_rec)
hybrid_np = to_img(img_hybrid)
return masked_np, hybrid_np
def classify(image:Image):
img = transform(image).unsqueeze(0)
img = img.to(device)
with torch.no_grad():
logits = ft_model(img)
probs = logits.softmax(dim=-1)
predicted_class = probs.argmax(dim=-1).item()
predicted_labels = yolo_mapping[int(predicted_class)]
return predicted_labels
def predict(mask_ratio:float, image:Image):
"""
takes PIL image and return reconstructed image and predicted emotion label
"""
mask_ratio = float(mask_ratio)
masked_image, re_image = mae_reconstruct(image, figure_name='figure/demo_temp.png', mask_ratio=mask_ratio)
predicted_labels = classify(image)
return masked_image, re_image, predicted_labels
demo = gr.Interface(
fn=predict,
inputs=[gr.Textbox(value = '0.75', label='masking ratio'), gr.Image(type='pil', label='Input Image')],
outputs=[
gr.Image(type='numpy', label='Randomly Masked Image'),
gr.Image(type='numpy', label='Reconstructed Image'),
gr.Textbox(label='Predicted Emotion')
],
title="Emotion Recognition and MAE Reconstruction",
description="Upload an image to see the reconstructed image (by MAE) and the predicted emotion label. Please only enter a decimal number greater than or equal to 0.00 and less than 1.00."
)
demo.launch(debug=True)