ShaswatRobotics commited on
Commit
f04cbea
·
verified ·
1 Parent(s): cb06601

Update iris/src/tokenizer.py

Browse files
Files changed (1) hide show
  1. iris/src/tokenizer.py +9 -8
iris/src/tokenizer.py CHANGED
@@ -12,15 +12,16 @@ from .models.lpips import LPIPS
12
  from .models.nets import Encoder, Decoder
13
 
14
  class Tokenizer(nn.Module):
15
- def __init__(self, vocab_size: int, embed_dim: int, encoder: Encoder, decoder: Decoder, with_lpips: bool = True) -> None:
16
  super().__init__()
17
- self.vocab_size = vocab_size
18
- self.encoder = encoder
19
- self.pre_quant_conv = torch.nn.Conv2d(encoder.config.z_channels, embed_dim, 1)
20
- self.embedding = nn.Embedding(vocab_size, embed_dim)
21
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, decoder.config.z_channels, 1)
22
- self.decoder = decoder
23
- self.embedding.weight.data.uniform_(-1.0 / vocab_size, 1.0 / vocab_size)
 
24
  self.lpips = LPIPS().eval() if with_lpips else None
25
 
26
  def __repr__(self) -> str:
 
12
  from .models.nets import Encoder, Decoder
13
 
14
  class Tokenizer(nn.Module):
15
+ def __init__(self, config: dict, with_lpips: bool = True) -> None:
16
  super().__init__()
17
+ self.vocab_size = config["vocab_size"]
18
+ self.embed_dim = config["embed_dim"]
19
+ self.encoder = Encoder(config)
20
+ self.pre_quant_conv = torch.nn.Conv2d(encoder.config.z_channels, self.embed_dim, 1)
21
+ self.embedding = nn.Embedding(self.vocab_size, self.embed_dim)
22
+ self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, decoder.config.z_channels, 1)
23
+ self.decoder = Decoder(config)
24
+ self.embedding.weight.data.uniform_(-1.0 / self.vocab_size, 1.0 / self.vocab_size)
25
  self.lpips = LPIPS().eval() if with_lpips else None
26
 
27
  def __repr__(self) -> str: