ShaswatRobotics commited on
Commit
f77d622
·
verified ·
1 Parent(s): 6d2be1e

Update delta-iris/src/tokenizer.py

Browse files
Files changed (1) hide show
  1. delta-iris/src/tokenizer.py +2 -2
delta-iris/src/tokenizer.py CHANGED
@@ -8,7 +8,7 @@ import torch.nn as nn
8
 
9
  from .models.convnet import FrameEncoder, FrameDecoder
10
  from .data import Batch
11
- from .models.tokenizer.quantizer import Quantizer, QuantizerOutput
12
  from .models.utils import init_weights, LossWithIntermediateLosses
13
 
14
  class Tokenizer(nn.Module):
@@ -38,7 +38,7 @@ class Tokenizer(nn.Module):
38
  def __repr__(self) -> str:
39
  return "tokenizer"
40
 
41
- def forward(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> QuantizerOutput:
42
  z = self.encode(x1, a, x2)
43
  z = rearrange(z, 'b t c (h k) (w l) -> b t (h w) (k l c)', h=self.tokens_grid_res, w=self.tokens_grid_res)
44
 
 
8
 
9
  from .models.convnet import FrameEncoder, FrameDecoder
10
  from .data import Batch
11
+ from .models.tokenizer.quantizer import Quantizer
12
  from .models.utils import init_weights, LossWithIntermediateLosses
13
 
14
  class Tokenizer(nn.Module):
 
38
  def __repr__(self) -> str:
39
  return "tokenizer"
40
 
41
+ def forward(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> dict:
42
  z = self.encode(x1, a, x2)
43
  z = rearrange(z, 'b t c (h k) (w l) -> b t (h w) (k l c)', h=self.tokens_grid_res, w=self.tokens_grid_res)
44