File size: 4,468 Bytes
3114d75
 
 
 
 
 
 
 
 
6799562
3114d75
 
 
9e2ba70
3114d75
 
 
 
 
 
 
 
 
 
 
6799562
 
 
 
7e3569f
6799562
 
 
 
3114d75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e3569f
3114d75
 
ca3b12a
3114d75
 
 
 
ab2b6d5
 
3114d75
ab2b6d5
3114d75
ab2b6d5
 
 
 
 
 
3114d75
 
 
 
 
 
 
 
 
 
 
 
 
 
ab2b6d5
3114d75
87a0a2f
3114d75
 
 
 
 
 
 
 
 
 
 
 
 
ca3b12a
3114d75
 
 
ca3b12a
122057c
3114d75
 
 
 
 
14adcb6
3114d75
6d9db7e
3114d75
 
 
 
 
 
ca3b12a
c444ecc
14adcb6
694d723
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import gradio as gr
from maevit import ViTForEmotionClassificationMLP, MAEViT
import torch
from torchvision import transforms
from PIL import Image
from matplotlib import pyplot as plt

IMAGE_SIZE = 224

pt_model_path = 'MAE1.bin'
ft_model_path='EmotionClassifier1.bin'

transform = transforms.Compose([
        transforms.CenterCrop(1024),
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


mae_model = MAEViT(
    image_size=224,
    patch_size=16,
    embed_dim=128, #originally 128
    encoder_layers=2, #originally 2
    encoder_heads=4, #originally 4
    mlp_ratio=2.0, #originally 2.0
    mask_ratio=0.75,
    decoder_embed_dim=64, #originally 64
    decoder_layers=2, #originally 2
    decoder_heads=4, #originally 4
    dropout=0.1 #originally 0.1
)
mae_model.load_state_dict(torch.load(pt_model_path, map_location='cpu'))
mae_model.eval()
mae_model.to(device)

ft_model = ViTForEmotionClassificationMLP(
    image_size=224,
    patch_size=16,
    embed_dim=128,
    encoder_layers=2,
    encoder_heads=4,
    mlp_ratio=2.0,
    dropout=0.1,
    num_classes=9, 
) # TODO check

ft_model.load_state_dict(torch.load(ft_model_path, map_location='cpu'))
ft_model.eval()
ft_model.to(device)


yolo_mapping = {
    0: "Angry",
    1: "Contempt",
    2: "Disgust",
    3: "Fear",
    4: "Happy",
    5: "Natural",
    6: "Sad",
    7: "Sleepy",
    8: "Surprised"
}

def mae_reconstruct(image:Image, figure_name='figure/demo_temp.png', mask_ratio=0.75):
    img = transform(image).unsqueeze(0)
    img = img.to(device)
    mae_model.mask_ratio = mask_ratio
    with torch.no_grad():
        x_enc, mask, ids_restore = mae_model.forward_encoder(img)
        x_rec_patches = mae_model.forward_decoder(x_enc, ids_restore)
    
    recon_patches = x_rec_patches[:, 1:, :]  # exclude CLS token
    orig_patches = mae_model.patchify(img)              # [1, num_patches, patch_dim]

    mask_bool = mask.unsqueeze(-1).to(torch.bool)  # True means that patch was masked

    # hybrid: unmasked keep original, masked use reconstruction
    hybrid_patches = torch.where(mask_bool, recon_patches, orig_patches)  # [1, N, patch_dim]

    img_rec = mae_model.unpatchify(recon_patches)      # reconstruction only (masked+unmasked predicted)
    img_hybrid = mae_model.unpatchify(hybrid_patches)  # original with masked parts filled in from reconstruction
    img_masked = mae_model.unpatchify(orig_patches.masked_fill(mask_bool, 0))  # masked input

    inv_normalize = transforms.Normalize(
        mean=[-m/s for m, s in zip((0.485,0.456,0.406),(0.229,0.224,0.225))],
        std =[1/s for s in       (0.229,0.224,0.225)]
    )
    def to_img(tensor):
        img = tensor.squeeze(0).cpu()
        img = inv_normalize(img)
        img = img.permute(1,2,0).clamp(0,1).numpy()
        return img

    orig_np     = to_img(img)
    masked_np   = to_img(img_masked)
    recon_np    = to_img(img_rec)
    hybrid_np   = to_img(img_hybrid)
    
    return masked_np, hybrid_np

def classify(image:Image):
    img = transform(image).unsqueeze(0)
    img = img.to(device)

    with torch.no_grad():
        logits = ft_model(img)
        probs = logits.softmax(dim=-1)
        predicted_class = probs.argmax(dim=-1).item()
        predicted_labels = yolo_mapping[int(predicted_class)]

    return predicted_labels

def predict(mask_ratio:float, image:Image):
    """
    takes PIL image and return reconstructed image and predicted emotion label
    """
    mask_ratio = float(mask_ratio)
    masked_image, re_image = mae_reconstruct(image, figure_name='figure/demo_temp.png', mask_ratio=mask_ratio)
    predicted_labels = classify(image)

    return masked_image, re_image, predicted_labels


demo = gr.Interface(
    fn=predict,
    inputs=[gr.Textbox(value = '0.75', label='masking ratio'), gr.Image(type='pil', label='Input Image')],
    outputs=[
        gr.Image(type='numpy', label='Randomly Masked Image'),
        gr.Image(type='numpy', label='Reconstructed Image'),
        gr.Textbox(label='Predicted Emotion')
    ],
    title="Emotion Recognition and MAE Reconstruction",
    description="Upload an image to see the reconstructed image (by MAE) and the predicted emotion label. Please only enter a decimal number greater than or equal to 0.00 and less than 1.00."
)

demo.launch(debug=True)