Update delta-iris/src/tokenizer.py
Browse files- delta-iris/src/tokenizer.py +1 -22
delta-iris/src/tokenizer.py
CHANGED
|
@@ -7,9 +7,8 @@ import torch
|
|
| 7 |
import torch.nn as nn
|
| 8 |
|
| 9 |
from .models.convnet import FrameEncoder, FrameDecoder
|
| 10 |
-
from .data import Batch
|
| 11 |
from .models.tokenizer.quantizer import Quantizer
|
| 12 |
-
from .models.utils import init_weights
|
| 13 |
|
| 14 |
class Tokenizer(nn.Module):
|
| 15 |
def __init__(self, config: dict) -> None:
|
|
@@ -44,26 +43,6 @@ class Tokenizer(nn.Module):
|
|
| 44 |
|
| 45 |
return self.quantizer(z)
|
| 46 |
|
| 47 |
-
def compute_loss(self, batch: Batch, **kwargs) -> Tuple[LossWithIntermediateLosses, Dict]:
|
| 48 |
-
x1 = batch.observations[:, :-1]
|
| 49 |
-
a = batch.actions[:, :-1]
|
| 50 |
-
x2 = batch.observations[:, 1:]
|
| 51 |
-
|
| 52 |
-
quantizer_outputs = self(x1, a, x2)
|
| 53 |
-
|
| 54 |
-
r = self.decode(x1, a, rearrange(quantizer_outputs.q, 'b t (h w) (k l e) -> b t e (h k) (w l)', h=self.tokens_grid_res, k=self.token_res, l=self.token_res))
|
| 55 |
-
delta = (x2 - r)
|
| 56 |
-
delta = delta[torch.logical_and(batch.mask_padding[:, 1:], batch.mask_padding[:, :-1])]
|
| 57 |
-
|
| 58 |
-
losses = {
|
| 59 |
-
**quantizer_outputs.loss,
|
| 60 |
-
'reconstruction_loss_l1': 0.1 * torch.abs(delta).mean(),
|
| 61 |
-
'reconstruction_loss_l2': delta.pow(2).mean(),
|
| 62 |
-
'reconstruction_loss_l2_worst_pixel': 0.01 * rearrange(delta, 'b c h w -> b (c h w)').pow(2).max(dim=-1)[0].mean(),
|
| 63 |
-
}
|
| 64 |
-
|
| 65 |
-
return LossWithIntermediateLosses(**losses), quantizer_outputs.metrics
|
| 66 |
-
|
| 67 |
def encode(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> torch.FloatTensor:
|
| 68 |
a_emb = rearrange(self.encoder_act_emb(a), 'b t (h w) -> b t 1 h w', h=x1.size(3))
|
| 69 |
encoder_input = torch.cat((x1, a_emb, x2), dim=2)
|
|
|
|
| 7 |
import torch.nn as nn
|
| 8 |
|
| 9 |
from .models.convnet import FrameEncoder, FrameDecoder
|
|
|
|
| 10 |
from .models.tokenizer.quantizer import Quantizer
|
| 11 |
+
from .models.utils import init_weights
|
| 12 |
|
| 13 |
class Tokenizer(nn.Module):
|
| 14 |
def __init__(self, config: dict) -> None:
|
|
|
|
| 43 |
|
| 44 |
return self.quantizer(z)
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
def encode(self, x1: torch.FloatTensor, a: torch.LongTensor, x2: torch.FloatTensor) -> torch.FloatTensor:
|
| 47 |
a_emb = rearrange(self.encoder_act_emb(a), 'b t (h w) -> b t 1 h w', h=x1.size(3))
|
| 48 |
encoder_input = torch.cat((x1, a_emb, x2), dim=2)
|