ShaswatRobotics commited on
Commit
b2ade59
·
verified ·
1 Parent(s): 4a0bd96

Delete iris/models

Browse files
iris/models/__init__.py DELETED
File without changes
iris/models/kv_caching.py DELETED
@@ -1,106 +0,0 @@
1
- from typing import Tuple
2
-
3
- import numpy as np
4
- import torch
5
-
6
-
7
- class Cache:
8
- def __init__(self, num_samples: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device) -> None:
9
- assert embed_dim % num_heads == 0
10
- self._n, self._cache, self._size = num_samples, None, None
11
- self._reset = lambda n: torch.empty(n, num_heads, max_tokens, embed_dim // num_heads, device=device) # (B, nh, T, hs)
12
- self.reset()
13
-
14
- @property
15
- def shape(self) -> Tuple[int, int, int, int]:
16
- n, num_heads, _, head_dim = self._cache.shape
17
- return n, num_heads, self._size, head_dim
18
-
19
- def reset(self) -> None:
20
- self._cache = self._reset(self._n)
21
- self._size = 0
22
-
23
- def prune(self, mask: np.ndarray) -> None:
24
- assert mask.ndim == 1 and mask.shape[0] == self.shape[0]
25
- self._cache = self._cache[mask]
26
- self._n = self._cache.shape[0]
27
-
28
- def get(self) -> torch.Tensor:
29
- return self._cache[:, :, :self._size, :]
30
-
31
- def update(self, x: torch.Tensor) -> None:
32
- assert (x.ndim == self._cache.ndim) and all([x.size(i) == self._cache.size(i) for i in (0, 1, 3)])
33
- assert self._size + x.size(2) <= self._cache.shape[2]
34
- self._cache = AssignWithoutInplaceCheck.apply(self._cache, x, 2, self._size, self._size + x.size(2))
35
- self._size += x.size(2)
36
-
37
-
38
- class KVCache:
39
- def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device) -> None:
40
- self._k_cache = Cache(n, num_heads, max_tokens, embed_dim, device)
41
- self._v_cache = Cache(n, num_heads, max_tokens, embed_dim, device)
42
-
43
- @property
44
- def shape(self) -> Tuple[int, int, int, int]:
45
- return self._k_cache.shape
46
-
47
- def reset(self) -> None:
48
- self._k_cache.reset()
49
- self._v_cache.reset()
50
-
51
- def prune(self, mask: np.ndarray) -> None:
52
- self._k_cache.prune(mask)
53
- self._v_cache.prune(mask)
54
-
55
- def get(self) -> Tuple[torch.Tensor, torch.Tensor]:
56
- return self._k_cache.get(), self._v_cache.get()
57
-
58
- def update(self, k: torch.Tensor, v: torch.Tensor):
59
- self._k_cache.update(k)
60
- self._v_cache.update(v)
61
-
62
-
63
- class KeysValues:
64
- def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, num_layers: int, device: torch.device) -> None:
65
- self._keys_values = tuple([KVCache(n, num_heads, max_tokens, embed_dim, device) for _ in range(num_layers)])
66
-
67
- def __getitem__(self, key: int) -> KVCache:
68
- return self._keys_values[key]
69
-
70
- def __len__(self):
71
- return len(self._keys_values)
72
-
73
- @property
74
- def size(self):
75
- return self._keys_values[0].shape[2]
76
-
77
- def reset(self) -> None:
78
- for kv_cache in self._keys_values:
79
- kv_cache.reset()
80
-
81
- def prune(self, mask: np.ndarray) -> None:
82
- for kv_cache in self._keys_values:
83
- kv_cache.prune(mask)
84
-
85
-
86
- class AssignWithoutInplaceCheck(torch.autograd.Function):
87
- """
88
- Inspired from : https://discuss.pytorch.org/t/disable-in-place-correctness-version-check-any-other-workaround/90738/4
89
- Warning : do not use it to overwrite a slice twice.
90
- """
91
-
92
- @staticmethod
93
- def get_slice(dim: int, start: int, stop: int) -> Tuple[slice]:
94
- return tuple([slice(None), ] * dim + [slice(start, stop)])
95
-
96
- @staticmethod
97
- def forward(ctx, input: torch.Tensor, value: torch.Tensor, dim: int, start: int, stop: int) -> torch.Tensor:
98
- ctx.dim = dim
99
- ctx.start = start
100
- ctx.stop = stop
101
- input.data[AssignWithoutInplaceCheck.get_slice(dim, start, stop)] = value
102
- return input
103
-
104
- @staticmethod
105
- def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor]:
106
- return grad_out, grad_out[AssignWithoutInplaceCheck.get_slice(ctx.dim, ctx.start, ctx.stop)], None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
iris/models/lpips.py DELETED
@@ -1,167 +0,0 @@
1
- """
2
- Credits to https://github.com/CompVis/taming-transformers
3
- """
4
-
5
- from collections import namedtuple
6
- import hashlib
7
- import os
8
- from pathlib import Path
9
- import requests
10
-
11
- import torch
12
- import torch.nn as nn
13
- from torchvision import models
14
- from tqdm import tqdm
15
-
16
-
17
- class LPIPS(nn.Module):
18
- # Learned perceptual metric
19
- def __init__(self, use_dropout: bool = True):
20
- super().__init__()
21
- self.scaling_layer = ScalingLayer()
22
- self.chns = [64, 128, 256, 512, 512] # vg16 features
23
- self.net = vgg16(pretrained=True, requires_grad=False)
24
- self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
25
- self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
26
- self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
27
- self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
28
- self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
29
- self.load_from_pretrained()
30
- for param in self.parameters():
31
- param.requires_grad = False
32
-
33
- def load_from_pretrained(self) -> None:
34
- ckpt = get_ckpt_path(name="vgg_lpips", root=Path.home() / ".cache/iris/tokenizer_pretrained_vgg") # Download VGG if necessary
35
- self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
36
-
37
- def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
38
- in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
39
- outs0, outs1 = self.net(in0_input), self.net(in1_input)
40
- feats0, feats1, diffs = {}, {}, {}
41
- lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
42
- for kk in range(len(self.chns)):
43
- feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
44
- diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
45
-
46
- res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
47
- val = res[0]
48
- for i in range(1, len(self.chns)):
49
- val += res[i]
50
- return val
51
-
52
-
53
- class ScalingLayer(nn.Module):
54
- def __init__(self) -> None:
55
- super(ScalingLayer, self).__init__()
56
- self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
57
- self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
58
-
59
- def forward(self, inp: torch.Tensor) -> torch.Tensor:
60
- return (inp - self.shift) / self.scale
61
-
62
-
63
- class NetLinLayer(nn.Module):
64
- """ A single linear layer which does a 1x1 conv """
65
- def __init__(self, chn_in: int, chn_out: int = 1, use_dropout: bool = False) -> None:
66
- super(NetLinLayer, self).__init__()
67
- layers = [nn.Dropout(), ] if (use_dropout) else []
68
- layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
69
- self.model = nn.Sequential(*layers)
70
-
71
-
72
- class vgg16(torch.nn.Module):
73
- def __init__(self, requires_grad: bool = False, pretrained: bool = True) -> None:
74
- super(vgg16, self).__init__()
75
- vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
76
- self.slice1 = torch.nn.Sequential()
77
- self.slice2 = torch.nn.Sequential()
78
- self.slice3 = torch.nn.Sequential()
79
- self.slice4 = torch.nn.Sequential()
80
- self.slice5 = torch.nn.Sequential()
81
- self.N_slices = 5
82
- for x in range(4):
83
- self.slice1.add_module(str(x), vgg_pretrained_features[x])
84
- for x in range(4, 9):
85
- self.slice2.add_module(str(x), vgg_pretrained_features[x])
86
- for x in range(9, 16):
87
- self.slice3.add_module(str(x), vgg_pretrained_features[x])
88
- for x in range(16, 23):
89
- self.slice4.add_module(str(x), vgg_pretrained_features[x])
90
- for x in range(23, 30):
91
- self.slice5.add_module(str(x), vgg_pretrained_features[x])
92
- if not requires_grad:
93
- for param in self.parameters():
94
- param.requires_grad = False
95
-
96
- def forward(self, X: torch.Tensor) -> torch.Tensor:
97
- h = self.slice1(X)
98
- h_relu1_2 = h
99
- h = self.slice2(h)
100
- h_relu2_2 = h
101
- h = self.slice3(h)
102
- h_relu3_3 = h
103
- h = self.slice4(h)
104
- h_relu4_3 = h
105
- h = self.slice5(h)
106
- h_relu5_3 = h
107
- vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
108
- out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
109
- return out
110
-
111
-
112
- def normalize_tensor(x: torch.Tensor, eps: float = 1e-10) -> torch.Tensor:
113
- norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
114
- return x / (norm_factor + eps)
115
-
116
-
117
- def spatial_average(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor:
118
- return x.mean([2, 3], keepdim=keepdim)
119
-
120
-
121
- # ********************************************************************
122
- # *************** Utilities to download pretrained vgg ***************
123
- # ********************************************************************
124
-
125
-
126
- URL_MAP = {
127
- "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
128
- }
129
-
130
-
131
- CKPT_MAP = {
132
- "vgg_lpips": "vgg.pth"
133
- }
134
-
135
-
136
- MD5_MAP = {
137
- "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
138
- }
139
-
140
-
141
- def download(url: str, local_path: str, chunk_size: int = 1024) -> None:
142
- os.makedirs(os.path.split(local_path)[0], exist_ok=True)
143
- with requests.get(url, stream=True) as r:
144
- total_size = int(r.headers.get("content-length", 0))
145
- with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
146
- with open(local_path, "wb") as f:
147
- for data in r.iter_content(chunk_size=chunk_size):
148
- if data:
149
- f.write(data)
150
- pbar.update(chunk_size)
151
-
152
-
153
- def md5_hash(path: str) -> str:
154
- with open(path, "rb") as f:
155
- content = f.read()
156
- return hashlib.md5(content).hexdigest()
157
-
158
-
159
- def get_ckpt_path(name: str, root: str, check: bool = False) -> str:
160
- assert name in URL_MAP
161
- path = os.path.join(root, CKPT_MAP[name])
162
- if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
163
- print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
164
- download(URL_MAP[name], path)
165
- md5 = md5_hash(path)
166
- assert md5 == MD5_MAP[name], md5
167
- return path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
iris/models/nets.py DELETED
@@ -1,345 +0,0 @@
1
- """
2
- Credits to https://github.com/CompVis/taming-transformers
3
- """
4
-
5
- import torch
6
- import torch.nn as nn
7
-
8
- class Encoder(nn.Module):
9
- def __init__(self, config: dict) -> None:
10
- super().__init__()
11
- self.config = config
12
- self.num_resolutions = len(config["ch_mult"])
13
- temb_ch = 0 # timestep embedding #channels
14
-
15
- # downsampling
16
- self.conv_in = torch.nn.Conv2d(config["in_channels"],
17
- config["ch"],
18
- kernel_size=3,
19
- stride=1,
20
- padding=1)
21
-
22
- curr_res = config["resolution"]
23
- in_ch_mult = (1,) + tuple(config["ch_mult"])
24
- self.down = nn.ModuleList()
25
- for i_level in range(self.num_resolutions):
26
- block = nn.ModuleList()
27
- attn = nn.ModuleList()
28
- block_in = config["ch"] * in_ch_mult[i_level]
29
- block_out = config["ch"] * config["ch_mult"][i_level]
30
- for i_block in range(self.config["num_res_blocks"]):
31
- block.append(ResnetBlock(in_channels=block_in,
32
- out_channels=block_out,
33
- temb_channels=temb_ch,
34
- dropout=config["dropout"]))
35
- block_in = block_out
36
- if curr_res in config["attn_resolutions"]:
37
- attn.append(AttnBlock(block_in))
38
- down = nn.Module()
39
- down.block = block
40
- down.attn = attn
41
- if i_level != self.num_resolutions - 1:
42
- down.downsample = Downsample(block_in, with_conv=True)
43
- curr_res = curr_res // 2
44
- self.down.append(down)
45
-
46
- # middle
47
- self.mid = nn.Module()
48
- self.mid.block_1 = ResnetBlock(in_channels=block_in,
49
- out_channels=block_in,
50
- temb_channels=temb_ch,
51
- dropout=config["dropout"])
52
- self.mid.attn_1 = AttnBlock(block_in)
53
- self.mid.block_2 = ResnetBlock(in_channels=block_in,
54
- out_channels=block_in,
55
- temb_channels=temb_ch,
56
- dropout=config["dropout"])
57
-
58
- # end
59
- self.norm_out = Normalize(block_in)
60
- self.conv_out = torch.nn.Conv2d(block_in,
61
- config["z_channels"],
62
- kernel_size=3,
63
- stride=1,
64
- padding=1)
65
-
66
- def forward(self, x: torch.Tensor) -> torch.Tensor:
67
-
68
- temb = None # timestep embedding
69
-
70
- # downsampling
71
- hs = [self.conv_in(x)]
72
- for i_level in range(self.num_resolutions):
73
- for i_block in range(self.config["num_res_blocks"]):
74
- h = self.down[i_level].block[i_block](hs[-1], temb)
75
- if len(self.down[i_level].attn) > 0:
76
- h = self.down[i_level].attn[i_block](h)
77
- hs.append(h)
78
- if i_level != self.num_resolutions - 1:
79
- hs.append(self.down[i_level].downsample(hs[-1]))
80
-
81
- # middle
82
- h = hs[-1]
83
- h = self.mid.block_1(h, temb)
84
- h = self.mid.attn_1(h)
85
- h = self.mid.block_2(h, temb)
86
-
87
- # end
88
- h = self.norm_out(h)
89
- h = nonlinearity(h)
90
- h = self.conv_out(h)
91
- return h
92
-
93
-
94
- class Decoder(nn.Module):
95
- def __init__(self, config: dict) -> None:
96
- super().__init__()
97
- self.config = config
98
- temb_ch = 0
99
- self.num_resolutions = len(config["ch_mult"])
100
-
101
- # compute in_ch_mult, block_in and curr_res at lowest res
102
- in_ch_mult = (1,) + tuple(config["ch_mult"])
103
- block_in = config["ch"] * config["ch_mult"][self.num_resolutions - 1]
104
- curr_res = config["resolution"] // 2 ** (self.num_resolutions - 1)
105
- print(f"Tokenizer : shape of latent is {config["z_channels"], curr_res, curr_res}.")
106
-
107
- # z to block_in
108
- self.conv_in = torch.nn.Conv2d(config["z_channels"],
109
- block_in,
110
- kernel_size=3,
111
- stride=1,
112
- padding=1)
113
-
114
- # middle
115
- self.mid = nn.Module()
116
- self.mid.block_1 = ResnetBlock(in_channels=block_in,
117
- out_channels=block_in,
118
- temb_channels=temb_ch,
119
- dropout=config["dropout"])
120
- self.mid.attn_1 = AttnBlock(block_in)
121
- self.mid.block_2 = ResnetBlock(in_channels=block_in,
122
- out_channels=block_in,
123
- temb_channels=temb_ch,
124
- dropout=config["dropout"])
125
-
126
- # upsampling
127
- self.up = nn.ModuleList()
128
- for i_level in reversed(range(self.num_resolutions)):
129
- block = nn.ModuleList()
130
- attn = nn.ModuleList()
131
- block_out = config["ch"] * config["ch_mult"][i_level]
132
- for i_block in range(config["num_res_blocks"] + 1):
133
- block.append(ResnetBlock(in_channels=block_in,
134
- out_channels=block_out,
135
- temb_channels=temb_ch,
136
- dropout=config["dropout"]))
137
- block_in = block_out
138
- if curr_res in config["attn_resolutions"]:
139
- attn.append(AttnBlock(block_in))
140
- up = nn.Module()
141
- up.block = block
142
- up.attn = attn
143
- if i_level != 0:
144
- up.upsample = Upsample(block_in, with_conv=True)
145
- curr_res = curr_res * 2
146
- self.up.insert(0, up) # prepend to get consistent order
147
-
148
- # end
149
- self.norm_out = Normalize(block_in)
150
- self.conv_out = torch.nn.Conv2d(block_in,
151
- config["out_ch"],
152
- kernel_size=3,
153
- stride=1,
154
- padding=1)
155
-
156
- def forward(self, z: torch.Tensor) -> torch.Tensor:
157
- temb = None # timestep embedding
158
-
159
- # z to block_in
160
- h = self.conv_in(z)
161
-
162
- # middle
163
- h = self.mid.block_1(h, temb)
164
- h = self.mid.attn_1(h)
165
- h = self.mid.block_2(h, temb)
166
-
167
- # upsampling
168
- for i_level in reversed(range(self.num_resolutions)):
169
- for i_block in range(self.config["num_res_blocks"] + 1):
170
- h = self.up[i_level].block[i_block](h, temb)
171
- if len(self.up[i_level].attn) > 0:
172
- h = self.up[i_level].attn[i_block](h)
173
- if i_level != 0:
174
- h = self.up[i_level].upsample(h)
175
-
176
- # end
177
- h = self.norm_out(h)
178
- h = nonlinearity(h)
179
- h = self.conv_out(h)
180
- return h
181
-
182
-
183
- def nonlinearity(x: torch.Tensor) -> torch.Tensor:
184
- # swish
185
- return x * torch.sigmoid(x)
186
-
187
-
188
- def Normalize(in_channels: int) -> nn.Module:
189
- return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
190
-
191
-
192
- class Upsample(nn.Module):
193
- def __init__(self, in_channels: int, with_conv: bool) -> None:
194
- super().__init__()
195
- self.with_conv = with_conv
196
- if self.with_conv:
197
- self.conv = torch.nn.Conv2d(in_channels,
198
- in_channels,
199
- kernel_size=3,
200
- stride=1,
201
- padding=1)
202
-
203
- def forward(self, x: torch.Tensor) -> torch.Tensor:
204
- x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
205
- if self.with_conv:
206
- x = self.conv(x)
207
- return x
208
-
209
-
210
- class Downsample(nn.Module):
211
- def __init__(self, in_channels: int, with_conv: bool) -> None:
212
- super().__init__()
213
- self.with_conv = with_conv
214
- if self.with_conv:
215
- # no asymmetric padding in torch conv, must do it ourselves
216
- self.conv = torch.nn.Conv2d(in_channels,
217
- in_channels,
218
- kernel_size=3,
219
- stride=2,
220
- padding=0)
221
-
222
- def forward(self, x: torch.Tensor) -> torch.Tensor:
223
- if self.with_conv:
224
- pad = (0, 1, 0, 1)
225
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
226
- x = self.conv(x)
227
- else:
228
- x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
229
- return x
230
-
231
-
232
- class ResnetBlock(nn.Module):
233
- def __init__(self, *, in_channels: int, out_channels: int = None, conv_shortcut: bool = False,
234
- dropout: float, temb_channels: int = 512) -> None:
235
- super().__init__()
236
- self.in_channels = in_channels
237
- out_channels = in_channels if out_channels is None else out_channels
238
- self.out_channels = out_channels
239
- self.use_conv_shortcut = conv_shortcut
240
-
241
- self.norm1 = Normalize(in_channels)
242
- self.conv1 = torch.nn.Conv2d(in_channels,
243
- out_channels,
244
- kernel_size=3,
245
- stride=1,
246
- padding=1)
247
- if temb_channels > 0:
248
- self.temb_proj = torch.nn.Linear(temb_channels,
249
- out_channels)
250
- self.norm2 = Normalize(out_channels)
251
- self.dropout = torch.nn.Dropout(dropout)
252
- self.conv2 = torch.nn.Conv2d(out_channels,
253
- out_channels,
254
- kernel_size=3,
255
- stride=1,
256
- padding=1)
257
- if self.in_channels != self.out_channels:
258
- if self.use_conv_shortcut:
259
- self.conv_shortcut = torch.nn.Conv2d(in_channels,
260
- out_channels,
261
- kernel_size=3,
262
- stride=1,
263
- padding=1)
264
- else:
265
- self.nin_shortcut = torch.nn.Conv2d(in_channels,
266
- out_channels,
267
- kernel_size=1,
268
- stride=1,
269
- padding=0)
270
-
271
- def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
272
- h = x
273
- h = self.norm1(h)
274
- h = nonlinearity(h)
275
- h = self.conv1(h)
276
-
277
- if temb is not None:
278
- h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
279
-
280
- h = self.norm2(h)
281
- h = nonlinearity(h)
282
- h = self.dropout(h)
283
- h = self.conv2(h)
284
-
285
- if self.in_channels != self.out_channels:
286
- if self.use_conv_shortcut:
287
- x = self.conv_shortcut(x)
288
- else:
289
- x = self.nin_shortcut(x)
290
-
291
- return x + h
292
-
293
-
294
- class AttnBlock(nn.Module):
295
- def __init__(self, in_channels: int) -> None:
296
- super().__init__()
297
- self.in_channels = in_channels
298
-
299
- self.norm = Normalize(in_channels)
300
- self.q = torch.nn.Conv2d(in_channels,
301
- in_channels,
302
- kernel_size=1,
303
- stride=1,
304
- padding=0)
305
- self.k = torch.nn.Conv2d(in_channels,
306
- in_channels,
307
- kernel_size=1,
308
- stride=1,
309
- padding=0)
310
- self.v = torch.nn.Conv2d(in_channels,
311
- in_channels,
312
- kernel_size=1,
313
- stride=1,
314
- padding=0)
315
- self.proj_out = torch.nn.Conv2d(in_channels,
316
- in_channels,
317
- kernel_size=1,
318
- stride=1,
319
- padding=0)
320
-
321
- def forward(self, x: torch.Tensor) -> torch.Tensor:
322
- h_ = x
323
- h_ = self.norm(h_)
324
- q = self.q(h_)
325
- k = self.k(h_)
326
- v = self.v(h_)
327
-
328
- # compute attention
329
- b, c, h, w = q.shape
330
- q = q.reshape(b, c, h * w)
331
- q = q.permute(0, 2, 1) # b,hw,c
332
- k = k.reshape(b, c, h * w) # b,c,hw
333
- w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
334
- w_ = w_ * (int(c) ** (-0.5))
335
- w_ = torch.nn.functional.softmax(w_, dim=2)
336
-
337
- # attend to values
338
- v = v.reshape(b, c, h * w)
339
- w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
340
- h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
341
- h_ = h_.reshape(b, c, h, w)
342
-
343
- h_ = self.proj_out(h_)
344
-
345
- return x + h_
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
iris/models/slicer.py DELETED
@@ -1,53 +0,0 @@
1
- import math
2
- from typing import List
3
-
4
- import torch
5
- import torch.nn as nn
6
-
7
- class Slicer(nn.Module):
8
- def __init__(self, max_blocks: int, block_mask: torch.Tensor) -> None:
9
- super().__init__()
10
- self.block_size = block_mask.size(0)
11
- self.num_kept_tokens = block_mask.sum().long().item()
12
- kept_indices = torch.where(block_mask)[0].repeat(max_blocks)
13
- offsets = torch.arange(max_blocks).repeat_interleave(self.num_kept_tokens)
14
- self.register_buffer('indices', kept_indices + block_mask.size(0) * offsets)
15
-
16
- def compute_slice(self, num_steps: int, prev_steps: int = 0) -> torch.Tensor:
17
- total_steps = num_steps + prev_steps
18
- num_blocks = math.ceil(total_steps / self.block_size)
19
- indices = self.indices[:num_blocks * self.num_kept_tokens]
20
- return indices[torch.logical_and(prev_steps <= indices, indices < total_steps)] - prev_steps
21
-
22
- def forward(self, *args, **kwargs):
23
- raise NotImplementedError
24
-
25
-
26
- class Head(Slicer):
27
- def __init__(self, max_blocks: int, block_mask: torch.Tensor, head_module: nn.Module) -> None:
28
- super().__init__(max_blocks, block_mask)
29
- assert isinstance(head_module, nn.Module)
30
- self.head_module = head_module
31
-
32
- def forward(self, x: torch.Tensor, num_steps: int, prev_steps: int) -> torch.Tensor:
33
- x_sliced = x[:, self.compute_slice(num_steps, prev_steps)] # x is (B, T, E)
34
- return self.head_module(x_sliced)
35
-
36
-
37
- class Embedder(nn.Module):
38
- def __init__(self, max_blocks: int, block_masks: List[torch.Tensor], embedding_tables: List[nn.Embedding]) -> None:
39
- super().__init__()
40
- assert len(block_masks) == len(embedding_tables)
41
- assert (sum(block_masks) == 1).all() # block mask are a partition of a block
42
- self.embedding_dim = embedding_tables[0].embedding_dim
43
- assert all([e.embedding_dim == self.embedding_dim for e in embedding_tables])
44
- self.embedding_tables = embedding_tables
45
- self.slicers = [Slicer(max_blocks, block_mask) for block_mask in block_masks]
46
-
47
- def forward(self, tokens: torch.Tensor, num_steps: int, prev_steps: int) -> torch.Tensor:
48
- assert tokens.ndim == 2 # x is (B, T)
49
- output = torch.zeros(*tokens.size(), self.embedding_dim, device=tokens.device)
50
- for slicer, emb in zip(self.slicers, self.embedding_tables):
51
- s = slicer.compute_slice(num_steps, prev_steps)
52
- output[:, s] = emb(tokens[:, s])
53
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
iris/models/transformer.py DELETED
@@ -1,101 +0,0 @@
1
- """
2
- Credits to https://github.com/karpathy/minGPT
3
- """
4
-
5
- from dataclasses import dataclass
6
- import math
7
- from typing import Optional
8
-
9
- from einops import rearrange
10
- import torch
11
- import torch.nn as nn
12
- from torch.nn import functional as F
13
-
14
- from .kv_caching import KeysValues, KVCache
15
-
16
- class Transformer(nn.Module):
17
- def __init__(self, config: dict) -> None:
18
- super().__init__()
19
- self.config = config
20
- self.config["max_tokens"] = config["tokens_per_block"] * config["max_blocks"]
21
- self.drop = nn.Dropout(config["embed_pdrop"])
22
- self.blocks = nn.ModuleList([Block(config) for _ in range(config["num_layers"])])
23
- self.ln_f = nn.LayerNorm(config["embed_dim"])
24
-
25
- def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues:
26
- device = self.ln_f.weight.device # Assumption that all submodules are on the same device
27
- return KeysValues(n, self.config["num_heads"], max_tokens, self.config["embed_dim"], self.config["num_layers"], device)
28
-
29
- def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues] = None) -> torch.Tensor:
30
- assert past_keys_values is None or len(past_keys_values) == len(self.blocks)
31
- x = self.drop(sequences)
32
- for i, block in enumerate(self.blocks):
33
- x = block(x, None if past_keys_values is None else past_keys_values[i])
34
-
35
- x = self.ln_f(x)
36
- return x
37
-
38
-
39
- class Block(nn.Module):
40
- def __init__(self, config: dict) -> None:
41
- super().__init__()
42
- self.ln1 = nn.LayerNorm(config["embed_dim"])
43
- self.ln2 = nn.LayerNorm(config["embed_dim"])
44
- self.attn = SelfAttention(config)
45
- self.mlp = nn.Sequential(
46
- nn.Linear(config["embed_dim"], 4 * config["embed_dim"]),
47
- nn.GELU(),
48
- nn.Linear(4 * config["embed_dim"], config["embed_dim"]),
49
- nn.Dropout(config["resid_pdrop"]),
50
- )
51
-
52
- def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None) -> torch.Tensor:
53
- x_attn = self.attn(self.ln1(x), past_keys_values)
54
- x = x + x_attn
55
- x = x + self.mlp(self.ln2(x))
56
- return x
57
-
58
-
59
- class SelfAttention(nn.Module):
60
- def __init__(self, config: dict) -> None:
61
- super().__init__()
62
- assert config["embed_dim"] % config["num_heads"] == 0
63
- assert config["attention"] in ('causal', 'block_causal')
64
- self.num_heads = config["num_heads"]
65
- self.key = nn.Linear(config["embed_dim"], config["embed_dim"])
66
- self.query = nn.Linear(config["embed_dim"], config["embed_dim"])
67
- self.value = nn.Linear(config["embed_dim"], config["embed_dim"])
68
- self.attn_drop = nn.Dropout(config["attn_pdrop"])
69
- self.resid_drop = nn.Dropout(config["resid_pdrop"])
70
- self.proj = nn.Linear(config["embed_dim"], config["embed_dim"])
71
-
72
- causal_mask = torch.tril(torch.ones(config["max_tokens"], config["max_tokens"]))
73
- block_causal_mask = torch.max(causal_mask, torch.block_diag(*[torch.ones(config["tokens_per_block"], config["tokens_per_block"]) for _ in range(config["max_blocks"])]))
74
- self.register_buffer('mask', causal_mask if config["attention"] == 'causal' else block_causal_mask)
75
-
76
- def forward(self, x: torch.Tensor, kv_cache: Optional[KVCache] = None) -> torch.Tensor:
77
- B, T, C = x.size()
78
- if kv_cache is not None:
79
- b, nh, L, c = kv_cache.shape
80
- assert nh == self.num_heads and b == B and c * nh == C
81
- else:
82
- L = 0
83
-
84
- q = self.query(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
85
- k = self.key(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
86
- v = self.value(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs)
87
-
88
- if kv_cache is not None:
89
- kv_cache.update(k, v)
90
- k, v = kv_cache.get()
91
-
92
- att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
93
- att = att.masked_fill(self.mask[L:L + T, :L + T] == 0, float('-inf'))
94
- att = F.softmax(att, dim=-1)
95
- att = self.attn_drop(att)
96
- y = att @ v
97
- y = rearrange(y, 'b h t e -> b t (h e)')
98
-
99
- y = self.resid_drop(self.proj(y))
100
-
101
- return y