Spaces:
Running
Running
| import gradio as gr | |
| from maevit import ViTForEmotionClassificationMLP, MAEViT | |
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| from matplotlib import pyplot as plt | |
| IMAGE_SIZE = 224 | |
| pt_model_path = 'MAE1.bin' | |
| ft_model_path='EmotionClassifier1.bin' | |
| transform = transforms.Compose([ | |
| transforms.CenterCrop(1024), | |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |
| ]) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| mae_model = MAEViT( | |
| image_size=224, | |
| patch_size=16, | |
| embed_dim=128, #originally 128 | |
| encoder_layers=2, #originally 2 | |
| encoder_heads=4, #originally 4 | |
| mlp_ratio=2.0, #originally 2.0 | |
| mask_ratio=0.75, | |
| decoder_embed_dim=64, #originally 64 | |
| decoder_layers=2, #originally 2 | |
| decoder_heads=4, #originally 4 | |
| dropout=0.1 #originally 0.1 | |
| ) | |
| mae_model.load_state_dict(torch.load(pt_model_path, map_location='cpu')) | |
| mae_model.eval() | |
| mae_model.to(device) | |
| ft_model = ViTForEmotionClassificationMLP( | |
| image_size=224, | |
| patch_size=16, | |
| embed_dim=128, | |
| encoder_layers=2, | |
| encoder_heads=4, | |
| mlp_ratio=2.0, | |
| dropout=0.1, | |
| num_classes=9, | |
| ) # TODO check | |
| ft_model.load_state_dict(torch.load(ft_model_path, map_location='cpu')) | |
| ft_model.eval() | |
| ft_model.to(device) | |
| yolo_mapping = { | |
| 0: "Angry", | |
| 1: "Contempt", | |
| 2: "Disgust", | |
| 3: "Fear", | |
| 4: "Happy", | |
| 5: "Natural", | |
| 6: "Sad", | |
| 7: "Sleepy", | |
| 8: "Surprised" | |
| } | |
| def mae_reconstruct(image:Image, figure_name='figure/demo_temp.png', mask_ratio=0.75): | |
| img = transform(image).unsqueeze(0) | |
| img = img.to(device) | |
| mae_model.mask_ratio = mask_ratio | |
| with torch.no_grad(): | |
| x_enc, mask, ids_restore = mae_model.forward_encoder(img) | |
| x_rec_patches = mae_model.forward_decoder(x_enc, ids_restore) | |
| recon_patches = x_rec_patches[:, 1:, :] # exclude CLS token | |
| orig_patches = mae_model.patchify(img) # [1, num_patches, patch_dim] | |
| mask_bool = mask.unsqueeze(-1).to(torch.bool) # True means that patch was masked | |
| # hybrid: unmasked keep original, masked use reconstruction | |
| hybrid_patches = torch.where(mask_bool, recon_patches, orig_patches) # [1, N, patch_dim] | |
| img_rec = mae_model.unpatchify(recon_patches) # reconstruction only (masked+unmasked predicted) | |
| img_hybrid = mae_model.unpatchify(hybrid_patches) # original with masked parts filled in from reconstruction | |
| img_masked = mae_model.unpatchify(orig_patches.masked_fill(mask_bool, 0)) # masked input | |
| inv_normalize = transforms.Normalize( | |
| mean=[-m/s for m, s in zip((0.485,0.456,0.406),(0.229,0.224,0.225))], | |
| std =[1/s for s in (0.229,0.224,0.225)] | |
| ) | |
| def to_img(tensor): | |
| img = tensor.squeeze(0).cpu() | |
| img = inv_normalize(img) | |
| img = img.permute(1,2,0).clamp(0,1).numpy() | |
| return img | |
| orig_np = to_img(img) | |
| masked_np = to_img(img_masked) | |
| recon_np = to_img(img_rec) | |
| hybrid_np = to_img(img_hybrid) | |
| return masked_np, hybrid_np | |
| def classify(image:Image): | |
| img = transform(image).unsqueeze(0) | |
| img = img.to(device) | |
| with torch.no_grad(): | |
| logits = ft_model(img) | |
| probs = logits.softmax(dim=-1) | |
| predicted_class = probs.argmax(dim=-1).item() | |
| predicted_labels = yolo_mapping[int(predicted_class)] | |
| return predicted_labels | |
| def predict(mask_ratio:float, image:Image): | |
| """ | |
| takes PIL image and return reconstructed image and predicted emotion label | |
| """ | |
| mask_ratio = float(mask_ratio) | |
| masked_image, re_image = mae_reconstruct(image, figure_name='figure/demo_temp.png', mask_ratio=mask_ratio) | |
| predicted_labels = classify(image) | |
| return masked_image, re_image, predicted_labels | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=[gr.Textbox(value = '0.75', label='masking ratio'), gr.Image(type='pil', label='Input Image')], | |
| outputs=[ | |
| gr.Image(type='numpy', label='Randomly Masked Image'), | |
| gr.Image(type='numpy', label='Reconstructed Image'), | |
| gr.Textbox(label='Predicted Emotion') | |
| ], | |
| title="Emotion Recognition and MAE Reconstruction", | |
| description="Upload an image to see the reconstructed image (by MAE) and the predicted emotion label. Please only enter a decimal number greater than or equal to 0.00 and less than 1.00." | |
| ) | |
| demo.launch(debug=True) |