Spaces:
Runtime error
Runtime error
chore: house cleaning
Browse files- app.py +39 -9
- utilities/model.py +0 -30
- utilities/visualization.py +0 -45
app.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
| 1 |
# import the necessary packages
|
| 2 |
from utilities import config
|
| 3 |
from utilities import model
|
| 4 |
-
from utilities import visualization
|
| 5 |
from tensorflow import keras
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import gradio as gr
|
| 7 |
|
| 8 |
# load the models from disk
|
|
@@ -19,15 +22,42 @@ conv_attn = keras.models.load_model(
|
|
| 19 |
compile=False
|
| 20 |
)
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
# get the plot attention function
|
| 30 |
-
plot_attention = visualization.PlotAttention(model=patch_conv_net)
|
| 31 |
iface = gr.Interface(
|
| 32 |
fn=plot_attention,
|
| 33 |
inputs=[gr.inputs.Image(label="Input Image")],
|
|
|
|
| 1 |
# import the necessary packages
|
| 2 |
from utilities import config
|
| 3 |
from utilities import model
|
|
|
|
| 4 |
from tensorflow import keras
|
| 5 |
+
from tensorflow.keras import layers
|
| 6 |
+
import tensorflow as tf
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import math
|
| 9 |
import gradio as gr
|
| 10 |
|
| 11 |
# load the models from disk
|
|
|
|
| 22 |
compile=False
|
| 23 |
)
|
| 24 |
|
| 25 |
+
def plot_attention(image):
|
| 26 |
+
# resize the image to a 224, 224 dim
|
| 27 |
+
image = tf.image.convert_image_dtype(image, tf.float32)
|
| 28 |
+
image = tf.image.resize(image, (224, 224))
|
| 29 |
+
image = image[tf.newaxis, ...]
|
| 30 |
+
|
| 31 |
+
# pass through the stem
|
| 32 |
+
test_x = conv_stem(image)
|
| 33 |
+
# pass through the trunk
|
| 34 |
+
test_x = conv_trunk(test_x)
|
| 35 |
+
# pass through the attention pooling block
|
| 36 |
+
_, test_viz_weights = conv_attn(test_x)
|
| 37 |
+
test_viz_weights = test_viz_weights[tf.newaxis, ...]
|
| 38 |
+
|
| 39 |
+
# reshape the vizualization weights
|
| 40 |
+
num_patches = tf.shape(test_viz_weights)[-1]
|
| 41 |
+
height = width = int(math.sqrt(num_patches))
|
| 42 |
+
test_viz_weights = layers.Reshape((height, width))(test_viz_weights)
|
| 43 |
+
|
| 44 |
+
index = 0
|
| 45 |
+
selected_image = image[index]
|
| 46 |
+
selected_weight = test_viz_weights[index]
|
| 47 |
+
|
| 48 |
+
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
|
| 49 |
+
ax[0].imshow(selected_image)
|
| 50 |
+
ax[0].set_title(f"Original")
|
| 51 |
+
ax[0].axis("off")
|
| 52 |
+
|
| 53 |
+
img = ax[1].imshow(selected_image)
|
| 54 |
+
ax[1].imshow(selected_weight, cmap='inferno', alpha=0.6, extent=img.get_extent())
|
| 55 |
+
ax[1].set_title(f"Attended")
|
| 56 |
+
ax[1].axis("off")
|
| 57 |
+
|
| 58 |
+
plt.axis("off")
|
| 59 |
+
return plt
|
| 60 |
|
|
|
|
|
|
|
| 61 |
iface = gr.Interface(
|
| 62 |
fn=plot_attention,
|
| 63 |
inputs=[gr.inputs.Image(label="Input Image")],
|
utilities/model.py
DELETED
|
@@ -1,30 +0,0 @@
|
|
| 1 |
-
# import the necessary packages
|
| 2 |
-
from tensorflow import keras
|
| 3 |
-
import tensorflow as tf
|
| 4 |
-
|
| 5 |
-
# Patch conv
|
| 6 |
-
class PatchConvNet(keras.Model):
|
| 7 |
-
def __init__(
|
| 8 |
-
self,
|
| 9 |
-
stem,
|
| 10 |
-
trunk,
|
| 11 |
-
attention_pooling,
|
| 12 |
-
**kwargs,
|
| 13 |
-
):
|
| 14 |
-
super().__init__(**kwargs)
|
| 15 |
-
self.stem = stem
|
| 16 |
-
self.trunk = trunk
|
| 17 |
-
self.attention_pooling = attention_pooling
|
| 18 |
-
|
| 19 |
-
@tf.function(
|
| 20 |
-
input_signature=[
|
| 21 |
-
tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.uint8)
|
| 22 |
-
])
|
| 23 |
-
def call(self, images):
|
| 24 |
-
# pass through the stem
|
| 25 |
-
x = self.stem(images)
|
| 26 |
-
# pass through the trunk
|
| 27 |
-
x = self.trunk(x)
|
| 28 |
-
# pass through the attention pooling block
|
| 29 |
-
predictions, viz_weights = self.attention_pooling(x)
|
| 30 |
-
return predictions, viz_weights
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utilities/visualization.py
DELETED
|
@@ -1,45 +0,0 @@
|
|
| 1 |
-
# import the necessary packages
|
| 2 |
-
from tensorflow.keras import layers
|
| 3 |
-
import tensorflow as tf
|
| 4 |
-
import matplotlib.pyplot as plt
|
| 5 |
-
import math
|
| 6 |
-
|
| 7 |
-
class PlotAttention(object):
|
| 8 |
-
def __init__(self, model):
|
| 9 |
-
self.model = model
|
| 10 |
-
|
| 11 |
-
def __call__(self, image):
|
| 12 |
-
# resize the image to a 224, 224 dim
|
| 13 |
-
image = tf.image.convert_image_dtype(image, tf.float32)
|
| 14 |
-
image = tf.image.resize(image, (224, 224))
|
| 15 |
-
image = image[tf.newaxis, ...]
|
| 16 |
-
|
| 17 |
-
# pass through the stem
|
| 18 |
-
test_x = self.model.stem(image)
|
| 19 |
-
# pass through the trunk
|
| 20 |
-
test_x = self.model.trunk(test_x)
|
| 21 |
-
# pass through the attention pooling block
|
| 22 |
-
_, test_viz_weights = self.model.attention_pooling(test_x)
|
| 23 |
-
test_viz_weights = test_viz_weights[tf.newaxis, ...]
|
| 24 |
-
|
| 25 |
-
# reshape the vizualization weights
|
| 26 |
-
num_patches = tf.shape(test_viz_weights)[-1]
|
| 27 |
-
height = width = int(math.sqrt(num_patches))
|
| 28 |
-
test_viz_weights = layers.Reshape((height, width))(test_viz_weights)
|
| 29 |
-
|
| 30 |
-
index = 0
|
| 31 |
-
selected_image = image[index]
|
| 32 |
-
selected_weight = test_viz_weights[index]
|
| 33 |
-
|
| 34 |
-
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
|
| 35 |
-
ax[0].imshow(selected_image)
|
| 36 |
-
ax[0].set_title(f"Original")
|
| 37 |
-
ax[0].axis("off")
|
| 38 |
-
|
| 39 |
-
img = ax[1].imshow(selected_image)
|
| 40 |
-
ax[1].imshow(selected_weight, cmap='inferno', alpha=0.6, extent=img.get_extent())
|
| 41 |
-
ax[1].set_title(f"Attended")
|
| 42 |
-
ax[1].axis("off")
|
| 43 |
-
|
| 44 |
-
plt.axis("off")
|
| 45 |
-
return plt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|