ShaswatRobotics commited on
Commit
7c81dfd
·
verified ·
1 Parent(s): 9147d50

Update delta-iris/src/world_model.py

Browse files
Files changed (1) hide show
  1. delta-iris/src/world_model.py +1 -31
delta-iris/src/world_model.py CHANGED
@@ -7,11 +7,9 @@ import torch.nn as nn
7
  import torch.nn.functional as F
8
 
9
  from .models.convnet import FrameEncoder
10
- from .data import Batch
11
  from .models.slicer import Head
12
- from .tokenizer import Tokenizer
13
  from .models.transformer import TransformerEncoder
14
- from .models.utils import init_weights, LossWithIntermediateLosses, symlog, two_hot
15
 
16
  class WorldModel(nn.Module):
17
  def __init__(self, config: dict) -> None:
@@ -79,34 +77,6 @@ class WorldModel(nn.Module):
79
  "logits_ends": logits_ends
80
  }
81
 
82
- def compute_loss(self, batch: Batch, tokenizer: Tokenizer, **kwargs) -> LossWithIntermediateLosses:
83
- assert torch.all(batch.ends.sum(dim=1) <= 1)
84
-
85
- with torch.no_grad():
86
- latent_tokens = tokenizer(batch.observations[:, :-1], batch.actions[:, :-1], batch.observations[:, 1:]).tokens
87
-
88
- b, _, k = latent_tokens.size()
89
-
90
- frames_emb = self.frame_cnn(batch.observations)
91
- act_tokens_emb = self.act_emb(rearrange(batch.actions, 'b t -> b t 1'))
92
- latent_tokens_emb = self.latents_emb(torch.cat((latent_tokens, latent_tokens.new_zeros(b, 1, k)), dim=1))
93
- sequence = rearrange(torch.cat((frames_emb, act_tokens_emb, latent_tokens_emb), dim=2), 'b t p1k e -> b (t p1k) e')
94
-
95
- outputs = self(sequence)
96
-
97
- mask = batch.mask_padding
98
-
99
- labels_latents = latent_tokens[mask[:, :-1]].flatten()
100
- logits_latents = outputs.logits_latents[:, :-k][repeat(mask[:, :-1], 'b t -> b (t k)', k=k)]
101
- latent_acc = (logits_latents.max(dim=-1)[1] == labels_latents).float().mean()
102
- labels_rewards = two_hot(symlog(batch.rewards)) if self.config["two_hot_rews"] else (batch.rewards.sign() + 1).long()
103
-
104
- loss_latents = F.cross_entropy(logits_latents, target=labels_latents) * self.config["latents_weight"]
105
- loss_rewards = F.cross_entropy(outputs.logits_rewards[mask], target=labels_rewards[mask]) * self.config["rewards_weight"]
106
- loss_ends = F.cross_entropy(outputs.logits_ends[mask], target=batch.ends[mask]) * self.config["ends_weight"]
107
-
108
- return LossWithIntermediateLosses(loss_latents=loss_latents, loss_rewards=loss_rewards, loss_ends=loss_ends), {'latent_accuracy': latent_acc}
109
-
110
  @torch.no_grad()
111
  def burn_in(self, obs: torch.FloatTensor, act: torch.LongTensor, latent_tokens: torch.LongTensor, use_kv_cache: bool = False) -> torch.FloatTensor:
112
  assert obs.size(1) == act.size(1) + 1 == latent_tokens.size(1) + 1
 
7
  import torch.nn.functional as F
8
 
9
  from .models.convnet import FrameEncoder
 
10
  from .models.slicer import Head
 
11
  from .models.transformer import TransformerEncoder
12
+ from .models.utils import init_weights
13
 
14
  class WorldModel(nn.Module):
15
  def __init__(self, config: dict) -> None:
 
77
  "logits_ends": logits_ends
78
  }
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  @torch.no_grad()
81
  def burn_in(self, obs: torch.FloatTensor, act: torch.LongTensor, latent_tokens: torch.LongTensor, use_kv_cache: bool = False) -> torch.FloatTensor:
82
  assert obs.size(1) == act.size(1) + 1 == latent_tokens.size(1) + 1