Spaces:
Runtime error
Runtime error
| # import the necessary packages | |
| from tensorflow import keras | |
| import tensorflow as tf | |
| # Patch conv | |
| class PatchConvNet(keras.Model): | |
| def __init__( | |
| self, | |
| stem, | |
| trunk, | |
| attention_pooling, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.stem = stem | |
| self.trunk = trunk | |
| self.attention_pooling = attention_pooling | |
| def call(self, images): | |
| # pass through the stem | |
| x = self.stem(images) | |
| # pass through the trunk | |
| x = self.trunk(x) | |
| # pass through the attention pooling block | |
| predictions, viz_weights = self.attention_pooling(x) | |
| return predictions, viz_weights |