ShaswatRobotics commited on
Commit
c67657d
·
verified ·
1 Parent(s): 702745a

Delete iris/world_model.py

Browse files
Files changed (1) hide show
  1. iris/world_model.py +0 -93
iris/world_model.py DELETED
@@ -1,93 +0,0 @@
1
- from typing import Any, Optional, Tuple
2
-
3
- from einops import rearrange
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
-
8
- from models.kv_caching import KeysValues
9
- from models.slicer import Embedder, Head
10
- from models.transformer import Transformer
11
-
12
- class WorldModel(nn.Module):
13
- def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: dict) -> None:
14
- super().__init__()
15
- self.obs_vocab_size, self.act_vocab_size = obs_vocab_size, act_vocab_size
16
- self.config = config
17
- self.transformer = Transformer(config)
18
-
19
- all_but_last_obs_tokens_pattern = torch.ones(config["tokens_per_block"])
20
- all_but_last_obs_tokens_pattern[-2] = 0
21
- act_tokens_pattern = torch.zeros(self.config["tokens_per_block"])
22
- act_tokens_pattern[-1] = 1
23
- obs_tokens_pattern = 1 - act_tokens_pattern
24
-
25
- self.pos_emb = nn.Embedding(config["max_tokens"], config["embed_dim"])
26
-
27
- self.embedder = Embedder(
28
- max_blocks=config["max_blocks"],
29
- block_masks=[act_tokens_pattern, obs_tokens_pattern],
30
- embedding_tables=nn.ModuleList([nn.Embedding(act_vocab_size, config["embed_dim"]), nn.Embedding(obs_vocab_size, config["embed_dim"])])
31
- )
32
-
33
- self.head_observations = Head(
34
- max_blocks=config["max_blocks"],
35
- block_mask=all_but_last_obs_tokens_pattern,
36
- head_module=nn.Sequential(
37
- nn.Linear(config["embed_dim"], config["embed_dim"]),
38
- nn.ReLU(),
39
- nn.Linear(config["embed_dim"], obs_vocab_size)
40
- )
41
- )
42
-
43
- self.head_rewards = Head(
44
- max_blocks=config["max_blocks"],
45
- block_mask=act_tokens_pattern,
46
- head_module=nn.Sequential(
47
- nn.Linear(config["embed_dim"], config["embed_dim"]),
48
- nn.ReLU(),
49
- nn.Linear(config["embed_dim"], 3)
50
- )
51
- )
52
-
53
- self.head_ends = Head(
54
- max_blocks=config["max_blocks"],
55
- block_mask=act_tokens_pattern,
56
- head_module=nn.Sequential(
57
- nn.Linear(config["embed_dim"], config["embed_dim"]),
58
- nn.ReLU(),
59
- nn.Linear(config["embed_dim"], 2)
60
- )
61
- )
62
-
63
- def __repr__(self) -> str:
64
- return "world_model"
65
-
66
- def forward(self, tokens: torch.LongTensor, past_keys_values: Optional[KeysValues] = None) -> dict:
67
-
68
- num_steps = tokens.size(1) # (B, T)
69
- assert num_steps <= self.config["max_tokens"]
70
- prev_steps = 0 if past_keys_values is None else past_keys_values.size
71
-
72
- sequences = self.embedder(tokens, num_steps, prev_steps) + self.pos_emb(prev_steps + torch.arange(num_steps, device=tokens.device))
73
-
74
- x = self.transformer(sequences, past_keys_values)
75
-
76
- logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps)
77
- logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps)
78
- logits_ends = self.head_ends(x, num_steps=num_steps, prev_steps=prev_steps)
79
- return {
80
- "output_sequence": x,
81
- "logits_observations": logits_observations,
82
- "logits_rewards": logits_rewards,
83
- "logits_ends": logits_ends
84
-
85
- }
86
-
87
- def compute_labels_world_model(self, obs_tokens: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
88
- assert torch.all(ends.sum(dim=1) <= 1) # at most 1 done
89
- mask_fill = torch.logical_not(mask_padding)
90
- labels_observations = rearrange(obs_tokens.masked_fill(mask_fill.unsqueeze(-1).expand_as(obs_tokens), -100), 'b t k -> b (t k)')[:, 1:]
91
- labels_rewards = (rewards.sign() + 1).masked_fill(mask_fill, -100).long() # Rewards clipped to {-1, 0, 1}
92
- labels_ends = ends.masked_fill(mask_fill, -100)
93
- return labels_observations.reshape(-1), labels_rewards.reshape(-1), labels_ends.reshape(-1)