Update delta-iris/src/world_model.py
Browse files
delta-iris/src/world_model.py
CHANGED
|
@@ -7,11 +7,9 @@ import torch.nn as nn
|
|
| 7 |
import torch.nn.functional as F
|
| 8 |
|
| 9 |
from .models.convnet import FrameEncoder
|
| 10 |
-
from .data import Batch
|
| 11 |
from .models.slicer import Head
|
| 12 |
-
from .tokenizer import Tokenizer
|
| 13 |
from .models.transformer import TransformerEncoder
|
| 14 |
-
from .models.utils import init_weights
|
| 15 |
|
| 16 |
class WorldModel(nn.Module):
|
| 17 |
def __init__(self, config: dict) -> None:
|
|
@@ -79,34 +77,6 @@ class WorldModel(nn.Module):
|
|
| 79 |
"logits_ends": logits_ends
|
| 80 |
}
|
| 81 |
|
| 82 |
-
def compute_loss(self, batch: Batch, tokenizer: Tokenizer, **kwargs) -> LossWithIntermediateLosses:
|
| 83 |
-
assert torch.all(batch.ends.sum(dim=1) <= 1)
|
| 84 |
-
|
| 85 |
-
with torch.no_grad():
|
| 86 |
-
latent_tokens = tokenizer(batch.observations[:, :-1], batch.actions[:, :-1], batch.observations[:, 1:]).tokens
|
| 87 |
-
|
| 88 |
-
b, _, k = latent_tokens.size()
|
| 89 |
-
|
| 90 |
-
frames_emb = self.frame_cnn(batch.observations)
|
| 91 |
-
act_tokens_emb = self.act_emb(rearrange(batch.actions, 'b t -> b t 1'))
|
| 92 |
-
latent_tokens_emb = self.latents_emb(torch.cat((latent_tokens, latent_tokens.new_zeros(b, 1, k)), dim=1))
|
| 93 |
-
sequence = rearrange(torch.cat((frames_emb, act_tokens_emb, latent_tokens_emb), dim=2), 'b t p1k e -> b (t p1k) e')
|
| 94 |
-
|
| 95 |
-
outputs = self(sequence)
|
| 96 |
-
|
| 97 |
-
mask = batch.mask_padding
|
| 98 |
-
|
| 99 |
-
labels_latents = latent_tokens[mask[:, :-1]].flatten()
|
| 100 |
-
logits_latents = outputs.logits_latents[:, :-k][repeat(mask[:, :-1], 'b t -> b (t k)', k=k)]
|
| 101 |
-
latent_acc = (logits_latents.max(dim=-1)[1] == labels_latents).float().mean()
|
| 102 |
-
labels_rewards = two_hot(symlog(batch.rewards)) if self.config["two_hot_rews"] else (batch.rewards.sign() + 1).long()
|
| 103 |
-
|
| 104 |
-
loss_latents = F.cross_entropy(logits_latents, target=labels_latents) * self.config["latents_weight"]
|
| 105 |
-
loss_rewards = F.cross_entropy(outputs.logits_rewards[mask], target=labels_rewards[mask]) * self.config["rewards_weight"]
|
| 106 |
-
loss_ends = F.cross_entropy(outputs.logits_ends[mask], target=batch.ends[mask]) * self.config["ends_weight"]
|
| 107 |
-
|
| 108 |
-
return LossWithIntermediateLosses(loss_latents=loss_latents, loss_rewards=loss_rewards, loss_ends=loss_ends), {'latent_accuracy': latent_acc}
|
| 109 |
-
|
| 110 |
@torch.no_grad()
|
| 111 |
def burn_in(self, obs: torch.FloatTensor, act: torch.LongTensor, latent_tokens: torch.LongTensor, use_kv_cache: bool = False) -> torch.FloatTensor:
|
| 112 |
assert obs.size(1) == act.size(1) + 1 == latent_tokens.size(1) + 1
|
|
|
|
| 7 |
import torch.nn.functional as F
|
| 8 |
|
| 9 |
from .models.convnet import FrameEncoder
|
|
|
|
| 10 |
from .models.slicer import Head
|
|
|
|
| 11 |
from .models.transformer import TransformerEncoder
|
| 12 |
+
from .models.utils import init_weights
|
| 13 |
|
| 14 |
class WorldModel(nn.Module):
|
| 15 |
def __init__(self, config: dict) -> None:
|
|
|
|
| 77 |
"logits_ends": logits_ends
|
| 78 |
}
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
@torch.no_grad()
|
| 81 |
def burn_in(self, obs: torch.FloatTensor, act: torch.LongTensor, latent_tokens: torch.LongTensor, use_kv_cache: bool = False) -> torch.FloatTensor:
|
| 82 |
assert obs.size(1) == act.size(1) + 1 == latent_tokens.size(1) + 1
|