Update iris/src/tokenizer.py
Browse files- iris/src/tokenizer.py +2 -2
iris/src/tokenizer.py
CHANGED
|
@@ -18,9 +18,9 @@ class Tokenizer(nn.Module):
|
|
| 18 |
self.embed_dim = config["embed_dim"]
|
| 19 |
self.encoder = Encoder(config["encoder"])
|
| 20 |
self.decoder = Decoder(config["decoder"])
|
| 21 |
-
self.pre_quant_conv = torch.nn.Conv2d(self.encoder.config
|
| 22 |
self.embedding = nn.Embedding(self.vocab_size, self.embed_dim)
|
| 23 |
-
self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, self.decoder.config
|
| 24 |
self.embedding.weight.data.uniform_(-1.0 / self.vocab_size, 1.0 / self.vocab_size)
|
| 25 |
self.lpips = LPIPS().eval() if with_lpips else None
|
| 26 |
|
|
|
|
| 18 |
self.embed_dim = config["embed_dim"]
|
| 19 |
self.encoder = Encoder(config["encoder"])
|
| 20 |
self.decoder = Decoder(config["decoder"])
|
| 21 |
+
self.pre_quant_conv = torch.nn.Conv2d(self.encoder.config["z_channels"], self.embed_dim, 1)
|
| 22 |
self.embedding = nn.Embedding(self.vocab_size, self.embed_dim)
|
| 23 |
+
self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, self.decoder.config["z_channels"], 1)
|
| 24 |
self.embedding.weight.data.uniform_(-1.0 / self.vocab_size, 1.0 / self.vocab_size)
|
| 25 |
self.lpips = LPIPS().eval() if with_lpips else None
|
| 26 |
|