Spaces:
Runtime error
Runtime error
| # import the necessary packages | |
| from tensorflow.keras import layers | |
| import tensorflow as tf | |
| import matplotlib.pyplot as plt | |
| import math | |
| class PlotAttention: | |
| def __init__(self, model): | |
| self.model = model | |
| def __call__(self, image): | |
| # resize the image to a 224, 224 dim | |
| image = tf.image.convert_image_dtype(image, tf.float32) | |
| image = tf.image.resize(image, (224, 224)) | |
| image = image[tf.newaxis, ...] | |
| # pass through the stem | |
| test_x = self.model.stem(image) | |
| # pass through the trunk | |
| test_x = self.model.trunk(test_x) | |
| # pass through the attention pooling block | |
| _, test_viz_weights = self.model.attention_pooling(test_x) | |
| test_viz_weights = test_viz_weights[tf.newaxis, ...] | |
| # reshape the vizualization weights | |
| num_patches = tf.shape(test_viz_weights)[-1] | |
| height = width = int(math.sqrt(num_patches)) | |
| test_viz_weights = layers.Reshape((height, width))(test_viz_weights) | |
| index = 0 | |
| selected_image = image[index] | |
| selected_weight = test_viz_weights[index] | |
| fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5)) | |
| ax[0].imshow(selected_image) | |
| ax[0].set_title(f"Original") | |
| ax[0].axis("off") | |
| img = ax[1].imshow(selected_image) | |
| ax[1].imshow(selected_weight, cmap='inferno', alpha=0.6, extent=img.get_extent()) | |
| ax[1].set_title(f"Attended") | |
| ax[1].axis("off") | |
| plt.axis("off") | |
| return plt |