import gradio as gr from maevit import ViTForEmotionClassificationMLP, MAEViT import torch from torchvision import transforms from PIL import Image from matplotlib import pyplot as plt IMAGE_SIZE = 224 pt_model_path = 'MAE1.bin' ft_model_path='EmotionClassifier1.bin' transform = transforms.Compose([ transforms.CenterCrop(1024), transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") mae_model = MAEViT( image_size=224, patch_size=16, embed_dim=128, #originally 128 encoder_layers=2, #originally 2 encoder_heads=4, #originally 4 mlp_ratio=2.0, #originally 2.0 mask_ratio=0.75, decoder_embed_dim=64, #originally 64 decoder_layers=2, #originally 2 decoder_heads=4, #originally 4 dropout=0.1 #originally 0.1 ) mae_model.load_state_dict(torch.load(pt_model_path, map_location='cpu')) mae_model.eval() mae_model.to(device) ft_model = ViTForEmotionClassificationMLP( image_size=224, patch_size=16, embed_dim=128, encoder_layers=2, encoder_heads=4, mlp_ratio=2.0, dropout=0.1, num_classes=9, ) # TODO check ft_model.load_state_dict(torch.load(ft_model_path, map_location='cpu')) ft_model.eval() ft_model.to(device) yolo_mapping = { 0: "Angry", 1: "Contempt", 2: "Disgust", 3: "Fear", 4: "Happy", 5: "Natural", 6: "Sad", 7: "Sleepy", 8: "Surprised" } def mae_reconstruct(image:Image, figure_name='figure/demo_temp.png', mask_ratio=0.75): img = transform(image).unsqueeze(0) img = img.to(device) mae_model.mask_ratio = mask_ratio with torch.no_grad(): x_enc, mask, ids_restore = mae_model.forward_encoder(img) x_rec_patches = mae_model.forward_decoder(x_enc, ids_restore) recon_patches = x_rec_patches[:, 1:, :] # exclude CLS token orig_patches = mae_model.patchify(img) # [1, num_patches, patch_dim] mask_bool = mask.unsqueeze(-1).to(torch.bool) # True means that patch was masked # hybrid: unmasked keep original, masked use reconstruction hybrid_patches = torch.where(mask_bool, recon_patches, orig_patches) # [1, N, patch_dim] img_rec = mae_model.unpatchify(recon_patches) # reconstruction only (masked+unmasked predicted) img_hybrid = mae_model.unpatchify(hybrid_patches) # original with masked parts filled in from reconstruction img_masked = mae_model.unpatchify(orig_patches.masked_fill(mask_bool, 0)) # masked input inv_normalize = transforms.Normalize( mean=[-m/s for m, s in zip((0.485,0.456,0.406),(0.229,0.224,0.225))], std =[1/s for s in (0.229,0.224,0.225)] ) def to_img(tensor): img = tensor.squeeze(0).cpu() img = inv_normalize(img) img = img.permute(1,2,0).clamp(0,1).numpy() return img orig_np = to_img(img) masked_np = to_img(img_masked) recon_np = to_img(img_rec) hybrid_np = to_img(img_hybrid) return masked_np, hybrid_np def classify(image:Image): img = transform(image).unsqueeze(0) img = img.to(device) with torch.no_grad(): logits = ft_model(img) probs = logits.softmax(dim=-1) predicted_class = probs.argmax(dim=-1).item() predicted_labels = yolo_mapping[int(predicted_class)] return predicted_labels def predict(mask_ratio:float, image:Image): """ takes PIL image and return reconstructed image and predicted emotion label """ mask_ratio = float(mask_ratio) masked_image, re_image = mae_reconstruct(image, figure_name='figure/demo_temp.png', mask_ratio=mask_ratio) predicted_labels = classify(image) return masked_image, re_image, predicted_labels demo = gr.Interface( fn=predict, inputs=[gr.Textbox(value = '0.75', label='masking ratio'), gr.Image(type='pil', label='Input Image')], outputs=[ gr.Image(type='numpy', label='Randomly Masked Image'), gr.Image(type='numpy', label='Reconstructed Image'), gr.Textbox(label='Predicted Emotion') ], title="Emotion Recognition and MAE Reconstruction", description="Upload an image to see the reconstructed image (by MAE) and the predicted emotion label. Please only enter a decimal number greater than or equal to 0.00 and less than 1.00." ) demo.launch(debug=True)