|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Layers for embedding."""
|
| import tensorflow as tf
|
|
|
| from layers import base_layers
|
| from layers import quantization_layers
|
|
|
|
|
| class EmbeddingLayer(base_layers.BaseLayer):
|
| """Embedding layer."""
|
|
|
| def __init__(self,
|
| shape,
|
| num_bits=8,
|
| initializer=None,
|
| trainable=True,
|
| **kwargs):
|
| self.shape = shape
|
| self.quantizer = quantization_layers.ActivationQuantization(
|
| num_bits=num_bits, **kwargs)
|
| super(EmbeddingLayer, self).__init__(**kwargs)
|
| if initializer is None:
|
| initializer = tf.keras.initializers.GlorotUniform()
|
| self.initializer = initializer
|
| self.trainable = trainable
|
|
|
| def build(self, input_shapes):
|
| self.embedding_table = self.add_weight(
|
| name="embedding_table",
|
| shape=self.shape,
|
| initializer=self.initializer,
|
| trainable=self.trainable,
|
| dtype=tf.float32)
|
| if self.trainable:
|
| self.add_reg_loss(self.embedding_table)
|
|
|
| def call(self, indices):
|
| assert indices.dtype in [tf.int64, tf.int32]
|
| outputs = tf.nn.embedding_lookup(self.embedding_table, indices)
|
| return self.quantizer(outputs)
|
|
|
|
|
| class EmbeddingFullyConnected(EmbeddingLayer):
|
| """Uses embedding table as weights in a fully connected op."""
|
|
|
| def __init__(self, **kwargs):
|
| shape = kwargs.pop("shape", None)
|
| initializer = kwargs.pop("initializer", None)
|
| self.qoutput = quantization_layers.ActivationQuantization(**kwargs)
|
| super(EmbeddingFullyConnected, self).__init__(
|
| shape=shape, initializer=initializer, **kwargs)
|
|
|
| def fully_connected(self, inputs, bias=None, weights_scale_factor=None):
|
|
|
| self._assert_rank_and_type(inputs, 2)
|
| weights = self.embedding_table
|
| if weights_scale_factor is not None:
|
| weights = weights * weights_scale_factor
|
| outputs = tf.matmul(inputs, weights, transpose_b=True)
|
| if bias is not None:
|
| outputs = tf.nn.bias_add(outputs, bias)
|
| return self.qoutput(outputs)
|
|
|