Spaces:
Running
Running
File size: 4,468 Bytes
3114d75 6799562 3114d75 9e2ba70 3114d75 6799562 7e3569f 6799562 3114d75 7e3569f 3114d75 ca3b12a 3114d75 ab2b6d5 3114d75 ab2b6d5 3114d75 ab2b6d5 3114d75 ab2b6d5 3114d75 87a0a2f 3114d75 ca3b12a 3114d75 ca3b12a 122057c 3114d75 14adcb6 3114d75 6d9db7e 3114d75 ca3b12a c444ecc 14adcb6 694d723 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import gradio as gr
from maevit import ViTForEmotionClassificationMLP, MAEViT
import torch
from torchvision import transforms
from PIL import Image
from matplotlib import pyplot as plt
IMAGE_SIZE = 224
pt_model_path = 'MAE1.bin'
ft_model_path='EmotionClassifier1.bin'
transform = transforms.Compose([
transforms.CenterCrop(1024),
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mae_model = MAEViT(
image_size=224,
patch_size=16,
embed_dim=128, #originally 128
encoder_layers=2, #originally 2
encoder_heads=4, #originally 4
mlp_ratio=2.0, #originally 2.0
mask_ratio=0.75,
decoder_embed_dim=64, #originally 64
decoder_layers=2, #originally 2
decoder_heads=4, #originally 4
dropout=0.1 #originally 0.1
)
mae_model.load_state_dict(torch.load(pt_model_path, map_location='cpu'))
mae_model.eval()
mae_model.to(device)
ft_model = ViTForEmotionClassificationMLP(
image_size=224,
patch_size=16,
embed_dim=128,
encoder_layers=2,
encoder_heads=4,
mlp_ratio=2.0,
dropout=0.1,
num_classes=9,
) # TODO check
ft_model.load_state_dict(torch.load(ft_model_path, map_location='cpu'))
ft_model.eval()
ft_model.to(device)
yolo_mapping = {
0: "Angry",
1: "Contempt",
2: "Disgust",
3: "Fear",
4: "Happy",
5: "Natural",
6: "Sad",
7: "Sleepy",
8: "Surprised"
}
def mae_reconstruct(image:Image, figure_name='figure/demo_temp.png', mask_ratio=0.75):
img = transform(image).unsqueeze(0)
img = img.to(device)
mae_model.mask_ratio = mask_ratio
with torch.no_grad():
x_enc, mask, ids_restore = mae_model.forward_encoder(img)
x_rec_patches = mae_model.forward_decoder(x_enc, ids_restore)
recon_patches = x_rec_patches[:, 1:, :] # exclude CLS token
orig_patches = mae_model.patchify(img) # [1, num_patches, patch_dim]
mask_bool = mask.unsqueeze(-1).to(torch.bool) # True means that patch was masked
# hybrid: unmasked keep original, masked use reconstruction
hybrid_patches = torch.where(mask_bool, recon_patches, orig_patches) # [1, N, patch_dim]
img_rec = mae_model.unpatchify(recon_patches) # reconstruction only (masked+unmasked predicted)
img_hybrid = mae_model.unpatchify(hybrid_patches) # original with masked parts filled in from reconstruction
img_masked = mae_model.unpatchify(orig_patches.masked_fill(mask_bool, 0)) # masked input
inv_normalize = transforms.Normalize(
mean=[-m/s for m, s in zip((0.485,0.456,0.406),(0.229,0.224,0.225))],
std =[1/s for s in (0.229,0.224,0.225)]
)
def to_img(tensor):
img = tensor.squeeze(0).cpu()
img = inv_normalize(img)
img = img.permute(1,2,0).clamp(0,1).numpy()
return img
orig_np = to_img(img)
masked_np = to_img(img_masked)
recon_np = to_img(img_rec)
hybrid_np = to_img(img_hybrid)
return masked_np, hybrid_np
def classify(image:Image):
img = transform(image).unsqueeze(0)
img = img.to(device)
with torch.no_grad():
logits = ft_model(img)
probs = logits.softmax(dim=-1)
predicted_class = probs.argmax(dim=-1).item()
predicted_labels = yolo_mapping[int(predicted_class)]
return predicted_labels
def predict(mask_ratio:float, image:Image):
"""
takes PIL image and return reconstructed image and predicted emotion label
"""
mask_ratio = float(mask_ratio)
masked_image, re_image = mae_reconstruct(image, figure_name='figure/demo_temp.png', mask_ratio=mask_ratio)
predicted_labels = classify(image)
return masked_image, re_image, predicted_labels
demo = gr.Interface(
fn=predict,
inputs=[gr.Textbox(value = '0.75', label='masking ratio'), gr.Image(type='pil', label='Input Image')],
outputs=[
gr.Image(type='numpy', label='Randomly Masked Image'),
gr.Image(type='numpy', label='Reconstructed Image'),
gr.Textbox(label='Predicted Emotion')
],
title="Emotion Recognition and MAE Reconstruction",
description="Upload an image to see the reconstructed image (by MAE) and the predicted emotion label. Please only enter a decimal number greater than or equal to 0.00 and less than 1.00."
)
demo.launch(debug=True) |