prekshyam commited on
Commit
3114d75
·
verified ·
1 Parent(s): 6389c8a
Files changed (1) hide show
  1. app.py +149 -0
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from maevit import ViTForEmotionClassificationMLP, MAEViT
3
+ import torch
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ from matplotlib import pyplot as plt
7
+
8
+ IMAGE_SIZE = 224
9
+
10
+ pt_model_path = 'MAE1.bin'
11
+ ft_model_path='EmotionClassifier1.bin'
12
+
13
+ transform = transforms.Compose([
14
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
15
+ transforms.ToTensor(),
16
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
17
+ ])
18
+
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+
22
+ mae_model = MAEViT(
23
+ image_size=224,
24
+ patch_size=16,
25
+ embed_dim=128,
26
+ encoder_layers=2,
27
+ encoder_heads=4,
28
+ mlp_ratio=2.0,
29
+ mask_ratio=0.75,
30
+ decoder_embed_dim=64,
31
+ decoder_layers=2,
32
+ decoder_heads=4,
33
+ dropout=0.1
34
+ )
35
+ mae_model.load_state_dict(torch.load(pt_model_path, map_location='cpu'))
36
+ mae_model.eval()
37
+ mae_model.to(device)
38
+
39
+ ft_model = ViTForEmotionClassificationMLP(
40
+ image_size=224,
41
+ patch_size=16,
42
+ embed_dim=128,
43
+ encoder_layers=2,
44
+ encoder_heads=4,
45
+ mlp_ratio=2.0,
46
+ dropout=0.1,
47
+ num_classes=9,
48
+ ) # TODO check
49
+
50
+ ft_model.load_state_dict(torch.load(ft_model_path, map_location='cpu'))
51
+ ft_model.eval()
52
+ ft_model.to(device)
53
+
54
+
55
+ yolo_mapping = {
56
+ 0: "Angry",
57
+ 1: "Contempt",
58
+ 2: "Disgust",
59
+ 3: "Fear",
60
+ 4: "Happy",
61
+ 5: "Natural",
62
+ 6: "Sad",
63
+ 7: "Sleepy",
64
+ 8: "Surprised"
65
+ }
66
+
67
+ def mae_reconstruct(image:Image, figure_name='figure/demo_temp.png'):
68
+ img = transform(image).unsqueeze(0)
69
+ img = img.to(device)
70
+
71
+ with torch.no_grad():
72
+ x_enc, mask, ids_restore = mae_model.forward_encoder(img)
73
+ x_rec_patches = mae_model.forward_decoder(x_enc, ids_restore)
74
+
75
+ img_rec = mae_model.unpatchify(x_rec_patches[:, 1:, :]) # exclude CLS # [1,3,224,224]
76
+ img_patches = mae_model.patchify(img) # [1, num_patches, patch_dim]
77
+
78
+ masked_patches = img_patches.clone()
79
+ mask = mask.unsqueeze(-1).to(torch.bool) # [1, num_patches, 1]
80
+ # masked_patches[mask] = 0
81
+ masked_patches = masked_patches.masked_fill(mask, 0)
82
+
83
+ img_masked = mae_model.unpatchify(masked_patches) # [1,3,224,224]
84
+
85
+ inv_normalize = transforms.Normalize(
86
+ mean=[-m/s for m, s in zip((0.485,0.456,0.406),(0.229,0.224,0.225))],
87
+ std =[1/s for s in (0.229,0.224,0.225)]
88
+ )
89
+ def to_img(tensor):
90
+ img = tensor.squeeze(0).cpu()
91
+ img = inv_normalize(img)
92
+ img = img.permute(1,2,0).clamp(0,1).numpy()
93
+ return img
94
+
95
+ orig_np = to_img(img)
96
+ masked_np = to_img(img_masked)
97
+ recon_np = to_img(img_rec)
98
+
99
+
100
+ # fig, axes = plt.subplots(1, 3, figsize=(15,5))
101
+ # for ax, im, title in zip(axes,
102
+ # [orig_np, masked_np, recon_np],
103
+ # ['Original', 'Masked Input', 'Reconstruction']):
104
+ # ax.imshow(im)
105
+ # ax.set_title(title)
106
+ # ax.axis('off')
107
+ # plt.tight_layout()
108
+ # plt.show()
109
+ # plt.savefig(figure_name)
110
+
111
+ # TODO
112
+ # how to return the reconstructed image?
113
+ # return the reconstructed image as a numpy array
114
+ return masked_np, recon_np
115
+
116
+ def classify(image:Image):
117
+ img = transform(image).unsqueeze(0)
118
+ img = img.to(device)
119
+
120
+ with torch.no_grad():
121
+ logits = ft_model(img)
122
+ probs = logits.softmax(dim=-1)
123
+ predicted_class = probs.argmax(dim=-1).item()
124
+ predicted_labels = yolo_mapping[int(predicted_class)]
125
+
126
+ return predicted_labels
127
+
128
+ def predict(image:Image):
129
+ """
130
+ takes PIL image and return reconstructed image and predicted emotion label
131
+ """
132
+
133
+ masked_image, re_image = mae_reconstruct(image, figure_name='figure/demo_temp.png')
134
+ predicted_labels = classify(image)
135
+
136
+ return masked_image, re_image, predicted_labels
137
+
138
+
139
+ gr.Interface(
140
+ fn=predict,
141
+ inputs=gr.Image(type='pil', label='Input Image'),
142
+ outputs=[
143
+ gr.Image(type='numpy', label='Randomly Masked Image'),
144
+ gr.Image(type='numpy', label='Reconstructed Image'),
145
+ gr.Textbox(label='Predicted Emotion')
146
+ ],
147
+ title="Emotion Recognition and MAE Reconstruction",
148
+ description="Upload an image to see the reconstructed image (by MAE) and the predicted emotion label."
149
+ ).launch(share=True, debug=True)