| import tensorflow as tf | |
| from tensorflow import keras | |
| from transformer import TransformerDecoder | |
| import tensorflow_probability as tfp | |
| class TFIng(keras.Model): | |
| def __init__(self, crop_size, embed_dim, num_layers, seq_length, hidden_dim, num_heads, | |
| target_vocab_size, dropout_rate=0.1): | |
| super().__init__() | |
| self.target_vocab_size = target_vocab_size | |
| self.encoder = keras.applications.InceptionV3( | |
| include_top=False, | |
| weights="imagenet", | |
| input_shape=crop_size + (3,), | |
| ) | |
| self.conv = keras.layers.Conv2D(embed_dim, 1) | |
| self.decoder = TransformerDecoder(num_layers, seq_length, embed_dim, hidden_dim, num_heads, | |
| target_vocab_size, dropout_rate=dropout_rate) | |
| self.linear = keras.layers.Dense(target_vocab_size, activation=None) | |
| def call(self, inputs, training=False): | |
| encoder_inputs, targets = inputs | |
| encoder_out = self.encoder(encoder_inputs, training=training) | |
| encoder_out = self.conv(encoder_out, training=training) | |
| encoder_out = tf.reshape(encoder_out, (tf.shape(encoder_out)[0], -1, tf.shape(encoder_out)[3])) | |
| decoder_outputs = self.decoder(targets, encoder_out, training=training) | |
| output = self.linear(decoder_outputs) | |
| return output + self.get_replacement_mask(targets) | |
| def get_replacement_mask(self, targets): | |
| targets = tf.cast(targets, tf.int32) | |
| batch_size, seq_length = tf.shape(targets)[0], tf.shape(targets)[1] | |
| n = int(seq_length * (seq_length + 1) / 2) | |
| mask = tfp.math.fill_triangular(tf.ones((n,), dtype=tf.int32)) | |
| mask = tf.repeat(mask[tf.newaxis, :], batch_size, axis=0) | |
| targets_repeated = tf.repeat(targets[:, tf.newaxis, :], seq_length, axis=1) | |
| targets_masked = targets_repeated * mask | |
| columns = tf.boolean_mask( | |
| targets_masked, | |
| tf.where(targets_masked != 0, tf.ones_like(targets_masked), tf.zeros_like(targets_masked)) | |
| ) | |
| rows_idx = tf.range(seq_length) | |
| rows_idx_repeated = tf.reshape(tf.repeat(rows_idx, seq_length), (seq_length, seq_length)) | |
| rows_idx_repeated = tf.repeat(rows_idx_repeated[tf.newaxis, :], batch_size, axis=0) | |
| rows = tf.boolean_mask( | |
| rows_idx_repeated, | |
| tf.where(targets_masked != 0, tf.ones_like(targets_masked), tf.zeros_like(targets_masked)) | |
| ) | |
| batches_idx = tf.range(batch_size) | |
| batches_idx_repeated = tf.reshape( | |
| tf.repeat(batches_idx, seq_length * seq_length), (batch_size, seq_length, seq_length) | |
| ) | |
| batches = tf.boolean_mask( | |
| batches_idx_repeated, | |
| tf.where(targets_masked != 0, tf.ones_like(targets_masked), tf.zeros_like(targets_masked)) | |
| ) | |
| idx = tf.stack([batches, rows, columns], axis=1) | |
| sparse_mask = tf.SparseTensor( | |
| tf.cast(idx, tf.int64), | |
| tf.fill([tf.shape(idx)[0]], float('-inf')), | |
| [batch_size, seq_length, self.target_vocab_size] | |
| ) | |
| sparse_mask = tf.sparse.reorder(sparse_mask) | |
| return tf.sparse.to_dense(sparse_mask) | |