from dataclasses import dataclass from typing import List, Optional, Tuple import torch from torch import Tensor import torch.nn as nn import torch.nn.functional as F from torcheval.metrics.functional import multiclass_confusion_matrix from .blocks import Conv3x3, Downsample, ResBlocks from ..data import Batch from ..utils import init_lstm, LossAndLogs @dataclass class RewEndModelConfig: lstm_dim: int img_channels: int img_size: int cond_channels: int depths: List[int] channels: List[int] attn_depths: List[int] num_actions: Optional[int] = None class RewEndModel(nn.Module): def __init__(self, cfg: RewEndModelConfig) -> None: super().__init__() self.cfg = cfg self.encoder = RewEndEncoder(2 * cfg.img_channels, cfg.cond_channels, cfg.depths, cfg.channels, cfg.attn_depths) self.act_emb = nn.Embedding(cfg.num_actions, cfg.cond_channels) input_dim_lstm = cfg.channels[-1] * (cfg.img_size // 2 ** (len(cfg.depths) - 1)) ** 2 self.lstm = nn.LSTM(input_dim_lstm, cfg.lstm_dim, batch_first=True) self.head = nn.Sequential( nn.Linear(cfg.lstm_dim, cfg.lstm_dim), nn.SiLU(), nn.Linear(cfg.lstm_dim, 3 + 2, bias=False), ) init_lstm(self.lstm) def predict_rew_end( self, obs: Tensor, act: Tensor, next_obs: Tensor, hx_cx: Optional[Tuple[Tensor, Tensor]] = None, ) -> Tuple[Tensor, Tensor, Tuple[Tensor, Tensor]]: b, t, c, h, w = obs.shape obs, act, next_obs = obs.reshape(b * t, c, h, w), act.reshape(b * t), next_obs.reshape(b * t, c, h, w) x = self.encoder(torch.cat((obs, next_obs), dim=1), self.act_emb(act)) x = x.reshape(b, t, -1) # (b t) e h w -> b t (e h w) x, hx_cx = self.lstm(x, hx_cx) logits = self.head(x) return logits[:, :, :-2], logits[:, :, -2:], hx_cx def forward(self, batch: Batch) -> LossAndLogs: obs = batch.obs[:, :-1] act = batch.act[:, :-1] next_obs = batch.obs[:, 1:] rew = batch.rew[:, :-1] end = batch.end[:, :-1] mask = batch.mask_padding[:, :-1] # When dead, replace frame (gray padding) by true final obs dead = end.bool().any(dim=1) if dead.any(): final_obs = torch.stack([i["final_observation"] for i, d in zip(batch.info, dead) if d]).to(obs.device) next_obs[dead, end[dead].argmax(dim=1)] = final_obs logits_rew, logits_end, _ = self.predict_rew_end(obs, act, next_obs) logits_rew = logits_rew[mask] logits_end = logits_end[mask] target_rew = rew[mask].sign().long().add(1) # clipped to {-1, 0, 1} target_end = end[mask] loss_rew = F.cross_entropy(logits_rew, target_rew) loss_end = F.cross_entropy(logits_end, target_end) loss = loss_rew + loss_end metrics = { "loss_rew": loss_rew.detach(), "loss_end": loss_end.detach(), "loss_total": loss.detach(), "confusion_matrix": { "rew": multiclass_confusion_matrix(logits_rew, target_rew, num_classes=3), "end": multiclass_confusion_matrix(logits_end, target_end, num_classes=2), }, } return loss, metrics class RewEndEncoder(nn.Module): def __init__( self, in_channels: int, cond_channels: int, depths: List[int], channels: List[int], attn_depths: List[int], ) -> None: super().__init__() assert len(depths) == len(channels) == len(attn_depths) self.conv_in = Conv3x3(in_channels, channels[0]) blocks = [] for i, n in enumerate(depths): c1 = channels[max(0, i - 1)] c2 = channels[i] blocks.append( ResBlocks( list_in_channels=[c1] + [c2] * (n - 1), list_out_channels=[c2] * n, cond_channels=cond_channels, attn=attn_depths[i], ) ) blocks.append( ResBlocks( list_in_channels=[channels[-1]] * 2, list_out_channels=[channels[-1]] * 2, cond_channels=cond_channels, attn=True, ) ) self.blocks = nn.ModuleList(blocks) self.downsamples = nn.ModuleList([nn.Identity()] + [Downsample(c) for c in channels[:-1]] + [nn.Identity()]) def forward(self, x: Tensor, cond: Tensor) -> Tensor: x = self.conv_in(x) for block, down in zip(self.blocks, self.downsamples): x = down(x) x, _ = block(x, cond) return x