ShaswatRobotics commited on
Commit
cf88ce4
·
verified ·
1 Parent(s): c67657d

Upload 8 files

Browse files
iris/src/models/__init__.py ADDED
File without changes
iris/src/models/kv_caching.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ class Cache:
8
+ def __init__(self, num_samples: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device) -> None:
9
+ assert embed_dim % num_heads == 0
10
+ self._n, self._cache, self._size = num_samples, None, None
11
+ self._reset = lambda n: torch.empty(n, num_heads, max_tokens, embed_dim // num_heads, device=device) # (B, nh, T, hs)
12
+ self.reset()
13
+
14
+ @property
15
+ def shape(self) -> Tuple[int, int, int, int]:
16
+ n, num_heads, _, head_dim = self._cache.shape
17
+ return n, num_heads, self._size, head_dim
18
+
19
+ def reset(self) -> None:
20
+ self._cache = self._reset(self._n)
21
+ self._size = 0
22
+
23
+ def prune(self, mask: np.ndarray) -> None:
24
+ assert mask.ndim == 1 and mask.shape[0] == self.shape[0]
25
+ self._cache = self._cache[mask]
26
+ self._n = self._cache.shape[0]
27
+
28
+ def get(self) -> torch.Tensor:
29
+ return self._cache[:, :, :self._size, :]
30
+
31
+ def update(self, x: torch.Tensor) -> None:
32
+ assert (x.ndim == self._cache.ndim) and all([x.size(i) == self._cache.size(i) for i in (0, 1, 3)])
33
+ assert self._size + x.size(2) <= self._cache.shape[2]
34
+ self._cache = AssignWithoutInplaceCheck.apply(self._cache, x, 2, self._size, self._size + x.size(2))
35
+ self._size += x.size(2)
36
+
37
+
38
+ class KVCache:
39
+ def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device) -> None:
40
+ self._k_cache = Cache(n, num_heads, max_tokens, embed_dim, device)
41
+ self._v_cache = Cache(n, num_heads, max_tokens, embed_dim, device)
42
+
43
+ @property
44
+ def shape(self) -> Tuple[int, int, int, int]:
45
+ return self._k_cache.shape
46
+
47
+ def reset(self) -> None:
48
+ self._k_cache.reset()
49
+ self._v_cache.reset()
50
+
51
+ def prune(self, mask: np.ndarray) -> None:
52
+ self._k_cache.prune(mask)
53
+ self._v_cache.prune(mask)
54
+
55
+ def get(self) -> Tuple[torch.Tensor, torch.Tensor]:
56
+ return self._k_cache.get(), self._v_cache.get()
57
+
58
+ def update(self, k: torch.Tensor, v: torch.Tensor):
59
+ self._k_cache.update(k)
60
+ self._v_cache.update(v)
61
+
62
+
63
+ class KeysValues:
64
+ def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, num_layers: int, device: torch.device) -> None:
65
+ self._keys_values = tuple([KVCache(n, num_heads, max_tokens, embed_dim, device) for _ in range(num_layers)])
66
+
67
+ def __getitem__(self, key: int) -> KVCache:
68
+ return self._keys_values[key]
69
+
70
+ def __len__(self):
71
+ return len(self._keys_values)
72
+
73
+ @property
74
+ def size(self):
75
+ return self._keys_values[0].shape[2]
76
+
77
+ def reset(self) -> None:
78
+ for kv_cache in self._keys_values:
79
+ kv_cache.reset()
80
+
81
+ def prune(self, mask: np.ndarray) -> None:
82
+ for kv_cache in self._keys_values:
83
+ kv_cache.prune(mask)
84
+
85
+
86
+ class AssignWithoutInplaceCheck(torch.autograd.Function):
87
+ """
88
+ Inspired from : https://discuss.pytorch.org/t/disable-in-place-correctness-version-check-any-other-workaround/90738/4
89
+ Warning : do not use it to overwrite a slice twice.
90
+ """
91
+
92
+ @staticmethod
93
+ def get_slice(dim: int, start: int, stop: int) -> Tuple[slice]:
94
+ return tuple([slice(None), ] * dim + [slice(start, stop)])
95
+
96
+ @staticmethod
97
+ def forward(ctx, input: torch.Tensor, value: torch.Tensor, dim: int, start: int, stop: int) -> torch.Tensor:
98
+ ctx.dim = dim
99
+ ctx.start = start
100
+ ctx.stop = stop
101
+ input.data[AssignWithoutInplaceCheck.get_slice(dim, start, stop)] = value
102
+ return input
103
+
104
+ @staticmethod
105
+ def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor]:
106
+ return grad_out, grad_out[AssignWithoutInplaceCheck.get_slice(ctx.dim, ctx.start, ctx.stop)], None, None, None
iris/src/models/lpips.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Credits to https://github.com/CompVis/taming-transformers
3
+ """
4
+
5
+ from collections import namedtuple
6
+ import hashlib
7
+ import os
8
+ from pathlib import Path
9
+ import requests
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torchvision import models
14
+ from tqdm import tqdm
15
+
16
+
17
+ class LPIPS(nn.Module):
18
+ # Learned perceptual metric
19
+ def __init__(self, use_dropout: bool = True):
20
+ super().__init__()
21
+ self.scaling_layer = ScalingLayer()
22
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
23
+ self.net = vgg16(pretrained=True, requires_grad=False)
24
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
25
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
26
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
27
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
28
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
29
+ self.load_from_pretrained()
30
+ for param in self.parameters():
31
+ param.requires_grad = False
32
+
33
+ def load_from_pretrained(self) -> None:
34
+ ckpt = get_ckpt_path(name="vgg_lpips", root=Path.home() / ".cache/iris/tokenizer_pretrained_vgg") # Download VGG if necessary
35
+ self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
36
+
37
+ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
38
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
39
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
40
+ feats0, feats1, diffs = {}, {}, {}
41
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
42
+ for kk in range(len(self.chns)):
43
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
44
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
45
+
46
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
47
+ val = res[0]
48
+ for i in range(1, len(self.chns)):
49
+ val += res[i]
50
+ return val
51
+
52
+
53
+ class ScalingLayer(nn.Module):
54
+ def __init__(self) -> None:
55
+ super(ScalingLayer, self).__init__()
56
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
57
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
58
+
59
+ def forward(self, inp: torch.Tensor) -> torch.Tensor:
60
+ return (inp - self.shift) / self.scale
61
+
62
+
63
+ class NetLinLayer(nn.Module):
64
+ """ A single linear layer which does a 1x1 conv """
65
+ def __init__(self, chn_in: int, chn_out: int = 1, use_dropout: bool = False) -> None:
66
+ super(NetLinLayer, self).__init__()
67
+ layers = [nn.Dropout(), ] if (use_dropout) else []
68
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
69
+ self.model = nn.Sequential(*layers)
70
+
71
+
72
+ class vgg16(torch.nn.Module):
73
+ def __init__(self, requires_grad: bool = False, pretrained: bool = True) -> None:
74
+ super(vgg16, self).__init__()
75
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
76
+ self.slice1 = torch.nn.Sequential()
77
+ self.slice2 = torch.nn.Sequential()
78
+ self.slice3 = torch.nn.Sequential()
79
+ self.slice4 = torch.nn.Sequential()
80
+ self.slice5 = torch.nn.Sequential()
81
+ self.N_slices = 5
82
+ for x in range(4):
83
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
84
+ for x in range(4, 9):
85
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
86
+ for x in range(9, 16):
87
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
88
+ for x in range(16, 23):
89
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
90
+ for x in range(23, 30):
91
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
92
+ if not requires_grad:
93
+ for param in self.parameters():
94
+ param.requires_grad = False
95
+
96
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
97
+ h = self.slice1(X)
98
+ h_relu1_2 = h
99
+ h = self.slice2(h)
100
+ h_relu2_2 = h
101
+ h = self.slice3(h)
102
+ h_relu3_3 = h
103
+ h = self.slice4(h)
104
+ h_relu4_3 = h
105
+ h = self.slice5(h)
106
+ h_relu5_3 = h
107
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
108
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
109
+ return out
110
+
111
+
112
+ def normalize_tensor(x: torch.Tensor, eps: float = 1e-10) -> torch.Tensor:
113
+ norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
114
+ return x / (norm_factor + eps)
115
+
116
+
117
+ def spatial_average(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor:
118
+ return x.mean([2, 3], keepdim=keepdim)
119
+
120
+
121
+ # ********************************************************************
122
+ # *************** Utilities to download pretrained vgg ***************
123
+ # ********************************************************************
124
+
125
+
126
+ URL_MAP = {
127
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
128
+ }
129
+
130
+
131
+ CKPT_MAP = {
132
+ "vgg_lpips": "vgg.pth"
133
+ }
134
+
135
+
136
+ MD5_MAP = {
137
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
138
+ }
139
+
140
+
141
+ def download(url: str, local_path: str, chunk_size: int = 1024) -> None:
142
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
143
+ with requests.get(url, stream=True) as r:
144
+ total_size = int(r.headers.get("content-length", 0))
145
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
146
+ with open(local_path, "wb") as f:
147
+ for data in r.iter_content(chunk_size=chunk_size):
148
+ if data:
149
+ f.write(data)
150
+ pbar.update(chunk_size)
151
+
152
+
153
+ def md5_hash(path: str) -> str:
154
+ with open(path, "rb") as f:
155
+ content = f.read()
156
+ return hashlib.md5(content).hexdigest()
157
+
158
+
159
+ def get_ckpt_path(name: str, root: str, check: bool = False) -> str:
160
+ assert name in URL_MAP
161
+ path = os.path.join(root, CKPT_MAP[name])
162
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
163
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
164
+ download(URL_MAP[name], path)
165
+ md5 = md5_hash(path)
166
+ assert md5 == MD5_MAP[name], md5
167
+ return path
iris/src/models/nets.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Credits to https://github.com/CompVis/taming-transformers
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ class Encoder(nn.Module):
9
+ def __init__(self, config: dict) -> None:
10
+ super().__init__()
11
+ self.config = config
12
+ self.num_resolutions = len(config["ch_mult"])
13
+ temb_ch = 0 # timestep embedding #channels
14
+
15
+ # downsampling
16
+ self.conv_in = torch.nn.Conv2d(config["in_channels"],
17
+ config["ch"],
18
+ kernel_size=3,
19
+ stride=1,
20
+ padding=1)
21
+
22
+ curr_res = config["resolution"]
23
+ in_ch_mult = (1,) + tuple(config["ch_mult"])
24
+ self.down = nn.ModuleList()
25
+ for i_level in range(self.num_resolutions):
26
+ block = nn.ModuleList()
27
+ attn = nn.ModuleList()
28
+ block_in = config["ch"] * in_ch_mult[i_level]
29
+ block_out = config["ch"] * config["ch_mult"][i_level]
30
+ for i_block in range(self.config["num_res_blocks"]):
31
+ block.append(ResnetBlock(in_channels=block_in,
32
+ out_channels=block_out,
33
+ temb_channels=temb_ch,
34
+ dropout=config["dropout"]))
35
+ block_in = block_out
36
+ if curr_res in config["attn_resolutions"]:
37
+ attn.append(AttnBlock(block_in))
38
+ down = nn.Module()
39
+ down.block = block
40
+ down.attn = attn
41
+ if i_level != self.num_resolutions - 1:
42
+ down.downsample = Downsample(block_in, with_conv=True)
43
+ curr_res = curr_res // 2
44
+ self.down.append(down)
45
+
46
+ # middle
47
+ self.mid = nn.Module()
48
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
49
+ out_channels=block_in,
50
+ temb_channels=temb_ch,
51
+ dropout=config["dropout"])
52
+ self.mid.attn_1 = AttnBlock(block_in)
53
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
54
+ out_channels=block_in,
55
+ temb_channels=temb_ch,
56
+ dropout=config["dropout"])
57
+
58
+ # end
59
+ self.norm_out = Normalize(block_in)
60
+ self.conv_out = torch.nn.Conv2d(block_in,
61
+ config["z_channels"],
62
+ kernel_size=3,
63
+ stride=1,
64
+ padding=1)
65
+
66
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
67
+
68
+ temb = None # timestep embedding
69
+
70
+ # downsampling
71
+ hs = [self.conv_in(x)]
72
+ for i_level in range(self.num_resolutions):
73
+ for i_block in range(self.config["num_res_blocks"]):
74
+ h = self.down[i_level].block[i_block](hs[-1], temb)
75
+ if len(self.down[i_level].attn) > 0:
76
+ h = self.down[i_level].attn[i_block](h)
77
+ hs.append(h)
78
+ if i_level != self.num_resolutions - 1:
79
+ hs.append(self.down[i_level].downsample(hs[-1]))
80
+
81
+ # middle
82
+ h = hs[-1]
83
+ h = self.mid.block_1(h, temb)
84
+ h = self.mid.attn_1(h)
85
+ h = self.mid.block_2(h, temb)
86
+
87
+ # end
88
+ h = self.norm_out(h)
89
+ h = nonlinearity(h)
90
+ h = self.conv_out(h)
91
+ return h
92
+
93
+
94
+ class Decoder(nn.Module):
95
+ def __init__(self, config: dict) -> None:
96
+ super().__init__()
97
+ self.config = config
98
+ temb_ch = 0
99
+ self.num_resolutions = len(config["ch_mult"])
100
+
101
+ # compute in_ch_mult, block_in and curr_res at lowest res
102
+ in_ch_mult = (1,) + tuple(config["ch_mult"])
103
+ block_in = config["ch"] * config["ch_mult"][self.num_resolutions - 1]
104
+ curr_res = config["resolution"] // 2 ** (self.num_resolutions - 1)
105
+ print(f"Tokenizer : shape of latent is {config["z_channels"], curr_res, curr_res}.")
106
+
107
+ # z to block_in
108
+ self.conv_in = torch.nn.Conv2d(config["z_channels"],
109
+ block_in,
110
+ kernel_size=3,
111
+ stride=1,
112
+ padding=1)
113
+
114
+ # middle
115
+ self.mid = nn.Module()
116
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
117
+ out_channels=block_in,
118
+ temb_channels=temb_ch,
119
+ dropout=config["dropout"])
120
+ self.mid.attn_1 = AttnBlock(block_in)
121
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
122
+ out_channels=block_in,
123
+ temb_channels=temb_ch,
124
+ dropout=config["dropout"])
125
+
126
+ # upsampling
127
+ self.up = nn.ModuleList()
128
+ for i_level in reversed(range(self.num_resolutions)):
129
+ block = nn.ModuleList()
130
+ attn = nn.ModuleList()
131
+ block_out = config["ch"] * config["ch_mult"][i_level]
132
+ for i_block in range(config["num_res_blocks"] + 1):
133
+ block.append(ResnetBlock(in_channels=block_in,
134
+ out_channels=block_out,
135
+ temb_channels=temb_ch,
136
+ dropout=config["dropout"]))
137
+ block_in = block_out
138
+ if curr_res in config["attn_resolutions"]:
139
+ attn.append(AttnBlock(block_in))
140
+ up = nn.Module()
141
+ up.block = block
142
+ up.attn = attn
143
+ if i_level != 0:
144
+ up.upsample = Upsample(block_in, with_conv=True)
145
+ curr_res = curr_res * 2
146
+ self.up.insert(0, up) # prepend to get consistent order
147
+
148
+ # end
149
+ self.norm_out = Normalize(block_in)
150
+ self.conv_out = torch.nn.Conv2d(block_in,
151
+ config["out_ch"],
152
+ kernel_size=3,
153
+ stride=1,
154
+ padding=1)
155
+
156
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
157
+ temb = None # timestep embedding
158
+
159
+ # z to block_in
160
+ h = self.conv_in(z)
161
+
162
+ # middle
163
+ h = self.mid.block_1(h, temb)
164
+ h = self.mid.attn_1(h)
165
+ h = self.mid.block_2(h, temb)
166
+
167
+ # upsampling
168
+ for i_level in reversed(range(self.num_resolutions)):
169
+ for i_block in range(self.config["num_res_blocks"] + 1):
170
+ h = self.up[i_level].block[i_block](h, temb)
171
+ if len(self.up[i_level].attn) > 0:
172
+ h = self.up[i_level].attn[i_block](h)
173
+ if i_level != 0:
174
+ h = self.up[i_level].upsample(h)
175
+
176
+ # end
177
+ h = self.norm_out(h)
178
+ h = nonlinearity(h)
179
+ h = self.conv_out(h)
180
+ return h
181
+
182
+
183
+ def nonlinearity(x: torch.Tensor) -> torch.Tensor:
184
+ # swish
185
+ return x * torch.sigmoid(x)
186
+
187
+
188
+ def Normalize(in_channels: int) -> nn.Module:
189
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
190
+
191
+
192
+ class Upsample(nn.Module):
193
+ def __init__(self, in_channels: int, with_conv: bool) -> None:
194
+ super().__init__()
195
+ self.with_conv = with_conv
196
+ if self.with_conv:
197
+ self.conv = torch.nn.Conv2d(in_channels,
198
+ in_channels,
199
+ kernel_size=3,
200
+ stride=1,
201
+ padding=1)
202
+
203
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
204
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
205
+ if self.with_conv:
206
+ x = self.conv(x)
207
+ return x
208
+
209
+
210
+ class Downsample(nn.Module):
211
+ def __init__(self, in_channels: int, with_conv: bool) -> None:
212
+ super().__init__()
213
+ self.with_conv = with_conv
214
+ if self.with_conv:
215
+ # no asymmetric padding in torch conv, must do it ourselves
216
+ self.conv = torch.nn.Conv2d(in_channels,
217
+ in_channels,
218
+ kernel_size=3,
219
+ stride=2,
220
+ padding=0)
221
+
222
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
223
+ if self.with_conv:
224
+ pad = (0, 1, 0, 1)
225
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
226
+ x = self.conv(x)
227
+ else:
228
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
229
+ return x
230
+
231
+
232
+ class ResnetBlock(nn.Module):
233
+ def __init__(self, *, in_channels: int, out_channels: int = None, conv_shortcut: bool = False,
234
+ dropout: float, temb_channels: int = 512) -> None:
235
+ super().__init__()
236
+ self.in_channels = in_channels
237
+ out_channels = in_channels if out_channels is None else out_channels
238
+ self.out_channels = out_channels
239
+ self.use_conv_shortcut = conv_shortcut
240
+
241
+ self.norm1 = Normalize(in_channels)
242
+ self.conv1 = torch.nn.Conv2d(in_channels,
243
+ out_channels,
244
+ kernel_size=3,
245
+ stride=1,
246
+ padding=1)
247
+ if temb_channels > 0:
248
+ self.temb_proj = torch.nn.Linear(temb_channels,
249
+ out_channels)
250
+ self.norm2 = Normalize(out_channels)
251
+ self.dropout = torch.nn.Dropout(dropout)
252
+ self.conv2 = torch.nn.Conv2d(out_channels,
253
+ out_channels,
254
+ kernel_size=3,
255
+ stride=1,
256
+ padding=1)
257
+ if self.in_channels != self.out_channels:
258
+ if self.use_conv_shortcut:
259
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
260
+ out_channels,
261
+ kernel_size=3,
262
+ stride=1,
263
+ padding=1)
264
+ else:
265
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
266
+ out_channels,
267
+ kernel_size=1,
268
+ stride=1,
269
+ padding=0)
270
+
271
+ def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
272
+ h = x
273
+ h = self.norm1(h)
274
+ h = nonlinearity(h)
275
+ h = self.conv1(h)
276
+
277
+ if temb is not None:
278
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
279
+
280
+ h = self.norm2(h)
281
+ h = nonlinearity(h)
282
+ h = self.dropout(h)
283
+ h = self.conv2(h)
284
+
285
+ if self.in_channels != self.out_channels:
286
+ if self.use_conv_shortcut:
287
+ x = self.conv_shortcut(x)
288
+ else:
289
+ x = self.nin_shortcut(x)
290
+
291
+ return x + h
292
+
293
+
294
+ class AttnBlock(nn.Module):
295
+ def __init__(self, in_channels: int) -> None:
296
+ super().__init__()
297
+ self.in_channels = in_channels
298
+
299
+ self.norm = Normalize(in_channels)
300
+ self.q = torch.nn.Conv2d(in_channels,
301
+ in_channels,
302
+ kernel_size=1,
303
+ stride=1,
304
+ padding=0)
305
+ self.k = torch.nn.Conv2d(in_channels,
306
+ in_channels,
307
+ kernel_size=1,
308
+ stride=1,
309
+ padding=0)
310
+ self.v = torch.nn.Conv2d(in_channels,
311
+ in_channels,
312
+ kernel_size=1,
313
+ stride=1,
314
+ padding=0)
315
+ self.proj_out = torch.nn.Conv2d(in_channels,
316
+ in_channels,
317
+ kernel_size=1,
318
+ stride=1,
319
+ padding=0)
320
+
321
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
322
+ h_ = x
323
+ h_ = self.norm(h_)
324
+ q = self.q(h_)
325
+ k = self.k(h_)
326
+ v = self.v(h_)
327
+
328
+ # compute attention
329
+ b, c, h, w = q.shape
330
+ q = q.reshape(b, c, h * w)
331
+ q = q.permute(0, 2, 1) # b,hw,c
332
+ k = k.reshape(b, c, h * w) # b,c,hw
333
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
334
+ w_ = w_ * (int(c) ** (-0.5))
335
+ w_ = torch.nn.functional.softmax(w_, dim=2)
336
+
337
+ # attend to values
338
+ v = v.reshape(b, c, h * w)
339
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
340
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
341
+ h_ = h_.reshape(b, c, h, w)
342
+
343
+ h_ = self.proj_out(h_)
344
+
345
+ return x + h_
iris/src/models/slicer.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ class Slicer(nn.Module):
8
+ def __init__(self, max_blocks: int, block_mask: torch.Tensor) -> None:
9
+ super().__init__()
10
+ self.block_size = block_mask.size(0)
11
+ self.num_kept_tokens = block_mask.sum().long().item()
12
+ kept_indices = torch.where(block_mask)[0].repeat(max_blocks)
13
+ offsets = torch.arange(max_blocks).repeat_interleave(self.num_kept_tokens)
14
+ self.register_buffer('indices', kept_indices + block_mask.size(0) * offsets)
15
+
16
+ def compute_slice(self, num_steps: int, prev_steps: int = 0) -> torch.Tensor:
17
+ total_steps = num_steps + prev_steps
18
+ num_blocks = math.ceil(total_steps / self.block_size)
19
+ indices = self.indices[:num_blocks * self.num_kept_tokens]
20
+ return indices[torch.logical_and(prev_steps <= indices, indices < total_steps)] - prev_steps
21
+
22
+ def forward(self, *args, **kwargs):
23
+ raise NotImplementedError
24
+
25
+
26
+ class Head(Slicer):
27
+ def __init__(self, max_blocks: int, block_mask: torch.Tensor, head_module: nn.Module) -> None:
28
+ super().__init__(max_blocks, block_mask)
29
+ assert isinstance(head_module, nn.Module)
30
+ self.head_module = head_module
31
+
32
+ def forward(self, x: torch.Tensor, num_steps: int, prev_steps: int) -> torch.Tensor:
33
+ x_sliced = x[:, self.compute_slice(num_steps, prev_steps)] # x is (B, T, E)
34
+ return self.head_module(x_sliced)
35
+
36
+
37
+ class Embedder(nn.Module):
38
+ def __init__(self, max_blocks: int, block_masks: List[torch.Tensor], embedding_tables: List[nn.Embedding]) -> None:
39
+ super().__init__()
40
+ assert len(block_masks) == len(embedding_tables)
41
+ assert (sum(block_masks) == 1).all() # block mask are a partition of a block
42
+ self.embedding_dim = embedding_tables[0].embedding_dim
43
+ assert all([e.embedding_dim == self.embedding_dim for e in embedding_tables])
44
+ self.embedding_tables = embedding_tables
45
+ self.slicers = [Slicer(max_blocks, block_mask) for block_mask in block_masks]
46
+
47
+ def forward(self, tokens: torch.Tensor, num_steps: int, prev_steps: int) -> torch.Tensor:
48
+ assert tokens.ndim == 2 # x is (B, T)
49
+ output = torch.zeros(*tokens.size(), self.embedding_dim, device=tokens.device)
50
+ for slicer, emb in zip(self.slicers, self.embedding_tables):
51
+ s = slicer.compute_slice(num_steps, prev_steps)
52
+ output[:, s] = emb(tokens[:, s])
53
+ return output
iris/src/models/transformer.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Credits to https://github.com/karpathy/minGPT
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+ import math
7
+ from typing import Optional
8
+
9
+ from einops import rearrange
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.nn import functional as F
13
+
14
+ from .kv_caching import KeysValues, KVCache
15
+
16
+ class Transformer(nn.Module):
17
+ def __init__(self, config: dict) -> None:
18
+ super().__init__()
19
+ self.config = config
20
+ self.config["max_tokens"] = config["tokens_per_block"] * config["max_blocks"]
21
+ self.drop = nn.Dropout(config["embed_pdrop"])
22
+ self.blocks = nn.ModuleList([Block(config) for _ in range(config["num_layers"])])
23
+ self.ln_f = nn.LayerNorm(config["embed_dim"])
24
+
25
+ def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues:
26
+ device = self.ln_f.weight.device # Assumption that all submodules are on the same device
27
+ return KeysValues(n, self.config["num_heads"], max_tokens, self.config["embed_dim"], self.config["num_layers"], device)
28
+
29
+ def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues] = None) -> torch.Tensor:
30
+ assert past_keys_values is None or len(past_keys_values) == len(self.blocks)
31
+ x = self.drop(sequences)
32
+ for i, block in enumerate(self.blocks):
33
+ x = block(x, None if past_keys_values is None else past_keys_values[i])
34
+
35
+ x = self.ln_f(x)
36
+ return x
37
+
38
+
39
+ class Block(nn.Module):
40
+ def __init__(self, config: dict) -> None:
41
+ super().__init__()
42
+ self.ln1 = nn.LayerNorm(config["embed_dim"])
43
+ self.ln2 = nn.LayerNorm(config["embed_dim"])
44
+ self.attn = SelfAttention(config)
45
+ self.mlp = nn.Sequential(
46
+ nn.Linear(config["embed_dim"], 4 * config["embed_dim"]),
47
+ nn.GELU(),
48
+ nn.Linear(4 * config["embed_dim"], config["embed_dim"]),
49
+ nn.Dropout(config["resid_pdrop"]),
50
+ )
51
+
52
+ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None) -> torch.Tensor:
53
+ x_attn = self.attn(self.ln1(x), past_keys_values)
54
+ x = x + x_attn
55
+ x = x + self.mlp(self.ln2(x))
56
+ return x
57
+
58
+
59
+ class SelfAttention(nn.Module):
60
+ def __init__(self, config: dict) -> None:
61
+ super().__init__()
62
+ assert config["embed_dim"] % config["num_heads"] == 0
63
+ assert config["attention"] in ('causal', 'block_causal')
64
+ self.num_heads = config["num_heads"]
65
+ self.key = nn.Linear(config["embed_dim"], config["embed_dim"])
66
+ self.query = nn.Linear(config["embed_dim"], config["embed_dim"])
67
+ self.value = nn.Linear(config["embed_dim"], config["embed_dim"])
68
+ self.attn_drop = nn.Dropout(config["attn_pdrop"])
69
+ self.resid_drop = nn.Dropout(config["resid_pdrop"])
70
+ self.proj = nn.Linear(config["embed_dim"], config["embed_dim"])
71
+
72
+ causal_mask = torch.tril(torch.ones(config["max_tokens"], config["max_tokens"]))
73
+ block_causal_mask = torch.max(causal_mask, torch.block_diag(*[torch.ones(config["tokens_per_block"], config["tokens_per_block"]) for _ in range(config["max_blocks"])]))
74
+ self.register_buffer('mask', causal_mask if config["attention"] == 'causal' else block_causal_mask)
75
+
76
+ def forward(self, x: torch.Tensor, kv_cache: Optional[KVCache] = None) -> torch.Tensor:
77
+ B, T, C = x.size()
78
+ if kv_cache is not None:
79
+ b, nh, L, c = kv_cache.shape
80
+ assert nh == self.num_heads and b == B and c * nh == C
81
+ else:
82
+ L = 0
83
+
84
+ q = self.query(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
85
+ k = self.key(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
86
+ v = self.value(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
87
+
88
+ if kv_cache is not None:
89
+ kv_cache.update(k, v)
90
+ k, v = kv_cache.get()
91
+
92
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
93
+ att = att.masked_fill(self.mask[L:L + T, :L + T] == 0, float('-inf'))
94
+ att = F.softmax(att, dim=-1)
95
+ att = self.attn_drop(att)
96
+ y = att @ v
97
+ y = rearrange(y, 'b h t e -> b t (h e)')
98
+
99
+ y = self.resid_drop(self.proj(y))
100
+
101
+ return y
iris/src/tokenizer.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Credits to https://github.com/CompVis/taming-transformers
3
+ """
4
+
5
+ from typing import Tuple
6
+
7
+ from einops import rearrange
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from models.lpips import LPIPS
12
+ from models.nets import Encoder, Decoder
13
+
14
+ class Tokenizer(nn.Module):
15
+ def __init__(self, vocab_size: int, embed_dim: int, encoder: Encoder, decoder: Decoder, with_lpips: bool = True) -> None:
16
+ super().__init__()
17
+ self.vocab_size = vocab_size
18
+ self.encoder = encoder
19
+ self.pre_quant_conv = torch.nn.Conv2d(encoder.config.z_channels, embed_dim, 1)
20
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
21
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, decoder.config.z_channels, 1)
22
+ self.decoder = decoder
23
+ self.embedding.weight.data.uniform_(-1.0 / vocab_size, 1.0 / vocab_size)
24
+ self.lpips = LPIPS().eval() if with_lpips else None
25
+
26
+ def __repr__(self) -> str:
27
+ return "tokenizer"
28
+
29
+ def forward(self, x: torch.Tensor, should_preprocess: bool = False, should_postprocess: bool = False) -> Tuple[torch.Tensor]:
30
+ outputs = self.encode(x, should_preprocess)
31
+ decoder_input = outputs.z + (outputs.z_quantized - outputs.z).detach()
32
+ reconstructions = self.decode(decoder_input, should_postprocess)
33
+ return outputs.z, outputs.z_quantized, reconstructions
34
+
35
+ def encode(self, x: torch.Tensor, should_preprocess: bool = False) -> dict:
36
+ if should_preprocess:
37
+ x = self.preprocess_input(x)
38
+ shape = x.shape # (..., C, H, W)
39
+ x = x.view(-1, *shape[-3:])
40
+ z = self.encoder(x)
41
+ z = self.pre_quant_conv(z)
42
+ b, e, h, w = z.shape
43
+ z_flattened = rearrange(z, 'b e h w -> (b h w) e')
44
+ dist_to_embeddings = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1) - 2 * torch.matmul(z_flattened, self.embedding.weight.t())
45
+
46
+ tokens = dist_to_embeddings.argmin(dim=-1)
47
+ z_q = rearrange(self.embedding(tokens), '(b h w) e -> b e h w', b=b, e=e, h=h, w=w).contiguous()
48
+
49
+ # Reshape to original
50
+ z = z.reshape(*shape[:-3], *z.shape[1:])
51
+ z_q = z_q.reshape(*shape[:-3], *z_q.shape[1:])
52
+ tokens = tokens.reshape(*shape[:-3], -1)
53
+
54
+ return {
55
+ "z": z,
56
+ "z_quantized": z_q,
57
+ "tokens": tokens
58
+ }
59
+
60
+ def decode(self, z_q: torch.Tensor, should_postprocess: bool = False) -> torch.Tensor:
61
+ shape = z_q.shape # (..., E, h, w)
62
+ z_q = z_q.view(-1, *shape[-3:])
63
+ z_q = self.post_quant_conv(z_q)
64
+ rec = self.decoder(z_q)
65
+ rec = rec.reshape(*shape[:-3], *rec.shape[1:])
66
+ if should_postprocess:
67
+ rec = self.postprocess_output(rec)
68
+ return rec
69
+
70
+ @torch.no_grad()
71
+ def encode_decode(self, x: torch.Tensor, should_preprocess: bool = False, should_postprocess: bool = False) -> torch.Tensor:
72
+ z_q = self.encode(x, should_preprocess).z_quantized
73
+ return self.decode(z_q, should_postprocess)
74
+
75
+ def preprocess_input(self, x: torch.Tensor) -> torch.Tensor:
76
+ """x is supposed to be channels first and in [0, 1]"""
77
+ return x.mul(2).sub(1)
78
+
79
+ def postprocess_output(self, y: torch.Tensor) -> torch.Tensor:
80
+ """y is supposed to be channels first and in [-1, 1]"""
81
+ return y.add(1).div(2)
iris/src/world_model.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Tuple
2
+
3
+ from einops import rearrange
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from models.kv_caching import KeysValues
9
+ from models.slicer import Embedder, Head
10
+ from models.transformer import Transformer
11
+
12
+ class WorldModel(nn.Module):
13
+ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: dict) -> None:
14
+ super().__init__()
15
+ self.obs_vocab_size, self.act_vocab_size = obs_vocab_size, act_vocab_size
16
+ self.config = config
17
+ self.transformer = Transformer(config)
18
+
19
+ all_but_last_obs_tokens_pattern = torch.ones(config["tokens_per_block"])
20
+ all_but_last_obs_tokens_pattern[-2] = 0
21
+ act_tokens_pattern = torch.zeros(self.config["tokens_per_block"])
22
+ act_tokens_pattern[-1] = 1
23
+ obs_tokens_pattern = 1 - act_tokens_pattern
24
+
25
+ self.pos_emb = nn.Embedding(config["max_tokens"], config["embed_dim"])
26
+
27
+ self.embedder = Embedder(
28
+ max_blocks=config["max_blocks"],
29
+ block_masks=[act_tokens_pattern, obs_tokens_pattern],
30
+ embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config["embed_dim"]), nn.Embedding(obs_vocab_size, config["embed_dim"])])
31
+ )
32
+
33
+ self.head_observations = Head(
34
+ max_blocks=config["max_blocks"],
35
+ block_mask=all_but_last_obs_tokens_pattern,
36
+ head_module=nn.Sequential(
37
+ nn.Linear(config["embed_dim"], config["embed_dim"]),
38
+ nn.ReLU(),
39
+ nn.Linear(config["embed_dim"], obs_vocab_size)
40
+ )
41
+ )
42
+
43
+ self.head_rewards = Head(
44
+ max_blocks=config["max_blocks"],
45
+ block_mask=act_tokens_pattern,
46
+ head_module=nn.Sequential(
47
+ nn.Linear(config["embed_dim"], config["embed_dim"]),
48
+ nn.ReLU(),
49
+ nn.Linear(config["embed_dim"], 3)
50
+ )
51
+ )
52
+
53
+ self.head_ends = Head(
54
+ max_blocks=config["max_blocks"],
55
+ block_mask=act_tokens_pattern,
56
+ head_module=nn.Sequential(
57
+ nn.Linear(config["embed_dim"], config["embed_dim"]),
58
+ nn.ReLU(),
59
+ nn.Linear(config["embed_dim"], 2)
60
+ )
61
+ )
62
+
63
+ def __repr__(self) -> str:
64
+ return "world_model"
65
+
66
+ def forward(self, tokens: torch.LongTensor, past_keys_values: Optional[KeysValues] = None) -> dict:
67
+
68
+ num_steps = tokens.size(1) # (B, T)
69
+ assert num_steps <= self.config["max_tokens"]
70
+ prev_steps = 0 if past_keys_values is None else past_keys_values.size
71
+
72
+ sequences = self.embedder(tokens, num_steps, prev_steps) + self.pos_emb(prev_steps + torch.arange(num_steps, device=tokens.device))
73
+
74
+ x = self.transformer(sequences, past_keys_values)
75
+
76
+ logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps)
77
+ logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps)
78
+ logits_ends = self.head_ends(x, num_steps=num_steps, prev_steps=prev_steps)
79
+ return {
80
+ "output_sequence": x,
81
+ "logits_observations": logits_observations,
82
+ "logits_rewards": logits_rewards,
83
+ "logits_ends": logits_ends
84
+
85
+ }
86
+
87
+ def compute_labels_world_model(self, obs_tokens: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
88
+ assert torch.all(ends.sum(dim=1) <= 1) # at most 1 done
89
+ mask_fill = torch.logical_not(mask_padding)
90
+ labels_observations = rearrange(obs_tokens.masked_fill(mask_fill.unsqueeze(-1).expand_as(obs_tokens), -100), 'b t k -> b (t k)')[:, 1:]
91
+ labels_rewards = (rewards.sign() + 1).masked_fill(mask_fill, -100).long() # Rewards clipped to {-1, 0, 1}
92
+ labels_ends = ends.masked_fill(mask_fill, -100)
93
+ return labels_observations.reshape(-1), labels_rewards.reshape(-1), labels_ends.reshape(-1)