prekshyam commited on
Commit
ca3b12a
·
verified ·
1 Parent(s): e7311eb

Adds a textbox for user to determine mask ratio; also includes center crop before resizing.

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -11,6 +11,7 @@ pt_model_path = 'MAE1.bin'
11
  ft_model_path='EmotionClassifier1.bin'
12
 
13
  transform = transforms.Compose([
 
14
  transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
15
  transforms.ToTensor(),
16
  transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
@@ -64,10 +65,10 @@ yolo_mapping = {
64
  8: "Surprised"
65
  }
66
 
67
- def mae_reconstruct(image:Image, figure_name='figure/demo_temp.png'):
68
  img = transform(image).unsqueeze(0)
69
  img = img.to(device)
70
-
71
  with torch.no_grad():
72
  x_enc, mask, ids_restore = mae_model.forward_encoder(img)
73
  x_rec_patches = mae_model.forward_decoder(x_enc, ids_restore)
@@ -125,11 +126,11 @@ def classify(image:Image):
125
 
126
  return predicted_labels
127
 
128
- def predict(image:Image):
129
  """
130
  takes PIL image and return reconstructed image and predicted emotion label
131
  """
132
-
133
  masked_image, re_image = mae_reconstruct(image, figure_name='figure/demo_temp.png')
134
  predicted_labels = classify(image)
135
 
@@ -145,7 +146,7 @@ demo = gr.Interface(
145
  gr.Textbox(label='Predicted Emotion')
146
  ],
147
  title="Emotion Recognition and MAE Reconstruction",
148
- description="Upload an image to see the reconstructed image (by MAE) and the predicted emotion label."
149
  )
150
 
151
  demo.launch(debug=True)
 
11
  ft_model_path='EmotionClassifier1.bin'
12
 
13
  transform = transforms.Compose([
14
+ transforms.CenterCrop(1024)
15
  transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
16
  transforms.ToTensor(),
17
  transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
 
65
  8: "Surprised"
66
  }
67
 
68
+ def mae_reconstruct(image:Image, figure_name='figure/demo_temp.png', mask_ratio=0.75):
69
  img = transform(image).unsqueeze(0)
70
  img = img.to(device)
71
+ mae_model.mask_ratio = mask_ratio
72
  with torch.no_grad():
73
  x_enc, mask, ids_restore = mae_model.forward_encoder(img)
74
  x_rec_patches = mae_model.forward_decoder(x_enc, ids_restore)
 
126
 
127
  return predicted_labels
128
 
129
+ def predict(mask_ratio:float, image:Image):
130
  """
131
  takes PIL image and return reconstructed image and predicted emotion label
132
  """
133
+ mask_ratio = float(mask_ratio)
134
  masked_image, re_image = mae_reconstruct(image, figure_name='figure/demo_temp.png')
135
  predicted_labels = classify(image)
136
 
 
146
  gr.Textbox(label='Predicted Emotion')
147
  ],
148
  title="Emotion Recognition and MAE Reconstruction",
149
+ description="Upload an image to see the reconstructed image (by MAE) and the predicted emotion label. Please only enter a decimal number greater than or equal to 0.00 and less than 1.00."
150
  )
151
 
152
  demo.launch(debug=True)