Fraser commited on
Commit
6f4a0d9
·
1 Parent(s): 3497606

small fixes

Browse files
Files changed (4) hide show
  1. model/decoders.py +3 -3
  2. model/encoders.py +1 -1
  3. model/t5_vae.py +2 -3
  4. model/vae.py +2 -2
model/decoders.py CHANGED
@@ -12,10 +12,10 @@ class Decoder(nn.Module):
12
  n_latent_tokens: int
13
 
14
  @nn.compact
15
- def __call__(self, latent_code):
16
- raw_latent_tokens = nn.Linear(self.dim_model)(latent_code)
17
  latent_tokens = nn.LayerNorm()(raw_latent_tokens)
18
- return latent_tokens # (batch, n_latent_tokens, dim_model)
19
 
20
 
21
  VAE_DECODER_MODELS = {
 
12
  n_latent_tokens: int
13
 
14
  @nn.compact
15
+ def __call__(self, latent_code): # (batch, latent_tokens_per_sequence, latent_token_dim)
16
+ raw_latent_tokens = nn.Dense(self.dim_model)(latent_code)
17
  latent_tokens = nn.LayerNorm()(raw_latent_tokens)
18
+ return latent_tokens # (batch, latent_tokens_per_sequence, dim_model)
19
 
20
 
21
  VAE_DECODER_MODELS = {
model/encoders.py CHANGED
@@ -13,7 +13,7 @@ class Encoder(nn.Module):
13
 
14
  @nn.compact
15
  def __call__(self, encoding):
16
- latent_tokens = nn.Linear(self.latent_size)(encoding)
17
  raw_latent_code = latent_tokens[:, : self.n_tokens, :]
18
  latent_code = nn.Tanh()(raw_latent_code)
19
  return latent_code # (batch, latent_tokens_per_sequence, latent_token_dim)
 
13
 
14
  @nn.compact
15
  def __call__(self, encoding):
16
+ latent_tokens = nn.Dense(self.latent_size)(encoding)
17
  raw_latent_code = latent_tokens[:, : self.n_tokens, :]
18
  latent_code = nn.Tanh()(raw_latent_code)
19
  return latent_code # (batch, latent_tokens_per_sequence, latent_token_dim)
model/t5_vae.py CHANGED
@@ -28,8 +28,7 @@ class FlaxT5_VAE_ForAutoencodingModule(nn.Module):
28
  return self.t5.decoder
29
 
30
  def setup(self):
31
- self.model_dim = self.config.t5.d_model
32
- self.t5 = FlaxT5ForConditionalGenerationModule(self.config)
33
  self.vae = VAE(self.config)
34
 
35
  def __call__(
@@ -79,7 +78,7 @@ class FlaxT5_VAE_ForAutoencodingModule(nn.Module):
79
  if self.config.tie_word_embeddings:
80
  # Rescale output before projecting on vocab
81
  # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
82
- sequence_output = sequence_output * (self.model_dim ** -0.5)
83
 
84
  if self.config.tie_word_embeddings:
85
  shared_embedding = self.shared.variables["params"]["embedding"]
 
28
  return self.t5.decoder
29
 
30
  def setup(self):
31
+ self.t5 = FlaxT5ForConditionalGenerationModule(self.config.t5)
 
32
  self.vae = VAE(self.config)
33
 
34
  def __call__(
 
78
  if self.config.tie_word_embeddings:
79
  # Rescale output before projecting on vocab
80
  # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
81
+ sequence_output = sequence_output * (self.config.t5.d_model ** -0.5)
82
 
83
  if self.config.tie_word_embeddings:
84
  shared_embedding = self.shared.variables["params"]["embedding"]
model/vae.py CHANGED
@@ -17,8 +17,8 @@ class VAE(nn.Module):
17
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
18
 
19
  def setup(self):
20
- self.encoder = VAE_ENCODER_MODELS[self.config.encoder](self.config.latent_size, self.config.n_latent_tokens)
21
- self.decoder = VAE_DECODER_MODELS[self.config.decoder](self.config.dim_models, self.config.n_latent_tokens)
22
 
23
  def __call__(self, encoding=None, latent_codes=None):
24
  if latent_codes is None:
 
17
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
18
 
19
  def setup(self):
20
+ self.encoder = VAE_ENCODER_MODELS[self.config.vae_encoder_model](self.config.latent_size, self.config.n_latent_tokens)
21
+ self.decoder = VAE_DECODER_MODELS[self.config.vae_decoder_model](self.config.t5.d_model, self.config.n_latent_tokens)
22
 
23
  def __call__(self, encoding=None, latent_codes=None):
24
  if latent_codes is None: