ShaswatRobotics commited on
Commit
702745a
·
verified ·
1 Parent(s): b2ade59

Delete iris/tokenizer.py

Browse files
Files changed (1) hide show
  1. iris/tokenizer.py +0 -81
iris/tokenizer.py DELETED
@@ -1,81 +0,0 @@
1
- """
2
- Credits to https://github.com/CompVis/taming-transformers
3
- """
4
-
5
- from typing import Tuple
6
-
7
- from einops import rearrange
8
- import torch
9
- import torch.nn as nn
10
-
11
- from models.lpips import LPIPS
12
- from models.nets import Encoder, Decoder
13
-
14
- class Tokenizer(nn.Module):
15
- def __init__(self, vocab_size: int, embed_dim: int, encoder: Encoder, decoder: Decoder, with_lpips: bool = True) -> None:
16
- super().__init__()
17
- self.vocab_size = vocab_size
18
- self.encoder = encoder
19
- self.pre_quant_conv = torch.nn.Conv2d(encoder.config.z_channels, embed_dim, 1)
20
- self.embedding = nn.Embedding(vocab_size, embed_dim)
21
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, decoder.config.z_channels, 1)
22
- self.decoder = decoder
23
- self.embedding.weight.data.uniform_(-1.0 / vocab_size, 1.0 / vocab_size)
24
- self.lpips = LPIPS().eval() if with_lpips else None
25
-
26
- def __repr__(self) -> str:
27
- return "tokenizer"
28
-
29
- def forward(self, x: torch.Tensor, should_preprocess: bool = False, should_postprocess: bool = False) -> Tuple[torch.Tensor]:
30
- outputs = self.encode(x, should_preprocess)
31
- decoder_input = outputs.z + (outputs.z_quantized - outputs.z).detach()
32
- reconstructions = self.decode(decoder_input, should_postprocess)
33
- return outputs.z, outputs.z_quantized, reconstructions
34
-
35
- def encode(self, x: torch.Tensor, should_preprocess: bool = False) -> dict:
36
- if should_preprocess:
37
- x = self.preprocess_input(x)
38
- shape = x.shape # (..., C, H, W)
39
- x = x.view(-1, *shape[-3:])
40
- z = self.encoder(x)
41
- z = self.pre_quant_conv(z)
42
- b, e, h, w = z.shape
43
- z_flattened = rearrange(z, 'b e h w -> (b h w) e')
44
- dist_to_embeddings = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1) - 2 * torch.matmul(z_flattened, self.embedding.weight.t())
45
-
46
- tokens = dist_to_embeddings.argmin(dim=-1)
47
- z_q = rearrange(self.embedding(tokens), '(b h w) e -> b e h w', b=b, e=e, h=h, w=w).contiguous()
48
-
49
- # Reshape to original
50
- z = z.reshape(*shape[:-3], *z.shape[1:])
51
- z_q = z_q.reshape(*shape[:-3], *z_q.shape[1:])
52
- tokens = tokens.reshape(*shape[:-3], -1)
53
-
54
- return {
55
- "z": z,
56
- "z_quantized": z_q,
57
- "tokens": tokens
58
- }
59
-
60
- def decode(self, z_q: torch.Tensor, should_postprocess: bool = False) -> torch.Tensor:
61
- shape = z_q.shape # (..., E, h, w)
62
- z_q = z_q.view(-1, *shape[-3:])
63
- z_q = self.post_quant_conv(z_q)
64
- rec = self.decoder(z_q)
65
- rec = rec.reshape(*shape[:-3], *rec.shape[1:])
66
- if should_postprocess:
67
- rec = self.postprocess_output(rec)
68
- return rec
69
-
70
- @torch.no_grad()
71
- def encode_decode(self, x: torch.Tensor, should_preprocess: bool = False, should_postprocess: bool = False) -> torch.Tensor:
72
- z_q = self.encode(x, should_preprocess).z_quantized
73
- return self.decode(z_q, should_postprocess)
74
-
75
- def preprocess_input(self, x: torch.Tensor) -> torch.Tensor:
76
- """x is supposed to be channels first and in [0, 1]"""
77
- return x.mul(2).sub(1)
78
-
79
- def postprocess_output(self, y: torch.Tensor) -> torch.Tensor:
80
- """y is supposed to be channels first and in [-1, 1]"""
81
- return y.add(1).div(2)