ShaswatRobotics commited on
Commit
768091b
·
verified ·
1 Parent(s): f3b1b83

Update iris/src/tokenizer.py

Browse files
Files changed (1) hide show
  1. iris/src/tokenizer.py +3 -3
iris/src/tokenizer.py CHANGED
@@ -17,10 +17,10 @@ class Tokenizer(nn.Module):
17
  self.vocab_size = config["vocab_size"]
18
  self.embed_dim = config["embed_dim"]
19
  self.encoder = Encoder(config["encoder"])
20
- self.pre_quant_conv = torch.nn.Conv2d(encoder.config.z_channels, self.embed_dim, 1)
21
- self.embedding = nn.Embedding(self.vocab_size, self.embed_dim)
22
- self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, decoder.config.z_channels, 1)
23
  self.decoder = Decoder(config["decoder"])
 
 
 
24
  self.embedding.weight.data.uniform_(-1.0 / self.vocab_size, 1.0 / self.vocab_size)
25
  self.lpips = LPIPS().eval() if with_lpips else None
26
 
 
17
  self.vocab_size = config["vocab_size"]
18
  self.embed_dim = config["embed_dim"]
19
  self.encoder = Encoder(config["encoder"])
 
 
 
20
  self.decoder = Decoder(config["decoder"])
21
+ self.pre_quant_conv = torch.nn.Conv2d(self.encoder.config.z_channels, self.embed_dim, 1)
22
+ self.embedding = nn.Embedding(self.vocab_size, self.embed_dim)
23
+ self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, self.decoder.config.z_channels, 1)
24
  self.embedding.weight.data.uniform_(-1.0 / self.vocab_size, 1.0 / self.vocab_size)
25
  self.lpips = LPIPS().eval() if with_lpips else None
26