small fixes
Browse files- model/decoders.py +3 -3
- model/encoders.py +1 -1
- model/t5_vae.py +2 -3
- model/vae.py +2 -2
model/decoders.py
CHANGED
|
@@ -12,10 +12,10 @@ class Decoder(nn.Module):
|
|
| 12 |
n_latent_tokens: int
|
| 13 |
|
| 14 |
@nn.compact
|
| 15 |
-
def __call__(self, latent_code):
|
| 16 |
-
raw_latent_tokens = nn.
|
| 17 |
latent_tokens = nn.LayerNorm()(raw_latent_tokens)
|
| 18 |
-
return latent_tokens # (batch,
|
| 19 |
|
| 20 |
|
| 21 |
VAE_DECODER_MODELS = {
|
|
|
|
| 12 |
n_latent_tokens: int
|
| 13 |
|
| 14 |
@nn.compact
|
| 15 |
+
def __call__(self, latent_code): # (batch, latent_tokens_per_sequence, latent_token_dim)
|
| 16 |
+
raw_latent_tokens = nn.Dense(self.dim_model)(latent_code)
|
| 17 |
latent_tokens = nn.LayerNorm()(raw_latent_tokens)
|
| 18 |
+
return latent_tokens # (batch, latent_tokens_per_sequence, dim_model)
|
| 19 |
|
| 20 |
|
| 21 |
VAE_DECODER_MODELS = {
|
model/encoders.py
CHANGED
|
@@ -13,7 +13,7 @@ class Encoder(nn.Module):
|
|
| 13 |
|
| 14 |
@nn.compact
|
| 15 |
def __call__(self, encoding):
|
| 16 |
-
latent_tokens = nn.
|
| 17 |
raw_latent_code = latent_tokens[:, : self.n_tokens, :]
|
| 18 |
latent_code = nn.Tanh()(raw_latent_code)
|
| 19 |
return latent_code # (batch, latent_tokens_per_sequence, latent_token_dim)
|
|
|
|
| 13 |
|
| 14 |
@nn.compact
|
| 15 |
def __call__(self, encoding):
|
| 16 |
+
latent_tokens = nn.Dense(self.latent_size)(encoding)
|
| 17 |
raw_latent_code = latent_tokens[:, : self.n_tokens, :]
|
| 18 |
latent_code = nn.Tanh()(raw_latent_code)
|
| 19 |
return latent_code # (batch, latent_tokens_per_sequence, latent_token_dim)
|
model/t5_vae.py
CHANGED
|
@@ -28,8 +28,7 @@ class FlaxT5_VAE_ForAutoencodingModule(nn.Module):
|
|
| 28 |
return self.t5.decoder
|
| 29 |
|
| 30 |
def setup(self):
|
| 31 |
-
self.
|
| 32 |
-
self.t5 = FlaxT5ForConditionalGenerationModule(self.config)
|
| 33 |
self.vae = VAE(self.config)
|
| 34 |
|
| 35 |
def __call__(
|
|
@@ -79,7 +78,7 @@ class FlaxT5_VAE_ForAutoencodingModule(nn.Module):
|
|
| 79 |
if self.config.tie_word_embeddings:
|
| 80 |
# Rescale output before projecting on vocab
|
| 81 |
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
| 82 |
-
sequence_output = sequence_output * (self.
|
| 83 |
|
| 84 |
if self.config.tie_word_embeddings:
|
| 85 |
shared_embedding = self.shared.variables["params"]["embedding"]
|
|
|
|
| 28 |
return self.t5.decoder
|
| 29 |
|
| 30 |
def setup(self):
|
| 31 |
+
self.t5 = FlaxT5ForConditionalGenerationModule(self.config.t5)
|
|
|
|
| 32 |
self.vae = VAE(self.config)
|
| 33 |
|
| 34 |
def __call__(
|
|
|
|
| 78 |
if self.config.tie_word_embeddings:
|
| 79 |
# Rescale output before projecting on vocab
|
| 80 |
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
| 81 |
+
sequence_output = sequence_output * (self.config.t5.d_model ** -0.5)
|
| 82 |
|
| 83 |
if self.config.tie_word_embeddings:
|
| 84 |
shared_embedding = self.shared.variables["params"]["embedding"]
|
model/vae.py
CHANGED
|
@@ -17,8 +17,8 @@ class VAE(nn.Module):
|
|
| 17 |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 18 |
|
| 19 |
def setup(self):
|
| 20 |
-
self.encoder = VAE_ENCODER_MODELS[self.config.
|
| 21 |
-
self.decoder = VAE_DECODER_MODELS[self.config.
|
| 22 |
|
| 23 |
def __call__(self, encoding=None, latent_codes=None):
|
| 24 |
if latent_codes is None:
|
|
|
|
| 17 |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
| 18 |
|
| 19 |
def setup(self):
|
| 20 |
+
self.encoder = VAE_ENCODER_MODELS[self.config.vae_encoder_model](self.config.latent_size, self.config.n_latent_tokens)
|
| 21 |
+
self.decoder = VAE_DECODER_MODELS[self.config.vae_decoder_model](self.config.t5.d_model, self.config.n_latent_tokens)
|
| 22 |
|
| 23 |
def __call__(self, encoding=None, latent_codes=None):
|
| 24 |
if latent_codes is None:
|