Update iris/src/tokenizer.py
Browse files- iris/src/tokenizer.py +9 -8
iris/src/tokenizer.py
CHANGED
|
@@ -12,15 +12,16 @@ from .models.lpips import LPIPS
|
|
| 12 |
from .models.nets import Encoder, Decoder
|
| 13 |
|
| 14 |
class Tokenizer(nn.Module):
|
| 15 |
-
def __init__(self,
|
| 16 |
super().__init__()
|
| 17 |
-
self.vocab_size = vocab_size
|
| 18 |
-
self.
|
| 19 |
-
self.
|
| 20 |
-
self.
|
| 21 |
-
self.
|
| 22 |
-
self.
|
| 23 |
-
self.
|
|
|
|
| 24 |
self.lpips = LPIPS().eval() if with_lpips else None
|
| 25 |
|
| 26 |
def __repr__(self) -> str:
|
|
|
|
| 12 |
from .models.nets import Encoder, Decoder
|
| 13 |
|
| 14 |
class Tokenizer(nn.Module):
|
| 15 |
+
def __init__(self, config: dict, with_lpips: bool = True) -> None:
|
| 16 |
super().__init__()
|
| 17 |
+
self.vocab_size = config["vocab_size"]
|
| 18 |
+
self.embed_dim = config["embed_dim"]
|
| 19 |
+
self.encoder = Encoder(config)
|
| 20 |
+
self.pre_quant_conv = torch.nn.Conv2d(encoder.config.z_channels, self.embed_dim, 1)
|
| 21 |
+
self.embedding = nn.Embedding(self.vocab_size, self.embed_dim)
|
| 22 |
+
self.post_quant_conv = torch.nn.Conv2d(self.embed_dim, decoder.config.z_channels, 1)
|
| 23 |
+
self.decoder = Decoder(config)
|
| 24 |
+
self.embedding.weight.data.uniform_(-1.0 / self.vocab_size, 1.0 / self.vocab_size)
|
| 25 |
self.lpips = LPIPS().eval() if with_lpips else None
|
| 26 |
|
| 27 |
def __repr__(self) -> str:
|