ShaswatRobotics commited on
Commit
23bc32f
·
verified ·
1 Parent(s): 3e82cab

Upload 35 files

Browse files
Files changed (35) hide show
  1. delta-iris/src/data/__init__.py +7 -0
  2. delta-iris/src/data/__pycache__/__init__.cpython-310.pyc +0 -0
  3. delta-iris/src/data/__pycache__/batch.cpython-310.pyc +0 -0
  4. delta-iris/src/data/__pycache__/dataset.cpython-310.pyc +0 -0
  5. delta-iris/src/data/__pycache__/episode.cpython-310.pyc +0 -0
  6. delta-iris/src/data/__pycache__/episode_count.cpython-310.pyc +0 -0
  7. delta-iris/src/data/__pycache__/sampler.cpython-310.pyc +0 -0
  8. delta-iris/src/data/__pycache__/segment.cpython-310.pyc +0 -0
  9. delta-iris/src/data/__pycache__/utils.cpython-310.pyc +0 -0
  10. delta-iris/src/data/batch.py +24 -0
  11. delta-iris/src/data/dataset.py +104 -0
  12. delta-iris/src/data/episode.py +41 -0
  13. delta-iris/src/data/episode_count.py +41 -0
  14. delta-iris/src/data/sampler.py +42 -0
  15. delta-iris/src/data/segment.py +25 -0
  16. delta-iris/src/data/utils.py +69 -0
  17. delta-iris/src/models/__init__.py +1 -0
  18. delta-iris/src/models/__pycache__/__init__.cpython-310.pyc +0 -0
  19. delta-iris/src/models/__pycache__/convnet.cpython-310.pyc +0 -0
  20. delta-iris/src/models/__pycache__/kv_caching.cpython-310.pyc +0 -0
  21. delta-iris/src/models/__pycache__/slicer.cpython-310.pyc +0 -0
  22. delta-iris/src/models/__pycache__/transformer.cpython-310.pyc +0 -0
  23. delta-iris/src/models/__pycache__/world_model.cpython-310.pyc +0 -0
  24. delta-iris/src/models/convnet.py +114 -0
  25. delta-iris/src/models/kv_caching.py +106 -0
  26. delta-iris/src/models/slicer.py +55 -0
  27. delta-iris/src/models/tokenizer/__init__.py +1 -0
  28. delta-iris/src/models/tokenizer/__pycache__/__init__.cpython-310.pyc +0 -0
  29. delta-iris/src/models/tokenizer/__pycache__/quantizer.cpython-310.pyc +0 -0
  30. delta-iris/src/models/tokenizer/__pycache__/tokenizer.cpython-310.pyc +0 -0
  31. delta-iris/src/models/tokenizer/quantizer.py +95 -0
  32. delta-iris/src/models/transformer.py +157 -0
  33. delta-iris/src/models/utils.py +198 -0
  34. delta-iris/src/tokenizer.py +115 -0
  35. delta-iris/src/world_model.py +139 -0
delta-iris/src/data/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .batch import Batch
2
+ from .dataset import EpisodeDataset
3
+ from .episode import Episode
4
+ from .episode_count import EpisodeCountManager
5
+ from .sampler import BatchSampler
6
+ from .segment import SegmentId
7
+ from .utils import collate_segments_to_batch, DatasetTraverser, make_segment
delta-iris/src/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (543 Bytes). View file
 
delta-iris/src/data/__pycache__/batch.cpython-310.pyc ADDED
Binary file (1.46 kB). View file
 
delta-iris/src/data/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (4.9 kB). View file
 
delta-iris/src/data/__pycache__/episode.cpython-310.pyc ADDED
Binary file (1.8 kB). View file
 
delta-iris/src/data/__pycache__/episode_count.cpython-310.pyc ADDED
Binary file (2.78 kB). View file
 
delta-iris/src/data/__pycache__/sampler.cpython-310.pyc ADDED
Binary file (1.96 kB). View file
 
delta-iris/src/data/__pycache__/segment.cpython-310.pyc ADDED
Binary file (1.06 kB). View file
 
delta-iris/src/data/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.98 kB). View file
 
delta-iris/src/data/batch.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import List
4
+
5
+ import torch
6
+
7
+ from .segment import SegmentId
8
+
9
+
10
+ @dataclass
11
+ class Batch:
12
+ observations: torch.ByteTensor
13
+ actions: torch.LongTensor
14
+ rewards: torch.FloatTensor
15
+ ends: torch.LongTensor
16
+ mask_padding: torch.BoolTensor
17
+ segment_ids: List[SegmentId]
18
+
19
+ def pin_memory(self) -> Batch:
20
+ return Batch(**{k: v if k == 'segment_ids' else v.pin_memory() for k, v in self.__dict__.items()})
21
+
22
+ def to(self, device: torch.device) -> Batch:
23
+ return Batch(**{k: v if k == 'segment_ids' else v.to(device) for k, v in self.__dict__.items()})
24
+
delta-iris/src/data/dataset.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import shutil
3
+ from typing import Dict, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from .episode import Episode
9
+ from .segment import Segment, SegmentId
10
+ from .utils import make_segment
11
+
12
+
13
+ class EpisodeDataset(torch.utils.data.Dataset):
14
+ def __init__(self, directory: Path, name: str) -> None:
15
+ super().__init__()
16
+ self.name = name
17
+ self.directory = Path(directory)
18
+ self.num_episodes, self.num_steps, self.start_idx, self.lengths = None, None, None, None
19
+
20
+ if not self.directory.is_dir():
21
+ self._init_empty()
22
+ else:
23
+ self._load_info()
24
+ print(f'({name}) {self.num_episodes} episodes, {self.num_steps} steps.')
25
+
26
+ @property
27
+ def info_path(self) -> Path:
28
+ return self.directory / 'info.pt'
29
+
30
+ @property
31
+ def info(self) -> Dict[str, Union[int, np.ndarray]]:
32
+ return {'num_episodes': self.num_episodes, 'num_steps': self.num_steps, 'start_idx': self.start_idx, 'lengths': self.lengths}
33
+
34
+ def __len__(self) -> int:
35
+ return self.num_steps
36
+
37
+ def __getitem__(self, segment_id: SegmentId) -> Segment:
38
+ return self._load_segment(segment_id)
39
+
40
+ def _init_empty(self) -> None:
41
+ self.directory.mkdir(parents=True, exist_ok=False)
42
+ self.num_episodes = 0
43
+ self.num_steps = 0
44
+ self.start_idx = np.array([], dtype=np.int64)
45
+ self.lengths = np.array([], dtype=np.int64)
46
+ self.save_info()
47
+
48
+ def _load_info(self) -> None:
49
+ info = torch.load(self.info_path)
50
+ self.num_steps = info['num_steps']
51
+ self.num_episodes = info['num_episodes']
52
+ self.start_idx = info['start_idx']
53
+ self.lengths = info['lengths']
54
+
55
+ def save_info(self) -> None:
56
+ torch.save(self.info, self.info_path)
57
+
58
+ def clear(self) -> None:
59
+ shutil.rmtree(self.directory)
60
+ self._init_empty()
61
+
62
+ def _get_episode_path(self, episode_id: int) -> Path:
63
+ n = 3 # number of hierarchies
64
+ powers = np.arange(n)
65
+ subfolders = list(map(int, np.floor((episode_id % 10 ** (1 + powers)) / 10 ** powers) * 10 ** powers))[::-1]
66
+ return self.directory / '/'.join(list(map(lambda x: f'{x[1]:0{n - x[0]}d}', enumerate(subfolders)))) / f'{episode_id}.pt'
67
+
68
+ def _load_segment(self, segment_id: SegmentId, should_pad: bool = True) -> Segment:
69
+ episode = self.load_episode(segment_id.episode_id)
70
+ return make_segment(episode, segment_id, should_pad)
71
+
72
+ def load_episode(self, episode_id: int) -> Episode:
73
+ return Episode(**torch.load(self._get_episode_path(episode_id)))
74
+
75
+ def add_episode(self, episode: Episode, *, episode_id: Optional[int] = None) -> int:
76
+ if episode_id is None:
77
+ episode_id = self.num_episodes
78
+ self.start_idx = np.concatenate((self.start_idx, np.array([self.num_steps])))
79
+ self.lengths = np.concatenate((self.lengths, np.array([len(episode)])))
80
+ self.num_steps += len(episode)
81
+ self.num_episodes += 1
82
+
83
+ else:
84
+ assert episode_id < self.num_episodes
85
+ old_episode = self.load_episode(episode_id)
86
+ episode = old_episode.merge(episode)
87
+ incr_num_steps = len(episode) - len(old_episode)
88
+ self.lengths[episode_id] = len(episode)
89
+ self.start_idx[episode_id + 1:] += incr_num_steps
90
+ self.num_steps += incr_num_steps
91
+
92
+ episode_path = self._get_episode_path(episode_id)
93
+ episode_path.parent.mkdir(parents=True, exist_ok=True)
94
+ torch.save(episode.__dict__, episode_path.with_suffix('.tmp'))
95
+ episode_path.with_suffix('.tmp').rename(episode_path)
96
+
97
+ return episode_id
98
+
99
+ def get_episode_id_from_global_idx(self, global_idx: np.ndarray) -> np.ndarray:
100
+ return (np.argmax(self.start_idx.reshape(-1, 1) > global_idx, axis=0) - 1) % self.num_episodes
101
+
102
+ def get_global_idx_from_segment_id(self, segment_id: SegmentId) -> np.ndarray:
103
+ start_idx = self.start_idx[segment_id.episode_id]
104
+ return np.arange(start_idx + segment_id.start, start_idx + segment_id.stop)
delta-iris/src/data/episode.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+
6
+
7
+ @dataclass
8
+ class EpisodeMetrics:
9
+ episode_length: int
10
+ episode_return: float
11
+
12
+
13
+ @dataclass
14
+ class Episode:
15
+ observations: torch.ByteTensor
16
+ actions: torch.LongTensor
17
+ rewards: torch.FloatTensor
18
+ ends: torch.LongTensor
19
+
20
+ def __post_init__(self):
21
+ assert len(self.observations) == len(self.actions) == len(self.rewards) == len(self.ends)
22
+ if self.ends.sum() > 0:
23
+ idx_end = torch.argmax(self.ends) + 1
24
+ self.observations = self.observations[:idx_end]
25
+ self.actions = self.actions[:idx_end]
26
+ self.rewards = self.rewards[:idx_end]
27
+ self.ends = self.ends[:idx_end]
28
+
29
+ def __len__(self) -> int:
30
+ return self.observations.size(0)
31
+
32
+ def merge(self, other: Episode) -> Episode:
33
+ return Episode(
34
+ torch.cat((self.observations, other.observations), dim=0),
35
+ torch.cat((self.actions, other.actions), dim=0),
36
+ torch.cat((self.rewards, other.rewards), dim=0),
37
+ torch.cat((self.ends, other.ends), dim=0),
38
+ )
39
+
40
+ def compute_metrics(self) -> EpisodeMetrics:
41
+ return EpisodeMetrics(len(self), self.rewards.sum())
delta-iris/src/data/episode_count.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from .dataset import EpisodeDataset
8
+
9
+
10
+ class EpisodeCountManager:
11
+ def __init__(self, dataset: EpisodeDataset) -> None:
12
+ self.dataset = dataset
13
+ self.all_counts = dict()
14
+
15
+ def load(self, path_to_checkpoint: Path) -> None:
16
+ self.all_counts = torch.load(path_to_checkpoint)
17
+ assert all([counts.shape[0] == self.dataset.num_episodes for counts in self.all_counts.values()])
18
+
19
+ def save(self, path_to_checkpoint: Path) -> None:
20
+ torch.save(self.all_counts, path_to_checkpoint)
21
+
22
+ def register(self, *keys: Tuple[str]) -> None:
23
+ assert all([key not in self.all_counts for key in keys])
24
+ self.all_counts.update({key: np.zeros(self.dataset.num_episodes, dtype=np.int64) for key in keys})
25
+
26
+ def add_episode(self, episode_id: int) -> None:
27
+ for key, counts in self.all_counts.items():
28
+ assert episode_id <= counts.shape[0]
29
+ if episode_id == counts.shape[0]:
30
+ self.all_counts[key] = np.concatenate((counts, np.zeros(1, dtype=np.int64)))
31
+ assert self.all_counts[key].shape[0] == self.dataset.num_episodes
32
+
33
+ def increment_episode_count(self, key: str, episode_id: int) -> None:
34
+ assert key in self.all_counts
35
+ self.all_counts[key][episode_id] += 1
36
+
37
+ def compute_probabilities(self, key: str, alpha: float) -> np.ndarray:
38
+ assert key in self.all_counts
39
+ inverse_counts = 1 / (1 + self.all_counts[key])
40
+ p = inverse_counts ** alpha
41
+ return p / p.sum()
delta-iris/src/data/sampler.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Generator, List
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from .dataset import EpisodeDataset
7
+ from .segment import SegmentId
8
+
9
+
10
+ class BatchSampler(torch.utils.data.Sampler):
11
+ def __init__(self, dataset: EpisodeDataset, num_steps_per_epoch: int, batch_size: int, sequence_length: int, can_sample_beyond_end: bool) -> None:
12
+ super().__init__(dataset)
13
+ self.dataset = dataset
14
+ self.probabilities = None
15
+ self.num_steps_per_epoch = num_steps_per_epoch
16
+ self.batch_size = batch_size
17
+ self.sequence_length = sequence_length
18
+ self.can_sample_beyond_end = can_sample_beyond_end
19
+
20
+ def __len__(self) -> int:
21
+ return self.num_steps_per_epoch
22
+
23
+ def __iter__(self) -> Generator[List[SegmentId], None, None]:
24
+ for _ in range(self.num_steps_per_epoch):
25
+ yield self.sample()
26
+
27
+ def sample(self) -> List[SegmentId]:
28
+ episode_ids = np.random.choice(np.arange(self.dataset.num_episodes), size=self.batch_size, replace=True, p=self.probabilities)
29
+ timesteps = np.random.randint(low=0, high=self.dataset.lengths[episode_ids])
30
+
31
+ # padding allowed, both before start and after end
32
+ if self.can_sample_beyond_end:
33
+ starts = timesteps - np.random.randint(0, self.sequence_length, len(timesteps))
34
+ stops = starts + self.sequence_length
35
+
36
+ # padding allowed only before start
37
+ else:
38
+ stops = np.minimum(self.dataset.lengths[episode_ids], timesteps + 1 + np.random.randint(0, self.sequence_length, len(timesteps)))
39
+ starts = stops - self.sequence_length
40
+
41
+ return list(map(lambda x: SegmentId(*x), zip(episode_ids, starts, stops)))
42
+
delta-iris/src/data/segment.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+
6
+
7
+ @dataclass
8
+ class SegmentId:
9
+ episode_id: int
10
+ start: int
11
+ stop: int
12
+
13
+
14
+ @dataclass
15
+ class Segment:
16
+ observations: torch.ByteTensor
17
+ actions: torch.LongTensor
18
+ rewards: torch.FloatTensor
19
+ ends: torch.LongTensor
20
+ mask_padding: torch.BoolTensor
21
+ id: SegmentId
22
+
23
+ @property
24
+ def effective_size(self) -> int:
25
+ return self.mask_padding.sum().item()
delta-iris/src/data/utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Generator, List
3
+
4
+ import torch
5
+
6
+ from .batch import Batch
7
+ from .episode import Episode
8
+ from .segment import Segment, SegmentId
9
+
10
+
11
+ def collate_segments_to_batch(segments: List[Segment]) -> Batch:
12
+ return Batch(
13
+ torch.stack(list(map(lambda s: s.observations, segments))).div(255),
14
+ torch.stack(list(map(lambda s: s.actions, segments))),
15
+ torch.stack(list(map(lambda s: s.rewards, segments))),
16
+ torch.stack(list(map(lambda s: s.ends, segments))),
17
+ torch.stack(list(map(lambda s: s.mask_padding, segments))),
18
+ list(map(lambda segment: segment.id, segments))
19
+ )
20
+
21
+
22
+ def make_segment(episode: Episode, segment_id: SegmentId, should_pad: bool = True) -> Segment:
23
+ assert segment_id.start < len(episode) and segment_id.stop > 0 and segment_id.start < segment_id.stop
24
+ padding_length_right = max(0, segment_id.stop - len(episode))
25
+ padding_length_left = max(0, -segment_id.start)
26
+ assert padding_length_right == padding_length_left == 0 or should_pad
27
+
28
+ def pad(x):
29
+ pad_right = torch.nn.functional.pad(x, [0 for _ in range(2 * x.ndim - 1)] + [padding_length_right]) if padding_length_right > 0 else x
30
+ return torch.nn.functional.pad(pad_right, [0 for _ in range(2 * x.ndim - 2)] + [padding_length_left, 0]) if padding_length_left > 0 else pad_right
31
+
32
+ start = max(0, segment_id.start)
33
+ stop = min(len(episode), segment_id.stop)
34
+
35
+ return Segment(
36
+ pad(episode.observations[start:stop]),
37
+ pad(episode.actions[start:stop]),
38
+ pad(episode.rewards[start:stop]),
39
+ pad(episode.ends[start:stop]),
40
+ mask_padding=torch.cat((torch.zeros(padding_length_left), torch.ones(stop - start), torch.zeros(padding_length_right))).bool(),
41
+ id=SegmentId(segment_id.episode_id, start, stop)
42
+ )
43
+
44
+
45
+ class DatasetTraverser:
46
+ def __init__(self, dataset, batch_num_samples: int, chunk_size: int) -> None:
47
+ self.dataset = dataset
48
+ self.batch_num_samples = batch_num_samples
49
+ self.chunk_size = chunk_size
50
+ self._num_batches = math.ceil(sum([math.ceil(dataset.lengths[episode_id] / chunk_size) - int(dataset.lengths[episode_id] % chunk_size == 1) for episode_id in range(dataset.num_episodes)]) / batch_num_samples)
51
+
52
+ def __len__(self) -> int:
53
+ return self._num_batches
54
+
55
+ def __iter__(self) -> Generator[Batch, None, None]:
56
+ chunks = []
57
+
58
+ for episode_id in range(self.dataset.num_episodes):
59
+ episode = self.dataset.load_episode(episode_id)
60
+ chunks.extend(make_segment(episode, SegmentId(episode_id, start=i * self.chunk_size, stop=(i + 1) * self.chunk_size), should_pad=True) for i in range(math.ceil(len(episode) / self.chunk_size)))
61
+ if chunks[-1].effective_size < 2:
62
+ chunks.pop()
63
+
64
+ while len(chunks) >= self.batch_num_samples:
65
+ yield collate_segments_to_batch(chunks[:self.batch_num_samples])
66
+ chunks = chunks[self.batch_num_samples:]
67
+
68
+ if len(chunks) > 0:
69
+ yield collate_segments_to_batch(chunks)
delta-iris/src/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .tokenizer import Tokenizer
delta-iris/src/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (315 Bytes). View file
 
delta-iris/src/models/__pycache__/convnet.cpython-310.pyc ADDED
Binary file (4.33 kB). View file
 
delta-iris/src/models/__pycache__/kv_caching.cpython-310.pyc ADDED
Binary file (5.78 kB). View file
 
delta-iris/src/models/__pycache__/slicer.cpython-310.pyc ADDED
Binary file (3.3 kB). View file
 
delta-iris/src/models/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (6.47 kB). View file
 
delta-iris/src/models/__pycache__/world_model.cpython-310.pyc ADDED
Binary file (5.2 kB). View file
 
delta-iris/src/models/convnet.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List
3
+
4
+ from einops import rearrange
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+ @dataclass
11
+ class FrameCnnConfig:
12
+ image_channels: int
13
+ latent_dim: int
14
+ num_channels: int
15
+ mult: List[int]
16
+ down: List[int]
17
+
18
+
19
+ class FrameEncoder(nn.Module):
20
+ def __init__(self, config: FrameCnnConfig) -> None:
21
+ super().__init__()
22
+
23
+ assert len(config.mult) == len(config.down)
24
+ encoder_layers = [nn.Conv2d(config.image_channels, config.num_channels, kernel_size=3, stride=1, padding=1)]
25
+ input_channels = config.num_channels
26
+
27
+ for m, d in zip(config.mult, config.down):
28
+ output_channels = m * config.num_channels
29
+ encoder_layers.append(ResidualBlock(input_channels, output_channels))
30
+ input_channels = output_channels
31
+ if d:
32
+ encoder_layers.append(Downsample(output_channels))
33
+ encoder_layers.extend([
34
+ nn.GroupNorm(num_groups=32, num_channels=input_channels),
35
+ nn.SiLU(inplace=True),
36
+ nn.Conv2d(input_channels, config.latent_dim, kernel_size=3, stride=1, padding=1)
37
+ ])
38
+ self.encoder = nn.Sequential(*encoder_layers)
39
+
40
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
41
+ b, t, _, _, _ = x.size()
42
+ x = rearrange(x, 'b t c h w -> (b t) c h w')
43
+ x = self.encoder(x)
44
+ x = rearrange(x, '(b t) c h w -> b t c h w', b=b, t=t)
45
+
46
+ return x
47
+
48
+
49
+ class FrameDecoder(nn.Module):
50
+ def __init__(self, config: FrameCnnConfig) -> None:
51
+ super().__init__()
52
+
53
+ assert len(config.mult) == len(config.down)
54
+ decoder_layers = []
55
+ output_channels = config.num_channels
56
+
57
+ for m, d in zip(config.mult, config.down):
58
+ input_channels = m * config.num_channels
59
+ decoder_layers.append(ResidualBlock(input_channels, output_channels))
60
+ output_channels = input_channels
61
+ if d:
62
+ decoder_layers.append(Upsample(input_channels))
63
+ decoder_layers.reverse()
64
+ decoder_layers.insert(0, nn.Conv2d(config.latent_dim, input_channels, kernel_size=3, stride=1, padding=1))
65
+ decoder_layers.extend([
66
+ nn.GroupNorm(num_groups=32, num_channels=config.num_channels),
67
+ nn.SiLU(inplace=True),
68
+ nn.Conv2d(config.num_channels, config.image_channels, kernel_size=3, stride=1, padding=1)
69
+ ])
70
+ self.decoder = nn.Sequential(*decoder_layers)
71
+
72
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
73
+ b, t, _, _, _ = x.size()
74
+ x = rearrange(x, 'b t c h w -> (b t) c h w')
75
+ x = self.decoder(x)
76
+ x = rearrange(x, '(b t) c h w -> b t c h w', b=b, t=t)
77
+ return x
78
+
79
+
80
+ class ResidualBlock(nn.Module):
81
+ def __init__(self, in_channels: int, out_channels: int, num_groups_norm: int = 32) -> None:
82
+ super().__init__()
83
+
84
+ self.f = nn.Sequential(
85
+ nn.GroupNorm(num_groups_norm, in_channels),
86
+ nn.SiLU(inplace=True),
87
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
88
+ nn.GroupNorm(num_groups_norm, out_channels),
89
+ nn.SiLU(inplace=True),
90
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
91
+ )
92
+ self.skip_projection = nn.Identity() if in_channels == out_channels else torch.nn.Conv2d(in_channels, out_channels, kernel_size=1)
93
+
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ return self.skip_projection(x) + self.f(x)
96
+
97
+
98
+ class Downsample(nn.Module):
99
+ def __init__(self, num_channels: int) -> None:
100
+ super().__init__()
101
+ self.conv = nn.Conv2d(num_channels, num_channels, kernel_size=2, stride=2, padding=0)
102
+
103
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
104
+ return self.conv(x)
105
+
106
+
107
+ class Upsample(nn.Module):
108
+ def __init__(self, num_channels: int) -> None:
109
+ super().__init__()
110
+ self.conv = nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1)
111
+
112
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
113
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
114
+ return self.conv(x)
delta-iris/src/models/kv_caching.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ class Cache:
8
+ def __init__(self, num_samples: int, max_tokens: int, embed_dim: int, device: torch.device) -> None:
9
+ self._n, self._cache, self._size = num_samples, None, None
10
+ self._reset = lambda n: torch.empty(n, max_tokens, embed_dim, device=device) # (B, T, E)
11
+ self.reset()
12
+
13
+ @property
14
+ def shape(self) -> Tuple[int, int, int]:
15
+ n, _, embed_dim = self._cache.shape
16
+
17
+ return n, self._size, embed_dim
18
+
19
+ def reset(self) -> None:
20
+ self._cache = self._reset(self._n)
21
+ self._size = 0
22
+
23
+ def prune(self, mask: np.ndarray) -> None:
24
+ assert mask.ndim == 1 and mask.shape[0] == self.shape[0]
25
+ self._cache = self._cache[mask]
26
+ self._n = self._cache.shape[0]
27
+
28
+ def get(self) -> torch.Tensor:
29
+ return self._cache[:, :self._size, :]
30
+
31
+ def update(self, x: torch.Tensor) -> None:
32
+ assert (x.ndim == self._cache.ndim) and all([x.size(i) == self._cache.size(i) for i in (0, 2)])
33
+ assert self._size + x.size(1) <= self._cache.shape[1]
34
+ self._cache = AssignWithoutInplaceCheck.apply(self._cache, x, 1, self._size, self._size + x.size(1))
35
+ self._size += x.size(1)
36
+
37
+
38
+ class KVCache:
39
+ def __init__(self, n: int, max_tokens: int, embed_dim: int, device: torch.device) -> None:
40
+ self._k_cache = Cache(n, max_tokens, embed_dim, device)
41
+ self._v_cache = Cache(n, max_tokens, embed_dim, device)
42
+
43
+ @property
44
+ def shape(self) -> Tuple[int, int, int]:
45
+ return self._k_cache.shape
46
+
47
+ def reset(self) -> None:
48
+ self._k_cache.reset()
49
+ self._v_cache.reset()
50
+
51
+ def prune(self, mask: np.ndarray) -> None:
52
+ self._k_cache.prune(mask)
53
+ self._v_cache.prune(mask)
54
+
55
+ def get(self) -> Tuple[torch.Tensor, torch.Tensor]:
56
+ return self._k_cache.get(), self._v_cache.get()
57
+
58
+ def update(self, k: torch.Tensor, v: torch.Tensor):
59
+ self._k_cache.update(k)
60
+ self._v_cache.update(v)
61
+
62
+
63
+ class KeysValues:
64
+ def __init__(self, n: int, max_tokens: int, embed_dim: int, num_layers: int, device: torch.device) -> None:
65
+ self._keys_values = tuple([KVCache(n, max_tokens, embed_dim, device) for _ in range(num_layers)])
66
+
67
+ def __getitem__(self, key: int) -> KVCache:
68
+ return self._keys_values[key]
69
+
70
+ def __len__(self):
71
+ return len(self._keys_values)
72
+
73
+ @property
74
+ def size(self):
75
+ return self._keys_values[0].shape[1]
76
+
77
+ def reset(self) -> None:
78
+ for kv_cache in self._keys_values:
79
+ kv_cache.reset()
80
+
81
+ def prune(self, mask: np.ndarray) -> None:
82
+ for kv_cache in self._keys_values:
83
+ kv_cache.prune(mask)
84
+
85
+
86
+ class AssignWithoutInplaceCheck(torch.autograd.Function):
87
+ """
88
+ Inspired from : https://discuss.pytorch.org/t/disable-in-place-correctness-version-check-any-other-workaround/90738/4
89
+ Warning : do not use it to overwrite a slice twice.
90
+ """
91
+
92
+ @staticmethod
93
+ def get_slice(dim: int, start: int, stop: int) -> Tuple[slice]:
94
+ return tuple([slice(None), ] * dim + [slice(start, stop)])
95
+
96
+ @staticmethod
97
+ def forward(ctx, input: torch.Tensor, value: torch.Tensor, dim: int, start: int, stop: int) -> torch.Tensor:
98
+ ctx.dim = dim
99
+ ctx.start = start
100
+ ctx.stop = stop
101
+ input.data[AssignWithoutInplaceCheck.get_slice(dim, start, stop)] = value
102
+ return input
103
+
104
+ @staticmethod
105
+ def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor]:
106
+ return grad_out, grad_out[AssignWithoutInplaceCheck.get_slice(ctx.dim, ctx.start, ctx.stop)], None, None, None
delta-iris/src/models/slicer.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ class Slicer(nn.Module):
9
+ def __init__(self, max_blocks: int, block_mask: torch.Tensor) -> None:
10
+ super().__init__()
11
+ self.block_size = block_mask.size(0)
12
+ self.num_kept_tokens = block_mask.sum().long().item()
13
+ kept_indices = torch.where(block_mask)[0].repeat(max_blocks)
14
+ offsets = torch.arange(max_blocks).repeat_interleave(self.num_kept_tokens)
15
+ self.register_buffer('indices', kept_indices + block_mask.size(0) * offsets)
16
+
17
+ def compute_slice(self, num_steps: int, prev_steps: int = 0) -> torch.Tensor:
18
+ total_steps = num_steps + prev_steps
19
+ num_blocks = math.ceil(total_steps / self.block_size)
20
+ indices = self.indices[:num_blocks * self.num_kept_tokens]
21
+ return indices[torch.logical_and(prev_steps <= indices, indices < total_steps)] - prev_steps
22
+
23
+ def forward(self, *args, **kwargs):
24
+ raise NotImplementedError
25
+
26
+
27
+ class Head(Slicer):
28
+ def __init__(self, max_blocks: int, block_mask: torch.Tensor, head_module: nn.Module) -> None:
29
+ super().__init__(max_blocks, block_mask)
30
+ assert isinstance(head_module, nn.Module)
31
+ self.head_module = head_module
32
+
33
+ def forward(self, x: torch.Tensor, num_steps: int, prev_steps: int) -> torch.Tensor:
34
+ x_sliced = x[:, self.compute_slice(num_steps, prev_steps)] # x is (B, T, E)
35
+ return self.head_module(x_sliced)
36
+
37
+
38
+ class Embedder(nn.Module):
39
+ def __init__(self, max_blocks: int, block_masks: List[torch.Tensor], embedding_tables: List[nn.Embedding]) -> None:
40
+ super().__init__()
41
+ assert len(block_masks) == len(embedding_tables)
42
+ assert (sum(block_masks) == 1).all() # block mask are a partition of a block
43
+ self.embedding_dim = embedding_tables[0].embedding_dim
44
+ assert all([e.embedding_dim == self.embedding_dim for e in embedding_tables])
45
+ self.embedding_tables = embedding_tables
46
+ self.slicers = [Slicer(max_blocks, block_mask) for block_mask in block_masks]
47
+
48
+ def forward(self, tokens: torch.LongTensor, num_steps: int, prev_steps: int) -> torch.FloatTensor:
49
+ assert tokens.ndim == 2 # x is (B, T)
50
+ output = torch.zeros(*tokens.size(), self.embedding_dim, device=tokens.device)
51
+ for slicer, emb in zip(self.slicers, self.embedding_tables):
52
+ s = slicer.compute_slice(num_steps, prev_steps)
53
+ output[:, s] = emb(tokens[:, s])
54
+
55
+ return output
delta-iris/src/models/tokenizer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from ....tokenizer import Tokenizer, TokenizerConfig
delta-iris/src/models/tokenizer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (257 Bytes). View file
 
delta-iris/src/models/tokenizer/__pycache__/quantizer.cpython-310.pyc ADDED
Binary file (3.74 kB). View file
 
delta-iris/src/models/tokenizer/__pycache__/tokenizer.cpython-310.pyc ADDED
Binary file (4.8 kB). View file
 
delta-iris/src/models/tokenizer/quantizer.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import math
3
+ from typing import Dict, Optional
4
+
5
+ from einops import rearrange
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ @dataclass
12
+ class QuantizerOutput:
13
+ q: torch.FloatTensor
14
+ tokens: torch.LongTensor
15
+ loss: Dict[str, torch.FloatTensor]
16
+ metrics: Dict[str, float]
17
+
18
+
19
+ class Quantizer(nn.Module):
20
+ def __init__(self, codebook_size: int, codebook_dim: int, input_dim: int, max_codebook_updates_with_revival: Optional[int] = None) -> None:
21
+ super().__init__()
22
+ assert math.log2(codebook_size).is_integer()
23
+ self.revival_entropy_threshold = int(math.log2(codebook_size)) - 2
24
+ self.max_codebook_updates_with_revival = max_codebook_updates_with_revival
25
+ self.pre_quant_proj = nn.Linear(input_dim, codebook_dim)
26
+ self.post_quant_proj = nn.Linear(codebook_dim, input_dim)
27
+ codebook = torch.empty(codebook_size, codebook_dim, requires_grad=False).uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
28
+ self.register_buffer('num_codebook_updates', torch.tensor(0))
29
+ self.register_buffer('codebook', codebook)
30
+ self.register_buffer('codewords_freqs', torch.ones(codebook_size).div(codebook_size))
31
+
32
+ def forward(self, z: torch.Tensor) -> QuantizerOutput:
33
+ z = self.pre_quant_proj(z)
34
+ z = F.normalize(z, dim=-1)
35
+ b, k = z.size(0), z.size(2)
36
+ z = rearrange(z, 'b t k e -> (b t k) e')
37
+
38
+ cosine_similarity = torch.einsum('n e, c e -> n c', z, self.codebook)
39
+ tokens = cosine_similarity.argmax(dim=-1)
40
+ q = self.codebook[tokens]
41
+
42
+ losses = {'commitment_loss': 0.02 * (z - q.detach()).pow(2).mean()}
43
+
44
+ if self.training:
45
+ metrics = {**self.update_codebook(z, tokens), 'codebook_entropy': self.compute_codebook_entropy()}
46
+ else:
47
+ metrics = {}
48
+
49
+ q = z + (q - z).detach()
50
+ q = self.post_quant_proj(q)
51
+
52
+ q = rearrange(q, '(b t k) e -> b t k e', b=b, k=k)
53
+ tokens = rearrange(tokens, '(b t k) -> b t k', b=b, k=k)
54
+
55
+ return QuantizerOutput(q, tokens, losses, metrics)
56
+
57
+ @torch.no_grad()
58
+ def update_codebook(self, z: torch.Tensor, tokens: torch.LongTensor) -> None:
59
+ tokens_one_hot = F.one_hot(tokens, self.codebook.size(0)).float() # (N, C)
60
+
61
+ # Update codebook
62
+ counts = tokens_one_hot.sum(dim=0)
63
+ codebook_update = torch.einsum('n e, n c -> c e', z, tokens_one_hot) / torch.clamp(counts.unsqueeze(-1), min=1)
64
+ codebook_update = F.normalize(codebook_update, dim=-1)
65
+ self.codebook.lerp_(codebook_update, 1 - 0.99)
66
+
67
+ # Update counts and revive dead codewords
68
+ freqs = counts / tokens_one_hot.size(0)
69
+ self.codewords_freqs.lerp_(freqs, 1 - 0.98)
70
+
71
+ can_revive = (self.compute_codebook_entropy() < 1) or (self.max_codebook_updates_with_revival is None) or (self.num_codebook_updates.item() < self.max_codebook_updates_with_revival)
72
+ if can_revive and (self.compute_codebook_entropy() < self.revival_entropy_threshold):
73
+ expired = torch.where(self.codewords_freqs < 1 / (10 * self.codewords_freqs.size(0)))[0]
74
+ num_expired = expired.size(0)
75
+ expired = expired[torch.randperm(num_expired)[:z.size(0)]]
76
+ idx_revived = torch.randperm(z.size(0), device=z.device)[:expired.size(0)]
77
+ self.codebook[expired] = z[idx_revived]
78
+ self.codewords_freqs[expired] = 1 / self.codewords_freqs.size(0)
79
+ else:
80
+ num_expired = 0
81
+
82
+ self.codebook = F.normalize(self.codebook, dim=-1)
83
+
84
+ self.num_codebook_updates += 1
85
+ metrics = {'codewords_revived': num_expired}
86
+
87
+ return metrics
88
+
89
+ def compute_codebook_entropy(self) -> float:
90
+ probs = self.codewords_freqs[self.codewords_freqs != 0]
91
+ return -(torch.log2(probs) * probs).sum().item()
92
+
93
+ @torch.no_grad()
94
+ def embed_tokens(self, tokens: torch.LongTensor) -> torch.FloatTensor:
95
+ return self.post_quant_proj(self.codebook[tokens])
delta-iris/src/models/transformer.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inspired from https://github.com/karpathy/minGPT
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Optional
7
+
8
+ from einops import rearrange
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from .kv_caching import KeysValues, KVCache
13
+
14
+
15
+ @dataclass
16
+ class TransformerConfig:
17
+
18
+ tokens_per_block: int
19
+ max_blocks: int
20
+
21
+ num_layers: int
22
+ num_heads: int
23
+ embed_dim: int
24
+
25
+ attention: str
26
+
27
+ embed_pdrop: float
28
+ resid_pdrop: float
29
+ attn_pdrop: float
30
+
31
+ @property
32
+ def max_tokens(self):
33
+ return self.tokens_per_block * self.max_blocks
34
+
35
+
36
+ class TransformerEncoder(nn.Module):
37
+ def __init__(self, config: TransformerConfig) -> None:
38
+ super().__init__()
39
+ self.config = config
40
+
41
+ self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim)
42
+ self.emb_drop = nn.Dropout(config.embed_pdrop)
43
+ self.ln = nn.LayerNorm(config.embed_dim)
44
+
45
+ assert config.attention in ('causal', 'block_causal')
46
+ k, m = config.tokens_per_block, config.max_blocks
47
+ mask_sa = torch.tril(torch.ones(k * m, k * m))
48
+ if config.attention == 'block_causal':
49
+ mask_sa = torch.max(mask_sa, torch.block_diag(*[torch.ones(k, k) for _ in range(m)]))
50
+ mask_sa = mask_sa.bool()
51
+
52
+ self.blocks = nn.ModuleList([EncoderLayer(config, mask_sa) for _ in range(config.num_layers)])
53
+ self.keys_values = None
54
+
55
+ @property
56
+ def num_blocks_left_in_kv_cache(self) -> float:
57
+ assert self.keys_values is not None
58
+ return (self.config.max_tokens - self.keys_values.size) / self.config.tokens_per_block
59
+
60
+ def reset_kv_cache(self, n: int) -> None:
61
+ device = self.ln.weight.device
62
+ self.keys_values = KeysValues(n, self.config.max_tokens, self.config.embed_dim, self.config.num_layers, device)
63
+
64
+ def forward(self, x: torch.FloatTensor, use_kv_cache: bool = False) -> torch.FloatTensor:
65
+ assert x.ndim == 3 and x.size(2) == self.config.embed_dim # (B, TK, E)
66
+
67
+ prev_steps = self.keys_values.size if use_kv_cache else 0
68
+ inputs = x + self.pos_emb(prev_steps + torch.arange(x.size(1), device=x.device))
69
+
70
+ y = self.emb_drop(inputs)
71
+ for i, block in enumerate(self.blocks):
72
+ y = block(y, self.keys_values[i] if use_kv_cache else None)
73
+ y = self.ln(y)
74
+
75
+ return y
76
+
77
+
78
+ class EncoderLayer(nn.Module):
79
+ def __init__(self, config: TransformerConfig, mask_sa: torch.LongTensor) -> None:
80
+ super().__init__()
81
+ self.sa = SelfAttentionLayer(config, mask=mask_sa)
82
+ self.mlp = MLPLayer(config)
83
+
84
+ def forward(self, x: torch.FloatTensor, kv_cache: Optional[KVCache] = None) -> torch.FloatTensor:
85
+ return self.mlp(self.sa(x, kv_cache))
86
+
87
+
88
+ class MLPLayer(nn.Module):
89
+ def __init__(self, config: TransformerConfig) -> None:
90
+ super().__init__()
91
+ self.ln = nn.LayerNorm(config.embed_dim)
92
+ self.mlp = nn.Sequential(
93
+ nn.Linear(config.embed_dim, 4 * config.embed_dim),
94
+ nn.GELU(),
95
+ nn.Linear(4 * config.embed_dim, config.embed_dim),
96
+ nn.Dropout(config.resid_pdrop),
97
+ )
98
+
99
+ def forward(self, inputs: torch.FloatTensor) -> torch.FloatTensor:
100
+ return inputs + self.mlp(self.ln(inputs))
101
+
102
+
103
+ class SelfAttentionLayer(nn.Module):
104
+ def __init__(self, config: TransformerConfig, mask: torch.BoolTensor) -> None:
105
+ super().__init__()
106
+ self.register_buffer('mask', mask)
107
+ self.ln = nn.LayerNorm(config.embed_dim)
108
+ self.query = nn.Linear(config.embed_dim, config.embed_dim)
109
+ self.key = nn.Linear(config.embed_dim, config.embed_dim)
110
+ self.value = nn.Linear(config.embed_dim, config.embed_dim)
111
+ self.attention = Attention(config)
112
+
113
+ def forward(self, inputs: torch.FloatTensor, kv_cache: Optional[KVCache] = None) -> torch.FloatTensor:
114
+ B, T, C = inputs.size()
115
+ if kv_cache is not None:
116
+ b, L, c = kv_cache.shape
117
+ assert b == B and c == C
118
+ else:
119
+ L = 0
120
+
121
+ x = self.ln(inputs)
122
+
123
+ q = self.query(x)
124
+ k = self.key(x)
125
+ v = self.value(x)
126
+
127
+ if kv_cache is not None:
128
+ kv_cache.update(k, v)
129
+ k, v = kv_cache.get()
130
+
131
+ y = inputs + self.attention(q, k, v, self.mask[L:L + T, :L + T])
132
+
133
+ return y
134
+
135
+
136
+ class Attention(nn.Module):
137
+ def __init__(self, config: TransformerConfig) -> None:
138
+ super().__init__()
139
+ assert config.embed_dim % config.num_heads == 0
140
+ self.num_heads = config.num_heads
141
+ self.attn_pdrop = config.attn_pdrop
142
+ self.resid_drop = nn.Dropout(config.resid_pdrop)
143
+ self.proj = nn.Linear(config.embed_dim, config.embed_dim)
144
+
145
+ def forward(self, q: torch.FloatTensor, k: torch.FloatTensor, v: torch.FloatTensor, mask: torch.BoolTensor) -> torch.FloatTensor:
146
+ assert mask.size(0) == q.size(1) and mask.size(1) == k.size(1)
147
+
148
+ q = rearrange(q, 'b q (h e) -> b h q e', h=self.num_heads)
149
+ k = rearrange(k, 'b k (h e) -> b h k e', h=self.num_heads)
150
+ v = rearrange(v, 'b k (h d) -> b h k d', h=self.num_heads)
151
+
152
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.attn_pdrop, is_causal=False) if q.size(2) != 0 else q.new_empty(*q.shape[:-1], v.size(-1))
153
+
154
+ y = rearrange(y, 'b h q d -> b q (h d)')
155
+ y = self.resid_drop(self.proj(y))
156
+
157
+ return y
delta-iris/src/models/utils.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import cv2
3
+ from pathlib import Path
4
+ import random
5
+ import shutil
6
+ from typing import Callable, Dict
7
+
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ from PIL import Image
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torch.optim import AdamW
15
+
16
+ from data import Episode
17
+
18
+
19
+ def configure_optimizer(model: nn.Module, learning_rate: float, weight_decay: float, *blacklist_module_names) -> AdamW:
20
+ """Credits to https://github.com/karpathy/minGPT"""
21
+ # separate out all parameters to those that will and won't experience regularizing weight decay
22
+ decay = set()
23
+ no_decay = set()
24
+ whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv1d)
25
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding, nn.Conv2d, nn.GroupNorm)
26
+ for mn, m in model.named_modules():
27
+ for pn, p in m.named_parameters():
28
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
29
+ if any([fpn.startswith(module_name) for module_name in blacklist_module_names]):
30
+ no_decay.add(fpn)
31
+ elif 'bias' in pn:
32
+ # all biases will not be decayed
33
+ no_decay.add(fpn)
34
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
35
+ # weights of whitelist modules will be weight decayed
36
+ decay.add(fpn)
37
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
38
+ # weights of blacklist modules will NOT be weight decayed
39
+ no_decay.add(fpn)
40
+
41
+ # validate that we considered every parameter
42
+ param_dict = {pn: p for pn, p in model.named_parameters()}
43
+ inter_params = decay & no_decay
44
+ union_params = decay | no_decay
45
+ assert len(inter_params) == 0, f"parameters {str(inter_params)} made it into both decay/no_decay sets!"
46
+ assert len(param_dict.keys() - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!"
47
+
48
+ # create the pytorch optimizer object
49
+ optim_groups = [
50
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
51
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
52
+ ]
53
+ optimizer = AdamW(optim_groups, lr=learning_rate)
54
+
55
+ return optimizer
56
+
57
+
58
+ def init_weights(module: nn.Module) -> None:
59
+ if isinstance(module, (nn.Linear, nn.Embedding)):
60
+ module.weight.data.normal_(mean=0.0, std=0.02)
61
+ if isinstance(module, nn.Linear) and module.bias is not None:
62
+ module.bias.data.zero_()
63
+ elif isinstance(module, nn.LayerNorm):
64
+ module.bias.data.zero_()
65
+ module.weight.data.fill_(1.0)
66
+
67
+
68
+ def extract_state_dict(state_dict: Dict, module_name: str) -> OrderedDict:
69
+ return OrderedDict({k.split('.', 1)[1]: v for k, v in state_dict.items() if k.startswith(module_name)})
70
+
71
+
72
+ def set_seed(seed: int) -> None:
73
+ np.random.seed(seed)
74
+ torch.manual_seed(seed)
75
+ torch.cuda.manual_seed(seed)
76
+ random.seed(seed)
77
+
78
+
79
+ @torch.no_grad()
80
+ def compute_discounted_returns(rewards: torch.FloatTensor, gamma: float) -> torch.FloatTensor:
81
+ assert 0 < gamma <= 1 and rewards.ndim == 2 # (B, T)
82
+ gammas = gamma ** torch.arange(rewards.size(1))
83
+ r = rewards * gammas
84
+
85
+ return (r + r.sum(dim=1, keepdim=True) - r.cumsum(dim=1)) / gammas
86
+
87
+
88
+ class LossWithIntermediateLosses:
89
+ def __init__(self, **kwargs) -> None:
90
+ self.loss_total = sum(kwargs.values())
91
+ self.intermediate_losses = {k: v.item() for k, v in kwargs.items()}
92
+
93
+
94
+ class EpisodeDirManager:
95
+ def __init__(self, episode_dir: Path, max_num_episodes: int) -> None:
96
+ self.episode_dir = episode_dir
97
+ self.episode_dir.mkdir(parents=False, exist_ok=True)
98
+ self.max_num_episodes = max_num_episodes
99
+ self.best_return = float('-inf')
100
+
101
+ def save(self, episode: Episode, episode_id: int, epoch: int) -> None:
102
+ if self.max_num_episodes is not None and self.max_num_episodes > 0:
103
+ self._save(episode, episode_id, epoch)
104
+
105
+ def _save(self, episode: Episode, episode_id: int, epoch: int) -> None:
106
+ ep_paths = [p for p in self.episode_dir.iterdir() if p.stem.startswith('episode_')]
107
+ assert len(ep_paths) <= self.max_num_episodes
108
+ if len(ep_paths) == self.max_num_episodes:
109
+ to_remove = min(ep_paths, key=lambda ep_path: int(ep_path.stem.split('_')[1]))
110
+ to_remove.unlink()
111
+ torch.save(episode.__dict__, self.episode_dir / f'episode_{episode_id}_epoch_{epoch}.pt')
112
+
113
+ ep_return = episode.compute_metrics().episode_return
114
+ if ep_return > self.best_return:
115
+ self.best_return = ep_return
116
+ path_best_ep = [p for p in self.episode_dir.iterdir() if p.stem.startswith('best_')]
117
+ assert len(path_best_ep) in (0, 1)
118
+ if len(path_best_ep) == 1:
119
+ path_best_ep[0].unlink()
120
+ torch.save(episode.__dict__, self.episode_dir / f'best_episode_{episode_id}_epoch_{epoch}.pt')
121
+
122
+
123
+ class RandomHeuristic:
124
+ def __init__(self, num_actions):
125
+ self.num_actions = num_actions
126
+
127
+ def act(self, obs):
128
+ assert obs.ndim == 4 # (N, H, W, C)
129
+ n = obs.size(0)
130
+
131
+ return torch.randint(low=0, high=self.num_actions, size=(n,))
132
+
133
+
134
+ def make_video(fname, fps, frames):
135
+ assert frames.ndim == 4 # (T, H, W, C)
136
+ _, h, w, c = frames.shape
137
+ assert c == 3
138
+
139
+ video = cv2.VideoWriter(str(fname), cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
140
+ for frame in frames:
141
+ video.write(frame[:, :, ::-1])
142
+ video.release()
143
+
144
+
145
+ def try_until_no_except(fn: Callable):
146
+ while True:
147
+ try:
148
+ fn()
149
+ except:
150
+ continue
151
+ else:
152
+ break
153
+
154
+
155
+ def symlog(x: torch.Tensor) -> torch.Tensor:
156
+ return torch.sign(x) * torch.log(torch.abs(x) + 1)
157
+
158
+
159
+ def symexp(x: torch.Tensor) -> torch.Tensor:
160
+ return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
161
+
162
+
163
+ def two_hot(x: torch.FloatTensor, x_min: int = -20, x_max: int = 20, num_buckets: int = 255) -> torch.FloatTensor:
164
+ x.clamp_(x_min, x_max - 1e-5)
165
+ buckets = torch.linspace(x_min, x_max, num_buckets).to(x.device)
166
+ k = torch.searchsorted(buckets, x) - 1
167
+ values = torch.stack((buckets[k + 1] - x, x - buckets[k]), dim=-1) / (buckets[k + 1] - buckets[k]).unsqueeze(-1)
168
+ two_hots = torch.scatter(x.new_zeros(*x.size(), num_buckets), dim=-1, index=torch.stack((k, k + 1), dim=-1), src=values)
169
+
170
+ return two_hots
171
+
172
+
173
+ def compute_softmax_over_buckets(logits: torch.FloatTensor, x_min: int = -20, x_max: int = 20, num_buckets: int = 255) -> torch.FloatTensor:
174
+ buckets = torch.linspace(x_min, x_max, num_buckets).to(logits.device)
175
+ probs = F.softmax(logits, dim=-1)
176
+
177
+ return probs @ buckets
178
+
179
+
180
+ def plot_counts(counts: np.ndarray) -> Image:
181
+ fig, ax = plt.subplots(figsize=(14, 7))
182
+ ax.plot(counts)
183
+ p = Path('priorities.png')
184
+ fig.savefig(p)
185
+ plt.close(fig)
186
+ im = Image.open(p)
187
+ p.unlink()
188
+
189
+ return im
190
+
191
+
192
+ def compute_mask_after_first_done(ends: torch.LongTensor) -> torch.BoolTensor:
193
+ assert ends.ndim == 2
194
+ first_one_index = torch.argmax(ends, dim=1)
195
+ mask = torch.arange(ends.size(1), device=ends.device).unsqueeze(0) <= first_one_index.unsqueeze(1)
196
+ mask = torch.logical_or(mask, ends.sum(dim=1, keepdim=True) == 0)
197
+
198
+ return mask
delta-iris/src/tokenizer.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import math
3
+ from typing import Dict, Tuple
4
+
5
+ from einops import rearrange
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from .models.convnet import FrameCnnConfig, FrameEncoder, FrameDecoder
10
+ from .data import Batch
11
+ from .models.tokenizer.quantizer import Quantizer, QuantizerOutput
12
+ from .models.utils import init_weights, LossWithIntermediateLosses
13
+
14
+
15
+ @dataclass
16
+ class TokenizerConfig:
17
+ image_channels: int
18
+ image_size: int
19
+ num_actions: int
20
+ num_tokens: int
21
+ decoder_act_channels: int
22
+ codebook_size: int
23
+ codebook_dim: int
24
+ max_codebook_updates_with_revival: int
25
+ encoder_config: FrameCnnConfig
26
+ decoder_config: FrameCnnConfig
27
+ frame_cnn_config: FrameCnnConfig
28
+
29
+
30
+ class Tokenizer(nn.Module):
31
+ def __init__(self, config: TokenizerConfig) -> None:
32
+ super().__init__()
33
+ self.config = config
34
+
35
+ self.latent_res = config.image_size // 2 ** sum(config.encoder_config.down)
36
+ self.tokens_grid_res = int(math.sqrt(config.num_tokens))
37
+ self.token_res = self.latent_res // self.tokens_grid_res
38
+
39
+ self.encoder_act_emb = nn.Embedding(config.num_actions, config.image_size ** 2)
40
+ self.decoder_act_emb = nn.Embedding(config.num_actions, config.decoder_act_channels * self.latent_res ** 2)
41
+
42
+ self.quantizer = Quantizer(
43
+ config.codebook_size, config.codebook_dim,
44
+ input_dim=config.encoder_config.latent_dim * self.token_res ** 2,
45
+ max_codebook_updates_with_revival=config.max_codebook_updates_with_revival
46
+ )
47
+
48
+ self.encoder = FrameEncoder(config.encoder_config)
49
+ self.decoder = FrameDecoder(config.decoder_config)
50
+ self.frame_cnn = FrameEncoder(config.frame_cnn_config)
51
+
52
+ self.apply(init_weights)
53
+
54
+ def __repr__(self) -> str:
55
+ return "tokenizer"
56
+
57
+ def forward(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> QuantizerOutput:
58
+ z = self.encode(x1, a, x2)
59
+ z = rearrange(z, 'b t c (h k) (w l) -> b t (h w) (k l c)', h=self.tokens_grid_res, w=self.tokens_grid_res)
60
+
61
+ return self.quantizer(z)
62
+
63
+ def compute_loss(self, batch: Batch, **kwargs) -> Tuple[LossWithIntermediateLosses, Dict]:
64
+ x1 = batch.observations[:, :-1]
65
+ a = batch.actions[:, :-1]
66
+ x2 = batch.observations[:, 1:]
67
+
68
+ quantizer_outputs = self(x1, a, x2)
69
+
70
+ r = self.decode(x1, a, rearrange(quantizer_outputs.q, 'b t (h w) (k l e) -> b t e (h k) (w l)', h=self.tokens_grid_res, k=self.token_res, l=self.token_res))
71
+ delta = (x2 - r)
72
+ delta = delta[torch.logical_and(batch.mask_padding[:, 1:], batch.mask_padding[:, :-1])]
73
+
74
+ losses = {
75
+ **quantizer_outputs.loss,
76
+ 'reconstruction_loss_l1': 0.1 * torch.abs(delta).mean(),
77
+ 'reconstruction_loss_l2': delta.pow(2).mean(),
78
+ 'reconstruction_loss_l2_worst_pixel': 0.01 * rearrange(delta, 'b c h w -> b (c h w)').pow(2).max(dim=-1)[0].mean(),
79
+ }
80
+
81
+ return LossWithIntermediateLosses(**losses), quantizer_outputs.metrics
82
+
83
+ def encode(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> torch.FloatTensor:
84
+ a_emb = rearrange(self.encoder_act_emb(a), 'b t (h w) -> b t 1 h w', h=x1.size(3))
85
+ encoder_input = torch.cat((x1, a_emb, x2), dim=2)
86
+ z = self.encoder(encoder_input)
87
+
88
+ return z
89
+
90
+ def decode(self, x1: torch.FloatTensor, a: torch.LongTensor, q2: torch.FloatTensor, should_clamp: bool = False) -> torch.FloatTensor:
91
+ x1_emb = self.frame_cnn(x1)
92
+ a_emb = rearrange(self.decoder_act_emb(a), 'b t (c h w) -> b t c h w', c=self.config.decoder_act_channels, h=x1_emb.size(3))
93
+
94
+ decoder_input = torch.cat((x1_emb, a_emb, q2), dim=2)
95
+
96
+ r = self.decoder(decoder_input)
97
+ r = torch.clamp(r, 0, 1).mul(255).round().div(255) if should_clamp else r
98
+
99
+ return r
100
+
101
+ @torch.no_grad()
102
+ def encode_decode(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> torch.Tensor:
103
+ z = self.encode(x1, a, x2)
104
+ z = rearrange(z, 'b t c (h k) (w l) -> b t (h w) (k l c)', k=self.token_res, l=self.token_res)
105
+ q = rearrange(self.quantizer(z).q, 'b t (h w) (k l e) -> b t e (h k) (w l)', h=self.tokens_grid_res, k=self.token_res, l=self.token_res)
106
+ r = self.decode(x1, a, q, should_clamp=True)
107
+
108
+ return r
109
+
110
+ @torch.no_grad()
111
+ def burn_in(self, obs: torch.FloatTensor, act: torch.LongTensor) -> torch.LongTensor:
112
+ assert obs.size(1) == act.size(1) + 1
113
+ quantizer_output = self(obs[:, :-1], act, obs[:, 1:])
114
+
115
+ return quantizer_output.tokens
delta-iris/src/world_model.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from einops import rearrange, repeat
4
+ from einops.layers.torch import Rearrange
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from .models.convnet import FrameCnnConfig, FrameEncoder
10
+ from .data import Batch
11
+ from .models.slicer import Head
12
+ from .tokenizer import Tokenizer
13
+ from .models.transformer import TransformerEncoder, TransformerConfig
14
+ from .models.utils import init_weights, LossWithIntermediateLosses, symlog, two_hot
15
+
16
+
17
+ @dataclass
18
+ class WorldModelOutput:
19
+ output_sequence: torch.FloatTensor
20
+ logits_latents: torch.FloatTensor
21
+ logits_rewards: torch.FloatTensor
22
+ logits_ends: torch.FloatTensor
23
+
24
+
25
+ @dataclass
26
+ class WorldModelConfig:
27
+ latent_vocab_size: int
28
+ num_actions: int
29
+ image_channels: int
30
+ image_size: int
31
+ latents_weight: float
32
+ rewards_weight: float
33
+ ends_weight: float
34
+ two_hot_rews: bool
35
+ transformer_config: TransformerConfig
36
+ frame_cnn_config: FrameCnnConfig
37
+
38
+
39
+ class WorldModel(nn.Module):
40
+ def __init__(self, config: WorldModelConfig) -> None:
41
+ super().__init__()
42
+ self.config = config
43
+ self.transformer = TransformerEncoder(config.transformer_config)
44
+
45
+ assert ((config.image_size // 2 ** sum(config.frame_cnn_config.down)) ** 2) * config.frame_cnn_config.latent_dim == config.transformer_config.embed_dim
46
+ self.frame_cnn = nn.Sequential(FrameEncoder(config.frame_cnn_config), Rearrange('b t c h w -> b t 1 (h w c)'), nn.LayerNorm(config.transformer_config.embed_dim))
47
+
48
+ self.act_emb = nn.Embedding(config.num_actions, config.transformer_config.embed_dim)
49
+ self.latents_emb = nn.Embedding(config.latent_vocab_size, config.transformer_config.embed_dim)
50
+
51
+ act_pattern = torch.zeros(config.transformer_config.tokens_per_block)
52
+ act_pattern[1] = 1
53
+ act_and_latents_but_last_pattern = torch.zeros(config.transformer_config.tokens_per_block)
54
+ act_and_latents_but_last_pattern[1:-1] = 1
55
+
56
+ self.head_latents = Head(
57
+ max_blocks=config.transformer_config.max_blocks,
58
+ block_mask=act_and_latents_but_last_pattern,
59
+ head_module=nn.Sequential(
60
+ nn.Linear(config.transformer_config.embed_dim, config.transformer_config.embed_dim), nn.ReLU(),
61
+ nn.Linear(config.transformer_config.embed_dim, config.latent_vocab_size)
62
+ )
63
+ )
64
+
65
+ self.head_rewards = Head(
66
+ max_blocks=config.transformer_config.max_blocks,
67
+ block_mask=act_pattern,
68
+ head_module=nn.Sequential(
69
+ nn.Linear(config.transformer_config.embed_dim, config.transformer_config.embed_dim), nn.ReLU(),
70
+ nn.Linear(config.transformer_config.embed_dim, 255 if config.two_hot_rews else 3)
71
+ )
72
+ )
73
+
74
+ self.head_ends = Head(
75
+ max_blocks=config.transformer_config.max_blocks,
76
+ block_mask=act_pattern,
77
+ head_module=nn.Sequential(
78
+ nn.Linear(config.transformer_config.embed_dim, config.transformer_config.embed_dim), nn.ReLU(),
79
+ nn.Linear(config.transformer_config.embed_dim, 2)
80
+ )
81
+ )
82
+
83
+ self.apply(init_weights)
84
+
85
+ def __repr__(self) -> str:
86
+ return "world_model"
87
+
88
+ def forward(self, sequence: torch.FloatTensor, use_kv_cache: bool = False) -> WorldModelOutput:
89
+ prev_steps = self.transformer.keys_values.size if use_kv_cache else 0
90
+ num_steps = sequence.size(1)
91
+
92
+ outputs = self.transformer(sequence, use_kv_cache=use_kv_cache)
93
+
94
+ logits_latents = self.head_latents(outputs, num_steps, prev_steps)
95
+ logits_rewards = self.head_rewards(outputs, num_steps, prev_steps)
96
+ logits_ends = self.head_ends(outputs, num_steps, prev_steps)
97
+
98
+ return WorldModelOutput(outputs, logits_latents, logits_rewards, logits_ends)
99
+
100
+ def compute_loss(self, batch: Batch, tokenizer: Tokenizer, **kwargs) -> LossWithIntermediateLosses:
101
+ assert torch.all(batch.ends.sum(dim=1) <= 1)
102
+
103
+ with torch.no_grad():
104
+ latent_tokens = tokenizer(batch.observations[:, :-1], batch.actions[:, :-1], batch.observations[:, 1:]).tokens
105
+
106
+ b, _, k = latent_tokens.size()
107
+
108
+ frames_emb = self.frame_cnn(batch.observations)
109
+ act_tokens_emb = self.act_emb(rearrange(batch.actions, 'b t -> b t 1'))
110
+ latent_tokens_emb = self.latents_emb(torch.cat((latent_tokens, latent_tokens.new_zeros(b, 1, k)), dim=1))
111
+ sequence = rearrange(torch.cat((frames_emb, act_tokens_emb, latent_tokens_emb), dim=2), 'b t p1k e -> b (t p1k) e')
112
+
113
+ outputs = self(sequence)
114
+
115
+ mask = batch.mask_padding
116
+
117
+ labels_latents = latent_tokens[mask[:, :-1]].flatten()
118
+ logits_latents = outputs.logits_latents[:, :-k][repeat(mask[:, :-1], 'b t -> b (t k)', k=k)]
119
+ latent_acc = (logits_latents.max(dim=-1)[1] == labels_latents).float().mean()
120
+ labels_rewards = two_hot(symlog(batch.rewards)) if self.config.two_hot_rews else (batch.rewards.sign() + 1).long()
121
+
122
+ loss_latents = F.cross_entropy(logits_latents, target=labels_latents) * self.config.latents_weight
123
+ loss_rewards = F.cross_entropy(outputs.logits_rewards[mask], target=labels_rewards[mask]) * self.config.rewards_weight
124
+ loss_ends = F.cross_entropy(outputs.logits_ends[mask], target=batch.ends[mask]) * self.config.ends_weight
125
+
126
+ return LossWithIntermediateLosses(loss_latents=loss_latents, loss_rewards=loss_rewards, loss_ends=loss_ends), {'latent_accuracy': latent_acc}
127
+
128
+ @torch.no_grad()
129
+ def burn_in(self, obs: torch.FloatTensor, act: torch.LongTensor, latent_tokens: torch.LongTensor, use_kv_cache: bool = False) -> torch.FloatTensor:
130
+ assert obs.size(1) == act.size(1) + 1 == latent_tokens.size(1) + 1
131
+
132
+ x_emb = self.frame_cnn(obs)
133
+ act_emb = rearrange(self.act_emb(act), 'b t e -> b t 1 e')
134
+ q_emb = self.latents_emb(latent_tokens)
135
+ x_a_q = rearrange(torch.cat((x_emb[:, :-1], act_emb, q_emb), dim=2), 'b t k2 e -> b (t k2) e')
136
+ wm_input_sequence = torch.cat((x_a_q, x_emb[:, -1]), dim=1)
137
+ wm_output_sequence = self(wm_input_sequence, use_kv_cache=use_kv_cache).output_sequence
138
+
139
+ return wm_output_sequence