ShaswatRobotics commited on
Commit
cf870d6
·
verified ·
1 Parent(s): 768091b

Update iris/src/tokenizer.py

Browse files
Files changed (1) hide show
  1. iris/src/tokenizer.py +2 -2
iris/src/tokenizer.py CHANGED
@@ -18,9 +18,9 @@ class Tokenizer(nn.Module):
18
  self.embed_dim = config["embed_dim"]
19
  self.encoder = Encoder(config["encoder"])
20
  self.decoder = Decoder(config["decoder"])
21
- self.pre_quant_conv = torch.nn.Conv2d(self.encoder.config.z_channels, self.embed_dim, 1)
22
  self.embedding = nn.Embedding(self.vocab_size, self.embed_dim)
23
- self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, self.decoder.config.z_channels, 1)
24
  self.embedding.weight.data.uniform_(-1.0 / self.vocab_size, 1.0 / self.vocab_size)
25
  self.lpips = LPIPS().eval() if with_lpips else None
26
 
 
18
  self.embed_dim = config["embed_dim"]
19
  self.encoder = Encoder(config["encoder"])
20
  self.decoder = Decoder(config["decoder"])
21
+ self.pre_quant_conv = torch.nn.Conv2d(self.encoder.config["z_channels"], self.embed_dim, 1)
22
  self.embedding = nn.Embedding(self.vocab_size, self.embed_dim)
23
+ self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, self.decoder.config["z_channels"], 1)
24
  self.embedding.weight.data.uniform_(-1.0 / self.vocab_size, 1.0 / self.vocab_size)
25
  self.lpips = LPIPS().eval() if with_lpips else None
26