ShaswatRobotics commited on
Commit
9147d50
·
verified ·
1 Parent(s): a29a97c

Update delta-iris/src/tokenizer.py

Browse files
Files changed (1) hide show
  1. delta-iris/src/tokenizer.py +1 -22
delta-iris/src/tokenizer.py CHANGED
@@ -7,9 +7,8 @@ import torch
7
  import torch.nn as nn
8
 
9
  from .models.convnet import FrameEncoder, FrameDecoder
10
- from .data import Batch
11
  from .models.tokenizer.quantizer import Quantizer
12
- from .models.utils import init_weights, LossWithIntermediateLosses
13
 
14
  class Tokenizer(nn.Module):
15
  def __init__(self, config: dict) -> None:
@@ -44,26 +43,6 @@ class Tokenizer(nn.Module):
44
 
45
  return self.quantizer(z)
46
 
47
- def compute_loss(self, batch: Batch, **kwargs) -> Tuple[LossWithIntermediateLosses, Dict]:
48
- x1 = batch.observations[:, :-1]
49
- a = batch.actions[:, :-1]
50
- x2 = batch.observations[:, 1:]
51
-
52
- quantizer_outputs = self(x1, a, x2)
53
-
54
- r = self.decode(x1, a, rearrange(quantizer_outputs.q, 'b t (h w) (k l e) -> b t e (h k) (w l)', h=self.tokens_grid_res, k=self.token_res, l=self.token_res))
55
- delta = (x2 - r)
56
- delta = delta[torch.logical_and(batch.mask_padding[:, 1:], batch.mask_padding[:, :-1])]
57
-
58
- losses = {
59
- **quantizer_outputs.loss,
60
- 'reconstruction_loss_l1': 0.1 * torch.abs(delta).mean(),
61
- 'reconstruction_loss_l2': delta.pow(2).mean(),
62
- 'reconstruction_loss_l2_worst_pixel': 0.01 * rearrange(delta, 'b c h w -> b (c h w)').pow(2).max(dim=-1)[0].mean(),
63
- }
64
-
65
- return LossWithIntermediateLosses(**losses), quantizer_outputs.metrics
66
-
67
  def encode(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> torch.FloatTensor:
68
  a_emb = rearrange(self.encoder_act_emb(a), 'b t (h w) -> b t 1 h w', h=x1.size(3))
69
  encoder_input = torch.cat((x1, a_emb, x2), dim=2)
 
7
  import torch.nn as nn
8
 
9
  from .models.convnet import FrameEncoder, FrameDecoder
 
10
  from .models.tokenizer.quantizer import Quantizer
11
+ from .models.utils import init_weights
12
 
13
  class Tokenizer(nn.Module):
14
  def __init__(self, config: dict) -> None:
 
43
 
44
  return self.quantizer(z)
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def encode(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> torch.FloatTensor:
47
  a_emb = rearrange(self.encoder_act_emb(a), 'b t (h w) -> b t 1 h w', h=x1.size(3))
48
  encoder_input = torch.cat((x1, a_emb, x2), dim=2)