ShaswatRobotics commited on
Commit
f3b1b83
·
verified ·
1 Parent(s): f04cbea

Update iris/src/tokenizer.py

Browse files
Files changed (1) hide show
  1. iris/src/tokenizer.py +2 -2
iris/src/tokenizer.py CHANGED
@@ -16,11 +16,11 @@ class Tokenizer(nn.Module):
16
  super().__init__()
17
  self.vocab_size = config["vocab_size"]
18
  self.embed_dim = config["embed_dim"]
19
- self.encoder = Encoder(config)
20
  self.pre_quant_conv = torch.nn.Conv2d(encoder.config.z_channels, self.embed_dim, 1)
21
  self.embedding = nn.Embedding(self.vocab_size, self.embed_dim)
22
  self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, decoder.config.z_channels, 1)
23
- self.decoder = Decoder(config)
24
  self.embedding.weight.data.uniform_(-1.0 / self.vocab_size, 1.0 / self.vocab_size)
25
  self.lpips = LPIPS().eval() if with_lpips else None
26
 
 
16
  super().__init__()
17
  self.vocab_size = config["vocab_size"]
18
  self.embed_dim = config["embed_dim"]
19
+ self.encoder = Encoder(config["encoder"])
20
  self.pre_quant_conv = torch.nn.Conv2d(encoder.config.z_channels, self.embed_dim, 1)
21
  self.embedding = nn.Embedding(self.vocab_size, self.embed_dim)
22
  self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, decoder.config.z_channels, 1)
23
+ self.decoder = Decoder(config["decoder"])
24
  self.embedding.weight.data.uniform_(-1.0 / self.vocab_size, 1.0 / self.vocab_size)
25
  self.lpips = LPIPS().eval() if with_lpips else None
26