Spaces:
Sleeping
Sleeping
Adds a textbox for user to determine mask ratio; also includes center crop before resizing.
Browse files
app.py
CHANGED
|
@@ -11,6 +11,7 @@ pt_model_path = 'MAE1.bin'
|
|
| 11 |
ft_model_path='EmotionClassifier1.bin'
|
| 12 |
|
| 13 |
transform = transforms.Compose([
|
|
|
|
| 14 |
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
| 15 |
transforms.ToTensor(),
|
| 16 |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
|
@@ -64,10 +65,10 @@ yolo_mapping = {
|
|
| 64 |
8: "Surprised"
|
| 65 |
}
|
| 66 |
|
| 67 |
-
def mae_reconstruct(image:Image, figure_name='figure/demo_temp.png'):
|
| 68 |
img = transform(image).unsqueeze(0)
|
| 69 |
img = img.to(device)
|
| 70 |
-
|
| 71 |
with torch.no_grad():
|
| 72 |
x_enc, mask, ids_restore = mae_model.forward_encoder(img)
|
| 73 |
x_rec_patches = mae_model.forward_decoder(x_enc, ids_restore)
|
|
@@ -125,11 +126,11 @@ def classify(image:Image):
|
|
| 125 |
|
| 126 |
return predicted_labels
|
| 127 |
|
| 128 |
-
def predict(image:Image):
|
| 129 |
"""
|
| 130 |
takes PIL image and return reconstructed image and predicted emotion label
|
| 131 |
"""
|
| 132 |
-
|
| 133 |
masked_image, re_image = mae_reconstruct(image, figure_name='figure/demo_temp.png')
|
| 134 |
predicted_labels = classify(image)
|
| 135 |
|
|
@@ -145,7 +146,7 @@ demo = gr.Interface(
|
|
| 145 |
gr.Textbox(label='Predicted Emotion')
|
| 146 |
],
|
| 147 |
title="Emotion Recognition and MAE Reconstruction",
|
| 148 |
-
description="Upload an image to see the reconstructed image (by MAE) and the predicted emotion label."
|
| 149 |
)
|
| 150 |
|
| 151 |
demo.launch(debug=True)
|
|
|
|
| 11 |
ft_model_path='EmotionClassifier1.bin'
|
| 12 |
|
| 13 |
transform = transforms.Compose([
|
| 14 |
+
transforms.CenterCrop(1024)
|
| 15 |
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
| 16 |
transforms.ToTensor(),
|
| 17 |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
|
|
|
| 65 |
8: "Surprised"
|
| 66 |
}
|
| 67 |
|
| 68 |
+
def mae_reconstruct(image:Image, figure_name='figure/demo_temp.png', mask_ratio=0.75):
|
| 69 |
img = transform(image).unsqueeze(0)
|
| 70 |
img = img.to(device)
|
| 71 |
+
mae_model.mask_ratio = mask_ratio
|
| 72 |
with torch.no_grad():
|
| 73 |
x_enc, mask, ids_restore = mae_model.forward_encoder(img)
|
| 74 |
x_rec_patches = mae_model.forward_decoder(x_enc, ids_restore)
|
|
|
|
| 126 |
|
| 127 |
return predicted_labels
|
| 128 |
|
| 129 |
+
def predict(mask_ratio:float, image:Image):
|
| 130 |
"""
|
| 131 |
takes PIL image and return reconstructed image and predicted emotion label
|
| 132 |
"""
|
| 133 |
+
mask_ratio = float(mask_ratio)
|
| 134 |
masked_image, re_image = mae_reconstruct(image, figure_name='figure/demo_temp.png')
|
| 135 |
predicted_labels = classify(image)
|
| 136 |
|
|
|
|
| 146 |
gr.Textbox(label='Predicted Emotion')
|
| 147 |
],
|
| 148 |
title="Emotion Recognition and MAE Reconstruction",
|
| 149 |
+
description="Upload an image to see the reconstructed image (by MAE) and the predicted emotion label. Please only enter a decimal number greater than or equal to 0.00 and less than 1.00."
|
| 150 |
)
|
| 151 |
|
| 152 |
demo.launch(debug=True)
|