Upload 35 files
Browse files- delta-iris/src/data/__init__.py +7 -0
- delta-iris/src/data/__pycache__/__init__.cpython-310.pyc +0 -0
- delta-iris/src/data/__pycache__/batch.cpython-310.pyc +0 -0
- delta-iris/src/data/__pycache__/dataset.cpython-310.pyc +0 -0
- delta-iris/src/data/__pycache__/episode.cpython-310.pyc +0 -0
- delta-iris/src/data/__pycache__/episode_count.cpython-310.pyc +0 -0
- delta-iris/src/data/__pycache__/sampler.cpython-310.pyc +0 -0
- delta-iris/src/data/__pycache__/segment.cpython-310.pyc +0 -0
- delta-iris/src/data/__pycache__/utils.cpython-310.pyc +0 -0
- delta-iris/src/data/batch.py +24 -0
- delta-iris/src/data/dataset.py +104 -0
- delta-iris/src/data/episode.py +41 -0
- delta-iris/src/data/episode_count.py +41 -0
- delta-iris/src/data/sampler.py +42 -0
- delta-iris/src/data/segment.py +25 -0
- delta-iris/src/data/utils.py +69 -0
- delta-iris/src/models/__init__.py +1 -0
- delta-iris/src/models/__pycache__/__init__.cpython-310.pyc +0 -0
- delta-iris/src/models/__pycache__/convnet.cpython-310.pyc +0 -0
- delta-iris/src/models/__pycache__/kv_caching.cpython-310.pyc +0 -0
- delta-iris/src/models/__pycache__/slicer.cpython-310.pyc +0 -0
- delta-iris/src/models/__pycache__/transformer.cpython-310.pyc +0 -0
- delta-iris/src/models/__pycache__/world_model.cpython-310.pyc +0 -0
- delta-iris/src/models/convnet.py +114 -0
- delta-iris/src/models/kv_caching.py +106 -0
- delta-iris/src/models/slicer.py +55 -0
- delta-iris/src/models/tokenizer/__init__.py +1 -0
- delta-iris/src/models/tokenizer/__pycache__/__init__.cpython-310.pyc +0 -0
- delta-iris/src/models/tokenizer/__pycache__/quantizer.cpython-310.pyc +0 -0
- delta-iris/src/models/tokenizer/__pycache__/tokenizer.cpython-310.pyc +0 -0
- delta-iris/src/models/tokenizer/quantizer.py +95 -0
- delta-iris/src/models/transformer.py +157 -0
- delta-iris/src/models/utils.py +198 -0
- delta-iris/src/tokenizer.py +115 -0
- delta-iris/src/world_model.py +139 -0
delta-iris/src/data/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .batch import Batch
|
| 2 |
+
from .dataset import EpisodeDataset
|
| 3 |
+
from .episode import Episode
|
| 4 |
+
from .episode_count import EpisodeCountManager
|
| 5 |
+
from .sampler import BatchSampler
|
| 6 |
+
from .segment import SegmentId
|
| 7 |
+
from .utils import collate_segments_to_batch, DatasetTraverser, make_segment
|
delta-iris/src/data/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (543 Bytes). View file
|
|
|
delta-iris/src/data/__pycache__/batch.cpython-310.pyc
ADDED
|
Binary file (1.46 kB). View file
|
|
|
delta-iris/src/data/__pycache__/dataset.cpython-310.pyc
ADDED
|
Binary file (4.9 kB). View file
|
|
|
delta-iris/src/data/__pycache__/episode.cpython-310.pyc
ADDED
|
Binary file (1.8 kB). View file
|
|
|
delta-iris/src/data/__pycache__/episode_count.cpython-310.pyc
ADDED
|
Binary file (2.78 kB). View file
|
|
|
delta-iris/src/data/__pycache__/sampler.cpython-310.pyc
ADDED
|
Binary file (1.96 kB). View file
|
|
|
delta-iris/src/data/__pycache__/segment.cpython-310.pyc
ADDED
|
Binary file (1.06 kB). View file
|
|
|
delta-iris/src/data/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (3.98 kB). View file
|
|
|
delta-iris/src/data/batch.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from .segment import SegmentId
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class Batch:
|
| 12 |
+
observations: torch.ByteTensor
|
| 13 |
+
actions: torch.LongTensor
|
| 14 |
+
rewards: torch.FloatTensor
|
| 15 |
+
ends: torch.LongTensor
|
| 16 |
+
mask_padding: torch.BoolTensor
|
| 17 |
+
segment_ids: List[SegmentId]
|
| 18 |
+
|
| 19 |
+
def pin_memory(self) -> Batch:
|
| 20 |
+
return Batch(**{k: v if k == 'segment_ids' else v.pin_memory() for k, v in self.__dict__.items()})
|
| 21 |
+
|
| 22 |
+
def to(self, device: torch.device) -> Batch:
|
| 23 |
+
return Batch(**{k: v if k == 'segment_ids' else v.to(device) for k, v in self.__dict__.items()})
|
| 24 |
+
|
delta-iris/src/data/dataset.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import shutil
|
| 3 |
+
from typing import Dict, Optional, Union
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from .episode import Episode
|
| 9 |
+
from .segment import Segment, SegmentId
|
| 10 |
+
from .utils import make_segment
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class EpisodeDataset(torch.utils.data.Dataset):
|
| 14 |
+
def __init__(self, directory: Path, name: str) -> None:
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.name = name
|
| 17 |
+
self.directory = Path(directory)
|
| 18 |
+
self.num_episodes, self.num_steps, self.start_idx, self.lengths = None, None, None, None
|
| 19 |
+
|
| 20 |
+
if not self.directory.is_dir():
|
| 21 |
+
self._init_empty()
|
| 22 |
+
else:
|
| 23 |
+
self._load_info()
|
| 24 |
+
print(f'({name}) {self.num_episodes} episodes, {self.num_steps} steps.')
|
| 25 |
+
|
| 26 |
+
@property
|
| 27 |
+
def info_path(self) -> Path:
|
| 28 |
+
return self.directory / 'info.pt'
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
def info(self) -> Dict[str, Union[int, np.ndarray]]:
|
| 32 |
+
return {'num_episodes': self.num_episodes, 'num_steps': self.num_steps, 'start_idx': self.start_idx, 'lengths': self.lengths}
|
| 33 |
+
|
| 34 |
+
def __len__(self) -> int:
|
| 35 |
+
return self.num_steps
|
| 36 |
+
|
| 37 |
+
def __getitem__(self, segment_id: SegmentId) -> Segment:
|
| 38 |
+
return self._load_segment(segment_id)
|
| 39 |
+
|
| 40 |
+
def _init_empty(self) -> None:
|
| 41 |
+
self.directory.mkdir(parents=True, exist_ok=False)
|
| 42 |
+
self.num_episodes = 0
|
| 43 |
+
self.num_steps = 0
|
| 44 |
+
self.start_idx = np.array([], dtype=np.int64)
|
| 45 |
+
self.lengths = np.array([], dtype=np.int64)
|
| 46 |
+
self.save_info()
|
| 47 |
+
|
| 48 |
+
def _load_info(self) -> None:
|
| 49 |
+
info = torch.load(self.info_path)
|
| 50 |
+
self.num_steps = info['num_steps']
|
| 51 |
+
self.num_episodes = info['num_episodes']
|
| 52 |
+
self.start_idx = info['start_idx']
|
| 53 |
+
self.lengths = info['lengths']
|
| 54 |
+
|
| 55 |
+
def save_info(self) -> None:
|
| 56 |
+
torch.save(self.info, self.info_path)
|
| 57 |
+
|
| 58 |
+
def clear(self) -> None:
|
| 59 |
+
shutil.rmtree(self.directory)
|
| 60 |
+
self._init_empty()
|
| 61 |
+
|
| 62 |
+
def _get_episode_path(self, episode_id: int) -> Path:
|
| 63 |
+
n = 3 # number of hierarchies
|
| 64 |
+
powers = np.arange(n)
|
| 65 |
+
subfolders = list(map(int, np.floor((episode_id % 10 ** (1 + powers)) / 10 ** powers) * 10 ** powers))[::-1]
|
| 66 |
+
return self.directory / '/'.join(list(map(lambda x: f'{x[1]:0{n - x[0]}d}', enumerate(subfolders)))) / f'{episode_id}.pt'
|
| 67 |
+
|
| 68 |
+
def _load_segment(self, segment_id: SegmentId, should_pad: bool = True) -> Segment:
|
| 69 |
+
episode = self.load_episode(segment_id.episode_id)
|
| 70 |
+
return make_segment(episode, segment_id, should_pad)
|
| 71 |
+
|
| 72 |
+
def load_episode(self, episode_id: int) -> Episode:
|
| 73 |
+
return Episode(**torch.load(self._get_episode_path(episode_id)))
|
| 74 |
+
|
| 75 |
+
def add_episode(self, episode: Episode, *, episode_id: Optional[int] = None) -> int:
|
| 76 |
+
if episode_id is None:
|
| 77 |
+
episode_id = self.num_episodes
|
| 78 |
+
self.start_idx = np.concatenate((self.start_idx, np.array([self.num_steps])))
|
| 79 |
+
self.lengths = np.concatenate((self.lengths, np.array([len(episode)])))
|
| 80 |
+
self.num_steps += len(episode)
|
| 81 |
+
self.num_episodes += 1
|
| 82 |
+
|
| 83 |
+
else:
|
| 84 |
+
assert episode_id < self.num_episodes
|
| 85 |
+
old_episode = self.load_episode(episode_id)
|
| 86 |
+
episode = old_episode.merge(episode)
|
| 87 |
+
incr_num_steps = len(episode) - len(old_episode)
|
| 88 |
+
self.lengths[episode_id] = len(episode)
|
| 89 |
+
self.start_idx[episode_id + 1:] += incr_num_steps
|
| 90 |
+
self.num_steps += incr_num_steps
|
| 91 |
+
|
| 92 |
+
episode_path = self._get_episode_path(episode_id)
|
| 93 |
+
episode_path.parent.mkdir(parents=True, exist_ok=True)
|
| 94 |
+
torch.save(episode.__dict__, episode_path.with_suffix('.tmp'))
|
| 95 |
+
episode_path.with_suffix('.tmp').rename(episode_path)
|
| 96 |
+
|
| 97 |
+
return episode_id
|
| 98 |
+
|
| 99 |
+
def get_episode_id_from_global_idx(self, global_idx: np.ndarray) -> np.ndarray:
|
| 100 |
+
return (np.argmax(self.start_idx.reshape(-1, 1) > global_idx, axis=0) - 1) % self.num_episodes
|
| 101 |
+
|
| 102 |
+
def get_global_idx_from_segment_id(self, segment_id: SegmentId) -> np.ndarray:
|
| 103 |
+
start_idx = self.start_idx[segment_id.episode_id]
|
| 104 |
+
return np.arange(start_idx + segment_id.start, start_idx + segment_id.stop)
|
delta-iris/src/data/episode.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class EpisodeMetrics:
|
| 9 |
+
episode_length: int
|
| 10 |
+
episode_return: float
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class Episode:
|
| 15 |
+
observations: torch.ByteTensor
|
| 16 |
+
actions: torch.LongTensor
|
| 17 |
+
rewards: torch.FloatTensor
|
| 18 |
+
ends: torch.LongTensor
|
| 19 |
+
|
| 20 |
+
def __post_init__(self):
|
| 21 |
+
assert len(self.observations) == len(self.actions) == len(self.rewards) == len(self.ends)
|
| 22 |
+
if self.ends.sum() > 0:
|
| 23 |
+
idx_end = torch.argmax(self.ends) + 1
|
| 24 |
+
self.observations = self.observations[:idx_end]
|
| 25 |
+
self.actions = self.actions[:idx_end]
|
| 26 |
+
self.rewards = self.rewards[:idx_end]
|
| 27 |
+
self.ends = self.ends[:idx_end]
|
| 28 |
+
|
| 29 |
+
def __len__(self) -> int:
|
| 30 |
+
return self.observations.size(0)
|
| 31 |
+
|
| 32 |
+
def merge(self, other: Episode) -> Episode:
|
| 33 |
+
return Episode(
|
| 34 |
+
torch.cat((self.observations, other.observations), dim=0),
|
| 35 |
+
torch.cat((self.actions, other.actions), dim=0),
|
| 36 |
+
torch.cat((self.rewards, other.rewards), dim=0),
|
| 37 |
+
torch.cat((self.ends, other.ends), dim=0),
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def compute_metrics(self) -> EpisodeMetrics:
|
| 41 |
+
return EpisodeMetrics(len(self), self.rewards.sum())
|
delta-iris/src/data/episode_count.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from .dataset import EpisodeDataset
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class EpisodeCountManager:
|
| 11 |
+
def __init__(self, dataset: EpisodeDataset) -> None:
|
| 12 |
+
self.dataset = dataset
|
| 13 |
+
self.all_counts = dict()
|
| 14 |
+
|
| 15 |
+
def load(self, path_to_checkpoint: Path) -> None:
|
| 16 |
+
self.all_counts = torch.load(path_to_checkpoint)
|
| 17 |
+
assert all([counts.shape[0] == self.dataset.num_episodes for counts in self.all_counts.values()])
|
| 18 |
+
|
| 19 |
+
def save(self, path_to_checkpoint: Path) -> None:
|
| 20 |
+
torch.save(self.all_counts, path_to_checkpoint)
|
| 21 |
+
|
| 22 |
+
def register(self, *keys: Tuple[str]) -> None:
|
| 23 |
+
assert all([key not in self.all_counts for key in keys])
|
| 24 |
+
self.all_counts.update({key: np.zeros(self.dataset.num_episodes, dtype=np.int64) for key in keys})
|
| 25 |
+
|
| 26 |
+
def add_episode(self, episode_id: int) -> None:
|
| 27 |
+
for key, counts in self.all_counts.items():
|
| 28 |
+
assert episode_id <= counts.shape[0]
|
| 29 |
+
if episode_id == counts.shape[0]:
|
| 30 |
+
self.all_counts[key] = np.concatenate((counts, np.zeros(1, dtype=np.int64)))
|
| 31 |
+
assert self.all_counts[key].shape[0] == self.dataset.num_episodes
|
| 32 |
+
|
| 33 |
+
def increment_episode_count(self, key: str, episode_id: int) -> None:
|
| 34 |
+
assert key in self.all_counts
|
| 35 |
+
self.all_counts[key][episode_id] += 1
|
| 36 |
+
|
| 37 |
+
def compute_probabilities(self, key: str, alpha: float) -> np.ndarray:
|
| 38 |
+
assert key in self.all_counts
|
| 39 |
+
inverse_counts = 1 / (1 + self.all_counts[key])
|
| 40 |
+
p = inverse_counts ** alpha
|
| 41 |
+
return p / p.sum()
|
delta-iris/src/data/sampler.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Generator, List
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from .dataset import EpisodeDataset
|
| 7 |
+
from .segment import SegmentId
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BatchSampler(torch.utils.data.Sampler):
|
| 11 |
+
def __init__(self, dataset: EpisodeDataset, num_steps_per_epoch: int, batch_size: int, sequence_length: int, can_sample_beyond_end: bool) -> None:
|
| 12 |
+
super().__init__(dataset)
|
| 13 |
+
self.dataset = dataset
|
| 14 |
+
self.probabilities = None
|
| 15 |
+
self.num_steps_per_epoch = num_steps_per_epoch
|
| 16 |
+
self.batch_size = batch_size
|
| 17 |
+
self.sequence_length = sequence_length
|
| 18 |
+
self.can_sample_beyond_end = can_sample_beyond_end
|
| 19 |
+
|
| 20 |
+
def __len__(self) -> int:
|
| 21 |
+
return self.num_steps_per_epoch
|
| 22 |
+
|
| 23 |
+
def __iter__(self) -> Generator[List[SegmentId], None, None]:
|
| 24 |
+
for _ in range(self.num_steps_per_epoch):
|
| 25 |
+
yield self.sample()
|
| 26 |
+
|
| 27 |
+
def sample(self) -> List[SegmentId]:
|
| 28 |
+
episode_ids = np.random.choice(np.arange(self.dataset.num_episodes), size=self.batch_size, replace=True, p=self.probabilities)
|
| 29 |
+
timesteps = np.random.randint(low=0, high=self.dataset.lengths[episode_ids])
|
| 30 |
+
|
| 31 |
+
# padding allowed, both before start and after end
|
| 32 |
+
if self.can_sample_beyond_end:
|
| 33 |
+
starts = timesteps - np.random.randint(0, self.sequence_length, len(timesteps))
|
| 34 |
+
stops = starts + self.sequence_length
|
| 35 |
+
|
| 36 |
+
# padding allowed only before start
|
| 37 |
+
else:
|
| 38 |
+
stops = np.minimum(self.dataset.lengths[episode_ids], timesteps + 1 + np.random.randint(0, self.sequence_length, len(timesteps)))
|
| 39 |
+
starts = stops - self.sequence_length
|
| 40 |
+
|
| 41 |
+
return list(map(lambda x: SegmentId(*x), zip(episode_ids, starts, stops)))
|
| 42 |
+
|
delta-iris/src/data/segment.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class SegmentId:
|
| 9 |
+
episode_id: int
|
| 10 |
+
start: int
|
| 11 |
+
stop: int
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class Segment:
|
| 16 |
+
observations: torch.ByteTensor
|
| 17 |
+
actions: torch.LongTensor
|
| 18 |
+
rewards: torch.FloatTensor
|
| 19 |
+
ends: torch.LongTensor
|
| 20 |
+
mask_padding: torch.BoolTensor
|
| 21 |
+
id: SegmentId
|
| 22 |
+
|
| 23 |
+
@property
|
| 24 |
+
def effective_size(self) -> int:
|
| 25 |
+
return self.mask_padding.sum().item()
|
delta-iris/src/data/utils.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Generator, List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from .batch import Batch
|
| 7 |
+
from .episode import Episode
|
| 8 |
+
from .segment import Segment, SegmentId
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def collate_segments_to_batch(segments: List[Segment]) -> Batch:
|
| 12 |
+
return Batch(
|
| 13 |
+
torch.stack(list(map(lambda s: s.observations, segments))).div(255),
|
| 14 |
+
torch.stack(list(map(lambda s: s.actions, segments))),
|
| 15 |
+
torch.stack(list(map(lambda s: s.rewards, segments))),
|
| 16 |
+
torch.stack(list(map(lambda s: s.ends, segments))),
|
| 17 |
+
torch.stack(list(map(lambda s: s.mask_padding, segments))),
|
| 18 |
+
list(map(lambda segment: segment.id, segments))
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def make_segment(episode: Episode, segment_id: SegmentId, should_pad: bool = True) -> Segment:
|
| 23 |
+
assert segment_id.start < len(episode) and segment_id.stop > 0 and segment_id.start < segment_id.stop
|
| 24 |
+
padding_length_right = max(0, segment_id.stop - len(episode))
|
| 25 |
+
padding_length_left = max(0, -segment_id.start)
|
| 26 |
+
assert padding_length_right == padding_length_left == 0 or should_pad
|
| 27 |
+
|
| 28 |
+
def pad(x):
|
| 29 |
+
pad_right = torch.nn.functional.pad(x, [0 for _ in range(2 * x.ndim - 1)] + [padding_length_right]) if padding_length_right > 0 else x
|
| 30 |
+
return torch.nn.functional.pad(pad_right, [0 for _ in range(2 * x.ndim - 2)] + [padding_length_left, 0]) if padding_length_left > 0 else pad_right
|
| 31 |
+
|
| 32 |
+
start = max(0, segment_id.start)
|
| 33 |
+
stop = min(len(episode), segment_id.stop)
|
| 34 |
+
|
| 35 |
+
return Segment(
|
| 36 |
+
pad(episode.observations[start:stop]),
|
| 37 |
+
pad(episode.actions[start:stop]),
|
| 38 |
+
pad(episode.rewards[start:stop]),
|
| 39 |
+
pad(episode.ends[start:stop]),
|
| 40 |
+
mask_padding=torch.cat((torch.zeros(padding_length_left), torch.ones(stop - start), torch.zeros(padding_length_right))).bool(),
|
| 41 |
+
id=SegmentId(segment_id.episode_id, start, stop)
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class DatasetTraverser:
|
| 46 |
+
def __init__(self, dataset, batch_num_samples: int, chunk_size: int) -> None:
|
| 47 |
+
self.dataset = dataset
|
| 48 |
+
self.batch_num_samples = batch_num_samples
|
| 49 |
+
self.chunk_size = chunk_size
|
| 50 |
+
self._num_batches = math.ceil(sum([math.ceil(dataset.lengths[episode_id] / chunk_size) - int(dataset.lengths[episode_id] % chunk_size == 1) for episode_id in range(dataset.num_episodes)]) / batch_num_samples)
|
| 51 |
+
|
| 52 |
+
def __len__(self) -> int:
|
| 53 |
+
return self._num_batches
|
| 54 |
+
|
| 55 |
+
def __iter__(self) -> Generator[Batch, None, None]:
|
| 56 |
+
chunks = []
|
| 57 |
+
|
| 58 |
+
for episode_id in range(self.dataset.num_episodes):
|
| 59 |
+
episode = self.dataset.load_episode(episode_id)
|
| 60 |
+
chunks.extend(make_segment(episode, SegmentId(episode_id, start=i * self.chunk_size, stop=(i + 1) * self.chunk_size), should_pad=True) for i in range(math.ceil(len(episode) / self.chunk_size)))
|
| 61 |
+
if chunks[-1].effective_size < 2:
|
| 62 |
+
chunks.pop()
|
| 63 |
+
|
| 64 |
+
while len(chunks) >= self.batch_num_samples:
|
| 65 |
+
yield collate_segments_to_batch(chunks[:self.batch_num_samples])
|
| 66 |
+
chunks = chunks[self.batch_num_samples:]
|
| 67 |
+
|
| 68 |
+
if len(chunks) > 0:
|
| 69 |
+
yield collate_segments_to_batch(chunks)
|
delta-iris/src/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .tokenizer import Tokenizer
|
delta-iris/src/models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (315 Bytes). View file
|
|
|
delta-iris/src/models/__pycache__/convnet.cpython-310.pyc
ADDED
|
Binary file (4.33 kB). View file
|
|
|
delta-iris/src/models/__pycache__/kv_caching.cpython-310.pyc
ADDED
|
Binary file (5.78 kB). View file
|
|
|
delta-iris/src/models/__pycache__/slicer.cpython-310.pyc
ADDED
|
Binary file (3.3 kB). View file
|
|
|
delta-iris/src/models/__pycache__/transformer.cpython-310.pyc
ADDED
|
Binary file (6.47 kB). View file
|
|
|
delta-iris/src/models/__pycache__/world_model.cpython-310.pyc
ADDED
|
Binary file (5.2 kB). View file
|
|
|
delta-iris/src/models/convnet.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class FrameCnnConfig:
|
| 12 |
+
image_channels: int
|
| 13 |
+
latent_dim: int
|
| 14 |
+
num_channels: int
|
| 15 |
+
mult: List[int]
|
| 16 |
+
down: List[int]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class FrameEncoder(nn.Module):
|
| 20 |
+
def __init__(self, config: FrameCnnConfig) -> None:
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
assert len(config.mult) == len(config.down)
|
| 24 |
+
encoder_layers = [nn.Conv2d(config.image_channels, config.num_channels, kernel_size=3, stride=1, padding=1)]
|
| 25 |
+
input_channels = config.num_channels
|
| 26 |
+
|
| 27 |
+
for m, d in zip(config.mult, config.down):
|
| 28 |
+
output_channels = m * config.num_channels
|
| 29 |
+
encoder_layers.append(ResidualBlock(input_channels, output_channels))
|
| 30 |
+
input_channels = output_channels
|
| 31 |
+
if d:
|
| 32 |
+
encoder_layers.append(Downsample(output_channels))
|
| 33 |
+
encoder_layers.extend([
|
| 34 |
+
nn.GroupNorm(num_groups=32, num_channels=input_channels),
|
| 35 |
+
nn.SiLU(inplace=True),
|
| 36 |
+
nn.Conv2d(input_channels, config.latent_dim, kernel_size=3, stride=1, padding=1)
|
| 37 |
+
])
|
| 38 |
+
self.encoder = nn.Sequential(*encoder_layers)
|
| 39 |
+
|
| 40 |
+
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
| 41 |
+
b, t, _, _, _ = x.size()
|
| 42 |
+
x = rearrange(x, 'b t c h w -> (b t) c h w')
|
| 43 |
+
x = self.encoder(x)
|
| 44 |
+
x = rearrange(x, '(b t) c h w -> b t c h w', b=b, t=t)
|
| 45 |
+
|
| 46 |
+
return x
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class FrameDecoder(nn.Module):
|
| 50 |
+
def __init__(self, config: FrameCnnConfig) -> None:
|
| 51 |
+
super().__init__()
|
| 52 |
+
|
| 53 |
+
assert len(config.mult) == len(config.down)
|
| 54 |
+
decoder_layers = []
|
| 55 |
+
output_channels = config.num_channels
|
| 56 |
+
|
| 57 |
+
for m, d in zip(config.mult, config.down):
|
| 58 |
+
input_channels = m * config.num_channels
|
| 59 |
+
decoder_layers.append(ResidualBlock(input_channels, output_channels))
|
| 60 |
+
output_channels = input_channels
|
| 61 |
+
if d:
|
| 62 |
+
decoder_layers.append(Upsample(input_channels))
|
| 63 |
+
decoder_layers.reverse()
|
| 64 |
+
decoder_layers.insert(0, nn.Conv2d(config.latent_dim, input_channels, kernel_size=3, stride=1, padding=1))
|
| 65 |
+
decoder_layers.extend([
|
| 66 |
+
nn.GroupNorm(num_groups=32, num_channels=config.num_channels),
|
| 67 |
+
nn.SiLU(inplace=True),
|
| 68 |
+
nn.Conv2d(config.num_channels, config.image_channels, kernel_size=3, stride=1, padding=1)
|
| 69 |
+
])
|
| 70 |
+
self.decoder = nn.Sequential(*decoder_layers)
|
| 71 |
+
|
| 72 |
+
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
| 73 |
+
b, t, _, _, _ = x.size()
|
| 74 |
+
x = rearrange(x, 'b t c h w -> (b t) c h w')
|
| 75 |
+
x = self.decoder(x)
|
| 76 |
+
x = rearrange(x, '(b t) c h w -> b t c h w', b=b, t=t)
|
| 77 |
+
return x
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class ResidualBlock(nn.Module):
|
| 81 |
+
def __init__(self, in_channels: int, out_channels: int, num_groups_norm: int = 32) -> None:
|
| 82 |
+
super().__init__()
|
| 83 |
+
|
| 84 |
+
self.f = nn.Sequential(
|
| 85 |
+
nn.GroupNorm(num_groups_norm, in_channels),
|
| 86 |
+
nn.SiLU(inplace=True),
|
| 87 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
| 88 |
+
nn.GroupNorm(num_groups_norm, out_channels),
|
| 89 |
+
nn.SiLU(inplace=True),
|
| 90 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
| 91 |
+
)
|
| 92 |
+
self.skip_projection = nn.Identity() if in_channels == out_channels else torch.nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
| 93 |
+
|
| 94 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 95 |
+
return self.skip_projection(x) + self.f(x)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class Downsample(nn.Module):
|
| 99 |
+
def __init__(self, num_channels: int) -> None:
|
| 100 |
+
super().__init__()
|
| 101 |
+
self.conv = nn.Conv2d(num_channels, num_channels, kernel_size=2, stride=2, padding=0)
|
| 102 |
+
|
| 103 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 104 |
+
return self.conv(x)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class Upsample(nn.Module):
|
| 108 |
+
def __init__(self, num_channels: int) -> None:
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.conv = nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1)
|
| 111 |
+
|
| 112 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 113 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 114 |
+
return self.conv(x)
|
delta-iris/src/models/kv_caching.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Cache:
|
| 8 |
+
def __init__(self, num_samples: int, max_tokens: int, embed_dim: int, device: torch.device) -> None:
|
| 9 |
+
self._n, self._cache, self._size = num_samples, None, None
|
| 10 |
+
self._reset = lambda n: torch.empty(n, max_tokens, embed_dim, device=device) # (B, T, E)
|
| 11 |
+
self.reset()
|
| 12 |
+
|
| 13 |
+
@property
|
| 14 |
+
def shape(self) -> Tuple[int, int, int]:
|
| 15 |
+
n, _, embed_dim = self._cache.shape
|
| 16 |
+
|
| 17 |
+
return n, self._size, embed_dim
|
| 18 |
+
|
| 19 |
+
def reset(self) -> None:
|
| 20 |
+
self._cache = self._reset(self._n)
|
| 21 |
+
self._size = 0
|
| 22 |
+
|
| 23 |
+
def prune(self, mask: np.ndarray) -> None:
|
| 24 |
+
assert mask.ndim == 1 and mask.shape[0] == self.shape[0]
|
| 25 |
+
self._cache = self._cache[mask]
|
| 26 |
+
self._n = self._cache.shape[0]
|
| 27 |
+
|
| 28 |
+
def get(self) -> torch.Tensor:
|
| 29 |
+
return self._cache[:, :self._size, :]
|
| 30 |
+
|
| 31 |
+
def update(self, x: torch.Tensor) -> None:
|
| 32 |
+
assert (x.ndim == self._cache.ndim) and all([x.size(i) == self._cache.size(i) for i in (0, 2)])
|
| 33 |
+
assert self._size + x.size(1) <= self._cache.shape[1]
|
| 34 |
+
self._cache = AssignWithoutInplaceCheck.apply(self._cache, x, 1, self._size, self._size + x.size(1))
|
| 35 |
+
self._size += x.size(1)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class KVCache:
|
| 39 |
+
def __init__(self, n: int, max_tokens: int, embed_dim: int, device: torch.device) -> None:
|
| 40 |
+
self._k_cache = Cache(n, max_tokens, embed_dim, device)
|
| 41 |
+
self._v_cache = Cache(n, max_tokens, embed_dim, device)
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def shape(self) -> Tuple[int, int, int]:
|
| 45 |
+
return self._k_cache.shape
|
| 46 |
+
|
| 47 |
+
def reset(self) -> None:
|
| 48 |
+
self._k_cache.reset()
|
| 49 |
+
self._v_cache.reset()
|
| 50 |
+
|
| 51 |
+
def prune(self, mask: np.ndarray) -> None:
|
| 52 |
+
self._k_cache.prune(mask)
|
| 53 |
+
self._v_cache.prune(mask)
|
| 54 |
+
|
| 55 |
+
def get(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 56 |
+
return self._k_cache.get(), self._v_cache.get()
|
| 57 |
+
|
| 58 |
+
def update(self, k: torch.Tensor, v: torch.Tensor):
|
| 59 |
+
self._k_cache.update(k)
|
| 60 |
+
self._v_cache.update(v)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class KeysValues:
|
| 64 |
+
def __init__(self, n: int, max_tokens: int, embed_dim: int, num_layers: int, device: torch.device) -> None:
|
| 65 |
+
self._keys_values = tuple([KVCache(n, max_tokens, embed_dim, device) for _ in range(num_layers)])
|
| 66 |
+
|
| 67 |
+
def __getitem__(self, key: int) -> KVCache:
|
| 68 |
+
return self._keys_values[key]
|
| 69 |
+
|
| 70 |
+
def __len__(self):
|
| 71 |
+
return len(self._keys_values)
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def size(self):
|
| 75 |
+
return self._keys_values[0].shape[1]
|
| 76 |
+
|
| 77 |
+
def reset(self) -> None:
|
| 78 |
+
for kv_cache in self._keys_values:
|
| 79 |
+
kv_cache.reset()
|
| 80 |
+
|
| 81 |
+
def prune(self, mask: np.ndarray) -> None:
|
| 82 |
+
for kv_cache in self._keys_values:
|
| 83 |
+
kv_cache.prune(mask)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class AssignWithoutInplaceCheck(torch.autograd.Function):
|
| 87 |
+
"""
|
| 88 |
+
Inspired from : https://discuss.pytorch.org/t/disable-in-place-correctness-version-check-any-other-workaround/90738/4
|
| 89 |
+
Warning : do not use it to overwrite a slice twice.
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
@staticmethod
|
| 93 |
+
def get_slice(dim: int, start: int, stop: int) -> Tuple[slice]:
|
| 94 |
+
return tuple([slice(None), ] * dim + [slice(start, stop)])
|
| 95 |
+
|
| 96 |
+
@staticmethod
|
| 97 |
+
def forward(ctx, input: torch.Tensor, value: torch.Tensor, dim: int, start: int, stop: int) -> torch.Tensor:
|
| 98 |
+
ctx.dim = dim
|
| 99 |
+
ctx.start = start
|
| 100 |
+
ctx.stop = stop
|
| 101 |
+
input.data[AssignWithoutInplaceCheck.get_slice(dim, start, stop)] = value
|
| 102 |
+
return input
|
| 103 |
+
|
| 104 |
+
@staticmethod
|
| 105 |
+
def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor]:
|
| 106 |
+
return grad_out, grad_out[AssignWithoutInplaceCheck.get_slice(ctx.dim, ctx.start, ctx.stop)], None, None, None
|
delta-iris/src/models/slicer.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Slicer(nn.Module):
|
| 9 |
+
def __init__(self, max_blocks: int, block_mask: torch.Tensor) -> None:
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.block_size = block_mask.size(0)
|
| 12 |
+
self.num_kept_tokens = block_mask.sum().long().item()
|
| 13 |
+
kept_indices = torch.where(block_mask)[0].repeat(max_blocks)
|
| 14 |
+
offsets = torch.arange(max_blocks).repeat_interleave(self.num_kept_tokens)
|
| 15 |
+
self.register_buffer('indices', kept_indices + block_mask.size(0) * offsets)
|
| 16 |
+
|
| 17 |
+
def compute_slice(self, num_steps: int, prev_steps: int = 0) -> torch.Tensor:
|
| 18 |
+
total_steps = num_steps + prev_steps
|
| 19 |
+
num_blocks = math.ceil(total_steps / self.block_size)
|
| 20 |
+
indices = self.indices[:num_blocks * self.num_kept_tokens]
|
| 21 |
+
return indices[torch.logical_and(prev_steps <= indices, indices < total_steps)] - prev_steps
|
| 22 |
+
|
| 23 |
+
def forward(self, *args, **kwargs):
|
| 24 |
+
raise NotImplementedError
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Head(Slicer):
|
| 28 |
+
def __init__(self, max_blocks: int, block_mask: torch.Tensor, head_module: nn.Module) -> None:
|
| 29 |
+
super().__init__(max_blocks, block_mask)
|
| 30 |
+
assert isinstance(head_module, nn.Module)
|
| 31 |
+
self.head_module = head_module
|
| 32 |
+
|
| 33 |
+
def forward(self, x: torch.Tensor, num_steps: int, prev_steps: int) -> torch.Tensor:
|
| 34 |
+
x_sliced = x[:, self.compute_slice(num_steps, prev_steps)] # x is (B, T, E)
|
| 35 |
+
return self.head_module(x_sliced)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Embedder(nn.Module):
|
| 39 |
+
def __init__(self, max_blocks: int, block_masks: List[torch.Tensor], embedding_tables: List[nn.Embedding]) -> None:
|
| 40 |
+
super().__init__()
|
| 41 |
+
assert len(block_masks) == len(embedding_tables)
|
| 42 |
+
assert (sum(block_masks) == 1).all() # block mask are a partition of a block
|
| 43 |
+
self.embedding_dim = embedding_tables[0].embedding_dim
|
| 44 |
+
assert all([e.embedding_dim == self.embedding_dim for e in embedding_tables])
|
| 45 |
+
self.embedding_tables = embedding_tables
|
| 46 |
+
self.slicers = [Slicer(max_blocks, block_mask) for block_mask in block_masks]
|
| 47 |
+
|
| 48 |
+
def forward(self, tokens: torch.LongTensor, num_steps: int, prev_steps: int) -> torch.FloatTensor:
|
| 49 |
+
assert tokens.ndim == 2 # x is (B, T)
|
| 50 |
+
output = torch.zeros(*tokens.size(), self.embedding_dim, device=tokens.device)
|
| 51 |
+
for slicer, emb in zip(self.slicers, self.embedding_tables):
|
| 52 |
+
s = slicer.compute_slice(num_steps, prev_steps)
|
| 53 |
+
output[:, s] = emb(tokens[:, s])
|
| 54 |
+
|
| 55 |
+
return output
|
delta-iris/src/models/tokenizer/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from ....tokenizer import Tokenizer, TokenizerConfig
|
delta-iris/src/models/tokenizer/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (257 Bytes). View file
|
|
|
delta-iris/src/models/tokenizer/__pycache__/quantizer.cpython-310.pyc
ADDED
|
Binary file (3.74 kB). View file
|
|
|
delta-iris/src/models/tokenizer/__pycache__/tokenizer.cpython-310.pyc
ADDED
|
Binary file (4.8 kB). View file
|
|
|
delta-iris/src/models/tokenizer/quantizer.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
import math
|
| 3 |
+
from typing import Dict, Optional
|
| 4 |
+
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class QuantizerOutput:
|
| 13 |
+
q: torch.FloatTensor
|
| 14 |
+
tokens: torch.LongTensor
|
| 15 |
+
loss: Dict[str, torch.FloatTensor]
|
| 16 |
+
metrics: Dict[str, float]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Quantizer(nn.Module):
|
| 20 |
+
def __init__(self, codebook_size: int, codebook_dim: int, input_dim: int, max_codebook_updates_with_revival: Optional[int] = None) -> None:
|
| 21 |
+
super().__init__()
|
| 22 |
+
assert math.log2(codebook_size).is_integer()
|
| 23 |
+
self.revival_entropy_threshold = int(math.log2(codebook_size)) - 2
|
| 24 |
+
self.max_codebook_updates_with_revival = max_codebook_updates_with_revival
|
| 25 |
+
self.pre_quant_proj = nn.Linear(input_dim, codebook_dim)
|
| 26 |
+
self.post_quant_proj = nn.Linear(codebook_dim, input_dim)
|
| 27 |
+
codebook = torch.empty(codebook_size, codebook_dim, requires_grad=False).uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
|
| 28 |
+
self.register_buffer('num_codebook_updates', torch.tensor(0))
|
| 29 |
+
self.register_buffer('codebook', codebook)
|
| 30 |
+
self.register_buffer('codewords_freqs', torch.ones(codebook_size).div(codebook_size))
|
| 31 |
+
|
| 32 |
+
def forward(self, z: torch.Tensor) -> QuantizerOutput:
|
| 33 |
+
z = self.pre_quant_proj(z)
|
| 34 |
+
z = F.normalize(z, dim=-1)
|
| 35 |
+
b, k = z.size(0), z.size(2)
|
| 36 |
+
z = rearrange(z, 'b t k e -> (b t k) e')
|
| 37 |
+
|
| 38 |
+
cosine_similarity = torch.einsum('n e, c e -> n c', z, self.codebook)
|
| 39 |
+
tokens = cosine_similarity.argmax(dim=-1)
|
| 40 |
+
q = self.codebook[tokens]
|
| 41 |
+
|
| 42 |
+
losses = {'commitment_loss': 0.02 * (z - q.detach()).pow(2).mean()}
|
| 43 |
+
|
| 44 |
+
if self.training:
|
| 45 |
+
metrics = {**self.update_codebook(z, tokens), 'codebook_entropy': self.compute_codebook_entropy()}
|
| 46 |
+
else:
|
| 47 |
+
metrics = {}
|
| 48 |
+
|
| 49 |
+
q = z + (q - z).detach()
|
| 50 |
+
q = self.post_quant_proj(q)
|
| 51 |
+
|
| 52 |
+
q = rearrange(q, '(b t k) e -> b t k e', b=b, k=k)
|
| 53 |
+
tokens = rearrange(tokens, '(b t k) -> b t k', b=b, k=k)
|
| 54 |
+
|
| 55 |
+
return QuantizerOutput(q, tokens, losses, metrics)
|
| 56 |
+
|
| 57 |
+
@torch.no_grad()
|
| 58 |
+
def update_codebook(self, z: torch.Tensor, tokens: torch.LongTensor) -> None:
|
| 59 |
+
tokens_one_hot = F.one_hot(tokens, self.codebook.size(0)).float() # (N, C)
|
| 60 |
+
|
| 61 |
+
# Update codebook
|
| 62 |
+
counts = tokens_one_hot.sum(dim=0)
|
| 63 |
+
codebook_update = torch.einsum('n e, n c -> c e', z, tokens_one_hot) / torch.clamp(counts.unsqueeze(-1), min=1)
|
| 64 |
+
codebook_update = F.normalize(codebook_update, dim=-1)
|
| 65 |
+
self.codebook.lerp_(codebook_update, 1 - 0.99)
|
| 66 |
+
|
| 67 |
+
# Update counts and revive dead codewords
|
| 68 |
+
freqs = counts / tokens_one_hot.size(0)
|
| 69 |
+
self.codewords_freqs.lerp_(freqs, 1 - 0.98)
|
| 70 |
+
|
| 71 |
+
can_revive = (self.compute_codebook_entropy() < 1) or (self.max_codebook_updates_with_revival is None) or (self.num_codebook_updates.item() < self.max_codebook_updates_with_revival)
|
| 72 |
+
if can_revive and (self.compute_codebook_entropy() < self.revival_entropy_threshold):
|
| 73 |
+
expired = torch.where(self.codewords_freqs < 1 / (10 * self.codewords_freqs.size(0)))[0]
|
| 74 |
+
num_expired = expired.size(0)
|
| 75 |
+
expired = expired[torch.randperm(num_expired)[:z.size(0)]]
|
| 76 |
+
idx_revived = torch.randperm(z.size(0), device=z.device)[:expired.size(0)]
|
| 77 |
+
self.codebook[expired] = z[idx_revived]
|
| 78 |
+
self.codewords_freqs[expired] = 1 / self.codewords_freqs.size(0)
|
| 79 |
+
else:
|
| 80 |
+
num_expired = 0
|
| 81 |
+
|
| 82 |
+
self.codebook = F.normalize(self.codebook, dim=-1)
|
| 83 |
+
|
| 84 |
+
self.num_codebook_updates += 1
|
| 85 |
+
metrics = {'codewords_revived': num_expired}
|
| 86 |
+
|
| 87 |
+
return metrics
|
| 88 |
+
|
| 89 |
+
def compute_codebook_entropy(self) -> float:
|
| 90 |
+
probs = self.codewords_freqs[self.codewords_freqs != 0]
|
| 91 |
+
return -(torch.log2(probs) * probs).sum().item()
|
| 92 |
+
|
| 93 |
+
@torch.no_grad()
|
| 94 |
+
def embed_tokens(self, tokens: torch.LongTensor) -> torch.FloatTensor:
|
| 95 |
+
return self.post_quant_proj(self.codebook[tokens])
|
delta-iris/src/models/transformer.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inspired from https://github.com/karpathy/minGPT
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
from .kv_caching import KeysValues, KVCache
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class TransformerConfig:
|
| 17 |
+
|
| 18 |
+
tokens_per_block: int
|
| 19 |
+
max_blocks: int
|
| 20 |
+
|
| 21 |
+
num_layers: int
|
| 22 |
+
num_heads: int
|
| 23 |
+
embed_dim: int
|
| 24 |
+
|
| 25 |
+
attention: str
|
| 26 |
+
|
| 27 |
+
embed_pdrop: float
|
| 28 |
+
resid_pdrop: float
|
| 29 |
+
attn_pdrop: float
|
| 30 |
+
|
| 31 |
+
@property
|
| 32 |
+
def max_tokens(self):
|
| 33 |
+
return self.tokens_per_block * self.max_blocks
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class TransformerEncoder(nn.Module):
|
| 37 |
+
def __init__(self, config: TransformerConfig) -> None:
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.config = config
|
| 40 |
+
|
| 41 |
+
self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim)
|
| 42 |
+
self.emb_drop = nn.Dropout(config.embed_pdrop)
|
| 43 |
+
self.ln = nn.LayerNorm(config.embed_dim)
|
| 44 |
+
|
| 45 |
+
assert config.attention in ('causal', 'block_causal')
|
| 46 |
+
k, m = config.tokens_per_block, config.max_blocks
|
| 47 |
+
mask_sa = torch.tril(torch.ones(k * m, k * m))
|
| 48 |
+
if config.attention == 'block_causal':
|
| 49 |
+
mask_sa = torch.max(mask_sa, torch.block_diag(*[torch.ones(k, k) for _ in range(m)]))
|
| 50 |
+
mask_sa = mask_sa.bool()
|
| 51 |
+
|
| 52 |
+
self.blocks = nn.ModuleList([EncoderLayer(config, mask_sa) for _ in range(config.num_layers)])
|
| 53 |
+
self.keys_values = None
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def num_blocks_left_in_kv_cache(self) -> float:
|
| 57 |
+
assert self.keys_values is not None
|
| 58 |
+
return (self.config.max_tokens - self.keys_values.size) / self.config.tokens_per_block
|
| 59 |
+
|
| 60 |
+
def reset_kv_cache(self, n: int) -> None:
|
| 61 |
+
device = self.ln.weight.device
|
| 62 |
+
self.keys_values = KeysValues(n, self.config.max_tokens, self.config.embed_dim, self.config.num_layers, device)
|
| 63 |
+
|
| 64 |
+
def forward(self, x: torch.FloatTensor, use_kv_cache: bool = False) -> torch.FloatTensor:
|
| 65 |
+
assert x.ndim == 3 and x.size(2) == self.config.embed_dim # (B, TK, E)
|
| 66 |
+
|
| 67 |
+
prev_steps = self.keys_values.size if use_kv_cache else 0
|
| 68 |
+
inputs = x + self.pos_emb(prev_steps + torch.arange(x.size(1), device=x.device))
|
| 69 |
+
|
| 70 |
+
y = self.emb_drop(inputs)
|
| 71 |
+
for i, block in enumerate(self.blocks):
|
| 72 |
+
y = block(y, self.keys_values[i] if use_kv_cache else None)
|
| 73 |
+
y = self.ln(y)
|
| 74 |
+
|
| 75 |
+
return y
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class EncoderLayer(nn.Module):
|
| 79 |
+
def __init__(self, config: TransformerConfig, mask_sa: torch.LongTensor) -> None:
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.sa = SelfAttentionLayer(config, mask=mask_sa)
|
| 82 |
+
self.mlp = MLPLayer(config)
|
| 83 |
+
|
| 84 |
+
def forward(self, x: torch.FloatTensor, kv_cache: Optional[KVCache] = None) -> torch.FloatTensor:
|
| 85 |
+
return self.mlp(self.sa(x, kv_cache))
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class MLPLayer(nn.Module):
|
| 89 |
+
def __init__(self, config: TransformerConfig) -> None:
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.ln = nn.LayerNorm(config.embed_dim)
|
| 92 |
+
self.mlp = nn.Sequential(
|
| 93 |
+
nn.Linear(config.embed_dim, 4 * config.embed_dim),
|
| 94 |
+
nn.GELU(),
|
| 95 |
+
nn.Linear(4 * config.embed_dim, config.embed_dim),
|
| 96 |
+
nn.Dropout(config.resid_pdrop),
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def forward(self, inputs: torch.FloatTensor) -> torch.FloatTensor:
|
| 100 |
+
return inputs + self.mlp(self.ln(inputs))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class SelfAttentionLayer(nn.Module):
|
| 104 |
+
def __init__(self, config: TransformerConfig, mask: torch.BoolTensor) -> None:
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.register_buffer('mask', mask)
|
| 107 |
+
self.ln = nn.LayerNorm(config.embed_dim)
|
| 108 |
+
self.query = nn.Linear(config.embed_dim, config.embed_dim)
|
| 109 |
+
self.key = nn.Linear(config.embed_dim, config.embed_dim)
|
| 110 |
+
self.value = nn.Linear(config.embed_dim, config.embed_dim)
|
| 111 |
+
self.attention = Attention(config)
|
| 112 |
+
|
| 113 |
+
def forward(self, inputs: torch.FloatTensor, kv_cache: Optional[KVCache] = None) -> torch.FloatTensor:
|
| 114 |
+
B, T, C = inputs.size()
|
| 115 |
+
if kv_cache is not None:
|
| 116 |
+
b, L, c = kv_cache.shape
|
| 117 |
+
assert b == B and c == C
|
| 118 |
+
else:
|
| 119 |
+
L = 0
|
| 120 |
+
|
| 121 |
+
x = self.ln(inputs)
|
| 122 |
+
|
| 123 |
+
q = self.query(x)
|
| 124 |
+
k = self.key(x)
|
| 125 |
+
v = self.value(x)
|
| 126 |
+
|
| 127 |
+
if kv_cache is not None:
|
| 128 |
+
kv_cache.update(k, v)
|
| 129 |
+
k, v = kv_cache.get()
|
| 130 |
+
|
| 131 |
+
y = inputs + self.attention(q, k, v, self.mask[L:L + T, :L + T])
|
| 132 |
+
|
| 133 |
+
return y
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class Attention(nn.Module):
|
| 137 |
+
def __init__(self, config: TransformerConfig) -> None:
|
| 138 |
+
super().__init__()
|
| 139 |
+
assert config.embed_dim % config.num_heads == 0
|
| 140 |
+
self.num_heads = config.num_heads
|
| 141 |
+
self.attn_pdrop = config.attn_pdrop
|
| 142 |
+
self.resid_drop = nn.Dropout(config.resid_pdrop)
|
| 143 |
+
self.proj = nn.Linear(config.embed_dim, config.embed_dim)
|
| 144 |
+
|
| 145 |
+
def forward(self, q: torch.FloatTensor, k: torch.FloatTensor, v: torch.FloatTensor, mask: torch.BoolTensor) -> torch.FloatTensor:
|
| 146 |
+
assert mask.size(0) == q.size(1) and mask.size(1) == k.size(1)
|
| 147 |
+
|
| 148 |
+
q = rearrange(q, 'b q (h e) -> b h q e', h=self.num_heads)
|
| 149 |
+
k = rearrange(k, 'b k (h e) -> b h k e', h=self.num_heads)
|
| 150 |
+
v = rearrange(v, 'b k (h d) -> b h k d', h=self.num_heads)
|
| 151 |
+
|
| 152 |
+
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.attn_pdrop, is_causal=False) if q.size(2) != 0 else q.new_empty(*q.shape[:-1], v.size(-1))
|
| 153 |
+
|
| 154 |
+
y = rearrange(y, 'b h q d -> b q (h d)')
|
| 155 |
+
y = self.resid_drop(self.proj(y))
|
| 156 |
+
|
| 157 |
+
return y
|
delta-iris/src/models/utils.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
import cv2
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import random
|
| 5 |
+
import shutil
|
| 6 |
+
from typing import Callable, Dict
|
| 7 |
+
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from torch.optim import AdamW
|
| 15 |
+
|
| 16 |
+
from data import Episode
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def configure_optimizer(model: nn.Module, learning_rate: float, weight_decay: float, *blacklist_module_names) -> AdamW:
|
| 20 |
+
"""Credits to https://github.com/karpathy/minGPT"""
|
| 21 |
+
# separate out all parameters to those that will and won't experience regularizing weight decay
|
| 22 |
+
decay = set()
|
| 23 |
+
no_decay = set()
|
| 24 |
+
whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv1d)
|
| 25 |
+
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding, nn.Conv2d, nn.GroupNorm)
|
| 26 |
+
for mn, m in model.named_modules():
|
| 27 |
+
for pn, p in m.named_parameters():
|
| 28 |
+
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
| 29 |
+
if any([fpn.startswith(module_name) for module_name in blacklist_module_names]):
|
| 30 |
+
no_decay.add(fpn)
|
| 31 |
+
elif 'bias' in pn:
|
| 32 |
+
# all biases will not be decayed
|
| 33 |
+
no_decay.add(fpn)
|
| 34 |
+
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
| 35 |
+
# weights of whitelist modules will be weight decayed
|
| 36 |
+
decay.add(fpn)
|
| 37 |
+
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
| 38 |
+
# weights of blacklist modules will NOT be weight decayed
|
| 39 |
+
no_decay.add(fpn)
|
| 40 |
+
|
| 41 |
+
# validate that we considered every parameter
|
| 42 |
+
param_dict = {pn: p for pn, p in model.named_parameters()}
|
| 43 |
+
inter_params = decay & no_decay
|
| 44 |
+
union_params = decay | no_decay
|
| 45 |
+
assert len(inter_params) == 0, f"parameters {str(inter_params)} made it into both decay/no_decay sets!"
|
| 46 |
+
assert len(param_dict.keys() - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!"
|
| 47 |
+
|
| 48 |
+
# create the pytorch optimizer object
|
| 49 |
+
optim_groups = [
|
| 50 |
+
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
|
| 51 |
+
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
| 52 |
+
]
|
| 53 |
+
optimizer = AdamW(optim_groups, lr=learning_rate)
|
| 54 |
+
|
| 55 |
+
return optimizer
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def init_weights(module: nn.Module) -> None:
|
| 59 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 60 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 61 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 62 |
+
module.bias.data.zero_()
|
| 63 |
+
elif isinstance(module, nn.LayerNorm):
|
| 64 |
+
module.bias.data.zero_()
|
| 65 |
+
module.weight.data.fill_(1.0)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def extract_state_dict(state_dict: Dict, module_name: str) -> OrderedDict:
|
| 69 |
+
return OrderedDict({k.split('.', 1)[1]: v for k, v in state_dict.items() if k.startswith(module_name)})
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def set_seed(seed: int) -> None:
|
| 73 |
+
np.random.seed(seed)
|
| 74 |
+
torch.manual_seed(seed)
|
| 75 |
+
torch.cuda.manual_seed(seed)
|
| 76 |
+
random.seed(seed)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@torch.no_grad()
|
| 80 |
+
def compute_discounted_returns(rewards: torch.FloatTensor, gamma: float) -> torch.FloatTensor:
|
| 81 |
+
assert 0 < gamma <= 1 and rewards.ndim == 2 # (B, T)
|
| 82 |
+
gammas = gamma ** torch.arange(rewards.size(1))
|
| 83 |
+
r = rewards * gammas
|
| 84 |
+
|
| 85 |
+
return (r + r.sum(dim=1, keepdim=True) - r.cumsum(dim=1)) / gammas
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class LossWithIntermediateLosses:
|
| 89 |
+
def __init__(self, **kwargs) -> None:
|
| 90 |
+
self.loss_total = sum(kwargs.values())
|
| 91 |
+
self.intermediate_losses = {k: v.item() for k, v in kwargs.items()}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class EpisodeDirManager:
|
| 95 |
+
def __init__(self, episode_dir: Path, max_num_episodes: int) -> None:
|
| 96 |
+
self.episode_dir = episode_dir
|
| 97 |
+
self.episode_dir.mkdir(parents=False, exist_ok=True)
|
| 98 |
+
self.max_num_episodes = max_num_episodes
|
| 99 |
+
self.best_return = float('-inf')
|
| 100 |
+
|
| 101 |
+
def save(self, episode: Episode, episode_id: int, epoch: int) -> None:
|
| 102 |
+
if self.max_num_episodes is not None and self.max_num_episodes > 0:
|
| 103 |
+
self._save(episode, episode_id, epoch)
|
| 104 |
+
|
| 105 |
+
def _save(self, episode: Episode, episode_id: int, epoch: int) -> None:
|
| 106 |
+
ep_paths = [p for p in self.episode_dir.iterdir() if p.stem.startswith('episode_')]
|
| 107 |
+
assert len(ep_paths) <= self.max_num_episodes
|
| 108 |
+
if len(ep_paths) == self.max_num_episodes:
|
| 109 |
+
to_remove = min(ep_paths, key=lambda ep_path: int(ep_path.stem.split('_')[1]))
|
| 110 |
+
to_remove.unlink()
|
| 111 |
+
torch.save(episode.__dict__, self.episode_dir / f'episode_{episode_id}_epoch_{epoch}.pt')
|
| 112 |
+
|
| 113 |
+
ep_return = episode.compute_metrics().episode_return
|
| 114 |
+
if ep_return > self.best_return:
|
| 115 |
+
self.best_return = ep_return
|
| 116 |
+
path_best_ep = [p for p in self.episode_dir.iterdir() if p.stem.startswith('best_')]
|
| 117 |
+
assert len(path_best_ep) in (0, 1)
|
| 118 |
+
if len(path_best_ep) == 1:
|
| 119 |
+
path_best_ep[0].unlink()
|
| 120 |
+
torch.save(episode.__dict__, self.episode_dir / f'best_episode_{episode_id}_epoch_{epoch}.pt')
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class RandomHeuristic:
|
| 124 |
+
def __init__(self, num_actions):
|
| 125 |
+
self.num_actions = num_actions
|
| 126 |
+
|
| 127 |
+
def act(self, obs):
|
| 128 |
+
assert obs.ndim == 4 # (N, H, W, C)
|
| 129 |
+
n = obs.size(0)
|
| 130 |
+
|
| 131 |
+
return torch.randint(low=0, high=self.num_actions, size=(n,))
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def make_video(fname, fps, frames):
|
| 135 |
+
assert frames.ndim == 4 # (T, H, W, C)
|
| 136 |
+
_, h, w, c = frames.shape
|
| 137 |
+
assert c == 3
|
| 138 |
+
|
| 139 |
+
video = cv2.VideoWriter(str(fname), cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
| 140 |
+
for frame in frames:
|
| 141 |
+
video.write(frame[:, :, ::-1])
|
| 142 |
+
video.release()
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def try_until_no_except(fn: Callable):
|
| 146 |
+
while True:
|
| 147 |
+
try:
|
| 148 |
+
fn()
|
| 149 |
+
except:
|
| 150 |
+
continue
|
| 151 |
+
else:
|
| 152 |
+
break
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def symlog(x: torch.Tensor) -> torch.Tensor:
|
| 156 |
+
return torch.sign(x) * torch.log(torch.abs(x) + 1)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def symexp(x: torch.Tensor) -> torch.Tensor:
|
| 160 |
+
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def two_hot(x: torch.FloatTensor, x_min: int = -20, x_max: int = 20, num_buckets: int = 255) -> torch.FloatTensor:
|
| 164 |
+
x.clamp_(x_min, x_max - 1e-5)
|
| 165 |
+
buckets = torch.linspace(x_min, x_max, num_buckets).to(x.device)
|
| 166 |
+
k = torch.searchsorted(buckets, x) - 1
|
| 167 |
+
values = torch.stack((buckets[k + 1] - x, x - buckets[k]), dim=-1) / (buckets[k + 1] - buckets[k]).unsqueeze(-1)
|
| 168 |
+
two_hots = torch.scatter(x.new_zeros(*x.size(), num_buckets), dim=-1, index=torch.stack((k, k + 1), dim=-1), src=values)
|
| 169 |
+
|
| 170 |
+
return two_hots
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def compute_softmax_over_buckets(logits: torch.FloatTensor, x_min: int = -20, x_max: int = 20, num_buckets: int = 255) -> torch.FloatTensor:
|
| 174 |
+
buckets = torch.linspace(x_min, x_max, num_buckets).to(logits.device)
|
| 175 |
+
probs = F.softmax(logits, dim=-1)
|
| 176 |
+
|
| 177 |
+
return probs @ buckets
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def plot_counts(counts: np.ndarray) -> Image:
|
| 181 |
+
fig, ax = plt.subplots(figsize=(14, 7))
|
| 182 |
+
ax.plot(counts)
|
| 183 |
+
p = Path('priorities.png')
|
| 184 |
+
fig.savefig(p)
|
| 185 |
+
plt.close(fig)
|
| 186 |
+
im = Image.open(p)
|
| 187 |
+
p.unlink()
|
| 188 |
+
|
| 189 |
+
return im
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def compute_mask_after_first_done(ends: torch.LongTensor) -> torch.BoolTensor:
|
| 193 |
+
assert ends.ndim == 2
|
| 194 |
+
first_one_index = torch.argmax(ends, dim=1)
|
| 195 |
+
mask = torch.arange(ends.size(1), device=ends.device).unsqueeze(0) <= first_one_index.unsqueeze(1)
|
| 196 |
+
mask = torch.logical_or(mask, ends.sum(dim=1, keepdim=True) == 0)
|
| 197 |
+
|
| 198 |
+
return mask
|
delta-iris/src/tokenizer.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
import math
|
| 3 |
+
from typing import Dict, Tuple
|
| 4 |
+
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
from .models.convnet import FrameCnnConfig, FrameEncoder, FrameDecoder
|
| 10 |
+
from .data import Batch
|
| 11 |
+
from .models.tokenizer.quantizer import Quantizer, QuantizerOutput
|
| 12 |
+
from .models.utils import init_weights, LossWithIntermediateLosses
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class TokenizerConfig:
|
| 17 |
+
image_channels: int
|
| 18 |
+
image_size: int
|
| 19 |
+
num_actions: int
|
| 20 |
+
num_tokens: int
|
| 21 |
+
decoder_act_channels: int
|
| 22 |
+
codebook_size: int
|
| 23 |
+
codebook_dim: int
|
| 24 |
+
max_codebook_updates_with_revival: int
|
| 25 |
+
encoder_config: FrameCnnConfig
|
| 26 |
+
decoder_config: FrameCnnConfig
|
| 27 |
+
frame_cnn_config: FrameCnnConfig
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Tokenizer(nn.Module):
|
| 31 |
+
def __init__(self, config: TokenizerConfig) -> None:
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.config = config
|
| 34 |
+
|
| 35 |
+
self.latent_res = config.image_size // 2 ** sum(config.encoder_config.down)
|
| 36 |
+
self.tokens_grid_res = int(math.sqrt(config.num_tokens))
|
| 37 |
+
self.token_res = self.latent_res // self.tokens_grid_res
|
| 38 |
+
|
| 39 |
+
self.encoder_act_emb = nn.Embedding(config.num_actions, config.image_size ** 2)
|
| 40 |
+
self.decoder_act_emb = nn.Embedding(config.num_actions, config.decoder_act_channels * self.latent_res ** 2)
|
| 41 |
+
|
| 42 |
+
self.quantizer = Quantizer(
|
| 43 |
+
config.codebook_size, config.codebook_dim,
|
| 44 |
+
input_dim=config.encoder_config.latent_dim * self.token_res ** 2,
|
| 45 |
+
max_codebook_updates_with_revival=config.max_codebook_updates_with_revival
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
self.encoder = FrameEncoder(config.encoder_config)
|
| 49 |
+
self.decoder = FrameDecoder(config.decoder_config)
|
| 50 |
+
self.frame_cnn = FrameEncoder(config.frame_cnn_config)
|
| 51 |
+
|
| 52 |
+
self.apply(init_weights)
|
| 53 |
+
|
| 54 |
+
def __repr__(self) -> str:
|
| 55 |
+
return "tokenizer"
|
| 56 |
+
|
| 57 |
+
def forward(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> QuantizerOutput:
|
| 58 |
+
z = self.encode(x1, a, x2)
|
| 59 |
+
z = rearrange(z, 'b t c (h k) (w l) -> b t (h w) (k l c)', h=self.tokens_grid_res, w=self.tokens_grid_res)
|
| 60 |
+
|
| 61 |
+
return self.quantizer(z)
|
| 62 |
+
|
| 63 |
+
def compute_loss(self, batch: Batch, **kwargs) -> Tuple[LossWithIntermediateLosses, Dict]:
|
| 64 |
+
x1 = batch.observations[:, :-1]
|
| 65 |
+
a = batch.actions[:, :-1]
|
| 66 |
+
x2 = batch.observations[:, 1:]
|
| 67 |
+
|
| 68 |
+
quantizer_outputs = self(x1, a, x2)
|
| 69 |
+
|
| 70 |
+
r = self.decode(x1, a, rearrange(quantizer_outputs.q, 'b t (h w) (k l e) -> b t e (h k) (w l)', h=self.tokens_grid_res, k=self.token_res, l=self.token_res))
|
| 71 |
+
delta = (x2 - r)
|
| 72 |
+
delta = delta[torch.logical_and(batch.mask_padding[:, 1:], batch.mask_padding[:, :-1])]
|
| 73 |
+
|
| 74 |
+
losses = {
|
| 75 |
+
**quantizer_outputs.loss,
|
| 76 |
+
'reconstruction_loss_l1': 0.1 * torch.abs(delta).mean(),
|
| 77 |
+
'reconstruction_loss_l2': delta.pow(2).mean(),
|
| 78 |
+
'reconstruction_loss_l2_worst_pixel': 0.01 * rearrange(delta, 'b c h w -> b (c h w)').pow(2).max(dim=-1)[0].mean(),
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
return LossWithIntermediateLosses(**losses), quantizer_outputs.metrics
|
| 82 |
+
|
| 83 |
+
def encode(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> torch.FloatTensor:
|
| 84 |
+
a_emb = rearrange(self.encoder_act_emb(a), 'b t (h w) -> b t 1 h w', h=x1.size(3))
|
| 85 |
+
encoder_input = torch.cat((x1, a_emb, x2), dim=2)
|
| 86 |
+
z = self.encoder(encoder_input)
|
| 87 |
+
|
| 88 |
+
return z
|
| 89 |
+
|
| 90 |
+
def decode(self, x1: torch.FloatTensor, a: torch.LongTensor, q2: torch.FloatTensor, should_clamp: bool = False) -> torch.FloatTensor:
|
| 91 |
+
x1_emb = self.frame_cnn(x1)
|
| 92 |
+
a_emb = rearrange(self.decoder_act_emb(a), 'b t (c h w) -> b t c h w', c=self.config.decoder_act_channels, h=x1_emb.size(3))
|
| 93 |
+
|
| 94 |
+
decoder_input = torch.cat((x1_emb, a_emb, q2), dim=2)
|
| 95 |
+
|
| 96 |
+
r = self.decoder(decoder_input)
|
| 97 |
+
r = torch.clamp(r, 0, 1).mul(255).round().div(255) if should_clamp else r
|
| 98 |
+
|
| 99 |
+
return r
|
| 100 |
+
|
| 101 |
+
@torch.no_grad()
|
| 102 |
+
def encode_decode(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> torch.Tensor:
|
| 103 |
+
z = self.encode(x1, a, x2)
|
| 104 |
+
z = rearrange(z, 'b t c (h k) (w l) -> b t (h w) (k l c)', k=self.token_res, l=self.token_res)
|
| 105 |
+
q = rearrange(self.quantizer(z).q, 'b t (h w) (k l e) -> b t e (h k) (w l)', h=self.tokens_grid_res, k=self.token_res, l=self.token_res)
|
| 106 |
+
r = self.decode(x1, a, q, should_clamp=True)
|
| 107 |
+
|
| 108 |
+
return r
|
| 109 |
+
|
| 110 |
+
@torch.no_grad()
|
| 111 |
+
def burn_in(self, obs: torch.FloatTensor, act: torch.LongTensor) -> torch.LongTensor:
|
| 112 |
+
assert obs.size(1) == act.size(1) + 1
|
| 113 |
+
quantizer_output = self(obs[:, :-1], act, obs[:, 1:])
|
| 114 |
+
|
| 115 |
+
return quantizer_output.tokens
|
delta-iris/src/world_model.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
from einops import rearrange, repeat
|
| 4 |
+
from einops.layers.torch import Rearrange
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from .models.convnet import FrameCnnConfig, FrameEncoder
|
| 10 |
+
from .data import Batch
|
| 11 |
+
from .models.slicer import Head
|
| 12 |
+
from .tokenizer import Tokenizer
|
| 13 |
+
from .models.transformer import TransformerEncoder, TransformerConfig
|
| 14 |
+
from .models.utils import init_weights, LossWithIntermediateLosses, symlog, two_hot
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class WorldModelOutput:
|
| 19 |
+
output_sequence: torch.FloatTensor
|
| 20 |
+
logits_latents: torch.FloatTensor
|
| 21 |
+
logits_rewards: torch.FloatTensor
|
| 22 |
+
logits_ends: torch.FloatTensor
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class WorldModelConfig:
|
| 27 |
+
latent_vocab_size: int
|
| 28 |
+
num_actions: int
|
| 29 |
+
image_channels: int
|
| 30 |
+
image_size: int
|
| 31 |
+
latents_weight: float
|
| 32 |
+
rewards_weight: float
|
| 33 |
+
ends_weight: float
|
| 34 |
+
two_hot_rews: bool
|
| 35 |
+
transformer_config: TransformerConfig
|
| 36 |
+
frame_cnn_config: FrameCnnConfig
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class WorldModel(nn.Module):
|
| 40 |
+
def __init__(self, config: WorldModelConfig) -> None:
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.config = config
|
| 43 |
+
self.transformer = TransformerEncoder(config.transformer_config)
|
| 44 |
+
|
| 45 |
+
assert ((config.image_size // 2 ** sum(config.frame_cnn_config.down)) ** 2) * config.frame_cnn_config.latent_dim == config.transformer_config.embed_dim
|
| 46 |
+
self.frame_cnn = nn.Sequential(FrameEncoder(config.frame_cnn_config), Rearrange('b t c h w -> b t 1 (h w c)'), nn.LayerNorm(config.transformer_config.embed_dim))
|
| 47 |
+
|
| 48 |
+
self.act_emb = nn.Embedding(config.num_actions, config.transformer_config.embed_dim)
|
| 49 |
+
self.latents_emb = nn.Embedding(config.latent_vocab_size, config.transformer_config.embed_dim)
|
| 50 |
+
|
| 51 |
+
act_pattern = torch.zeros(config.transformer_config.tokens_per_block)
|
| 52 |
+
act_pattern[1] = 1
|
| 53 |
+
act_and_latents_but_last_pattern = torch.zeros(config.transformer_config.tokens_per_block)
|
| 54 |
+
act_and_latents_but_last_pattern[1:-1] = 1
|
| 55 |
+
|
| 56 |
+
self.head_latents = Head(
|
| 57 |
+
max_blocks=config.transformer_config.max_blocks,
|
| 58 |
+
block_mask=act_and_latents_but_last_pattern,
|
| 59 |
+
head_module=nn.Sequential(
|
| 60 |
+
nn.Linear(config.transformer_config.embed_dim, config.transformer_config.embed_dim), nn.ReLU(),
|
| 61 |
+
nn.Linear(config.transformer_config.embed_dim, config.latent_vocab_size)
|
| 62 |
+
)
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
self.head_rewards = Head(
|
| 66 |
+
max_blocks=config.transformer_config.max_blocks,
|
| 67 |
+
block_mask=act_pattern,
|
| 68 |
+
head_module=nn.Sequential(
|
| 69 |
+
nn.Linear(config.transformer_config.embed_dim, config.transformer_config.embed_dim), nn.ReLU(),
|
| 70 |
+
nn.Linear(config.transformer_config.embed_dim, 255 if config.two_hot_rews else 3)
|
| 71 |
+
)
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
self.head_ends = Head(
|
| 75 |
+
max_blocks=config.transformer_config.max_blocks,
|
| 76 |
+
block_mask=act_pattern,
|
| 77 |
+
head_module=nn.Sequential(
|
| 78 |
+
nn.Linear(config.transformer_config.embed_dim, config.transformer_config.embed_dim), nn.ReLU(),
|
| 79 |
+
nn.Linear(config.transformer_config.embed_dim, 2)
|
| 80 |
+
)
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
self.apply(init_weights)
|
| 84 |
+
|
| 85 |
+
def __repr__(self) -> str:
|
| 86 |
+
return "world_model"
|
| 87 |
+
|
| 88 |
+
def forward(self, sequence: torch.FloatTensor, use_kv_cache: bool = False) -> WorldModelOutput:
|
| 89 |
+
prev_steps = self.transformer.keys_values.size if use_kv_cache else 0
|
| 90 |
+
num_steps = sequence.size(1)
|
| 91 |
+
|
| 92 |
+
outputs = self.transformer(sequence, use_kv_cache=use_kv_cache)
|
| 93 |
+
|
| 94 |
+
logits_latents = self.head_latents(outputs, num_steps, prev_steps)
|
| 95 |
+
logits_rewards = self.head_rewards(outputs, num_steps, prev_steps)
|
| 96 |
+
logits_ends = self.head_ends(outputs, num_steps, prev_steps)
|
| 97 |
+
|
| 98 |
+
return WorldModelOutput(outputs, logits_latents, logits_rewards, logits_ends)
|
| 99 |
+
|
| 100 |
+
def compute_loss(self, batch: Batch, tokenizer: Tokenizer, **kwargs) -> LossWithIntermediateLosses:
|
| 101 |
+
assert torch.all(batch.ends.sum(dim=1) <= 1)
|
| 102 |
+
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
latent_tokens = tokenizer(batch.observations[:, :-1], batch.actions[:, :-1], batch.observations[:, 1:]).tokens
|
| 105 |
+
|
| 106 |
+
b, _, k = latent_tokens.size()
|
| 107 |
+
|
| 108 |
+
frames_emb = self.frame_cnn(batch.observations)
|
| 109 |
+
act_tokens_emb = self.act_emb(rearrange(batch.actions, 'b t -> b t 1'))
|
| 110 |
+
latent_tokens_emb = self.latents_emb(torch.cat((latent_tokens, latent_tokens.new_zeros(b, 1, k)), dim=1))
|
| 111 |
+
sequence = rearrange(torch.cat((frames_emb, act_tokens_emb, latent_tokens_emb), dim=2), 'b t p1k e -> b (t p1k) e')
|
| 112 |
+
|
| 113 |
+
outputs = self(sequence)
|
| 114 |
+
|
| 115 |
+
mask = batch.mask_padding
|
| 116 |
+
|
| 117 |
+
labels_latents = latent_tokens[mask[:, :-1]].flatten()
|
| 118 |
+
logits_latents = outputs.logits_latents[:, :-k][repeat(mask[:, :-1], 'b t -> b (t k)', k=k)]
|
| 119 |
+
latent_acc = (logits_latents.max(dim=-1)[1] == labels_latents).float().mean()
|
| 120 |
+
labels_rewards = two_hot(symlog(batch.rewards)) if self.config.two_hot_rews else (batch.rewards.sign() + 1).long()
|
| 121 |
+
|
| 122 |
+
loss_latents = F.cross_entropy(logits_latents, target=labels_latents) * self.config.latents_weight
|
| 123 |
+
loss_rewards = F.cross_entropy(outputs.logits_rewards[mask], target=labels_rewards[mask]) * self.config.rewards_weight
|
| 124 |
+
loss_ends = F.cross_entropy(outputs.logits_ends[mask], target=batch.ends[mask]) * self.config.ends_weight
|
| 125 |
+
|
| 126 |
+
return LossWithIntermediateLosses(loss_latents=loss_latents, loss_rewards=loss_rewards, loss_ends=loss_ends), {'latent_accuracy': latent_acc}
|
| 127 |
+
|
| 128 |
+
@torch.no_grad()
|
| 129 |
+
def burn_in(self, obs: torch.FloatTensor, act: torch.LongTensor, latent_tokens: torch.LongTensor, use_kv_cache: bool = False) -> torch.FloatTensor:
|
| 130 |
+
assert obs.size(1) == act.size(1) + 1 == latent_tokens.size(1) + 1
|
| 131 |
+
|
| 132 |
+
x_emb = self.frame_cnn(obs)
|
| 133 |
+
act_emb = rearrange(self.act_emb(act), 'b t e -> b t 1 e')
|
| 134 |
+
q_emb = self.latents_emb(latent_tokens)
|
| 135 |
+
x_a_q = rearrange(torch.cat((x_emb[:, :-1], act_emb, q_emb), dim=2), 'b t k2 e -> b (t k2) e')
|
| 136 |
+
wm_input_sequence = torch.cat((x_a_q, x_emb[:, -1]), dim=1)
|
| 137 |
+
wm_output_sequence = self(wm_input_sequence, use_kv_cache=use_kv_cache).output_sequence
|
| 138 |
+
|
| 139 |
+
return wm_output_sequence
|