Spaces:
Sleeping
Sleeping
File size: 4,745 Bytes
c64c726 17fd5e3 c64c726 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torcheval.metrics.functional import multiclass_confusion_matrix
from .blocks import Conv3x3, Downsample, ResBlocks
from ..data import Batch
from ..utils import init_lstm, LossAndLogs
@dataclass
class RewEndModelConfig:
lstm_dim: int
img_channels: int
img_size: int
cond_channels: int
depths: List[int]
channels: List[int]
attn_depths: List[int]
num_actions: Optional[int] = None
class RewEndModel(nn.Module):
def __init__(self, cfg: RewEndModelConfig) -> None:
super().__init__()
self.cfg = cfg
self.encoder = RewEndEncoder(2 * cfg.img_channels, cfg.cond_channels, cfg.depths, cfg.channels, cfg.attn_depths)
self.act_emb = nn.Embedding(cfg.num_actions, cfg.cond_channels)
input_dim_lstm = cfg.channels[-1] * (cfg.img_size // 2 ** (len(cfg.depths) - 1)) ** 2
self.lstm = nn.LSTM(input_dim_lstm, cfg.lstm_dim, batch_first=True)
self.head = nn.Sequential(
nn.Linear(cfg.lstm_dim, cfg.lstm_dim),
nn.SiLU(),
nn.Linear(cfg.lstm_dim, 3 + 2, bias=False),
)
init_lstm(self.lstm)
def predict_rew_end(
self,
obs: Tensor,
act: Tensor,
next_obs: Tensor,
hx_cx: Optional[Tuple[Tensor, Tensor]] = None,
) -> Tuple[Tensor, Tensor, Tuple[Tensor, Tensor]]:
b, t, c, h, w = obs.shape
obs, act, next_obs = obs.reshape(b * t, c, h, w), act.reshape(b * t), next_obs.reshape(b * t, c, h, w)
x = self.encoder(torch.cat((obs, next_obs), dim=1), self.act_emb(act))
x = x.reshape(b, t, -1) # (b t) e h w -> b t (e h w)
x, hx_cx = self.lstm(x, hx_cx)
logits = self.head(x)
return logits[:, :, :-2], logits[:, :, -2:], hx_cx
def forward(self, batch: Batch) -> LossAndLogs:
obs = batch.obs[:, :-1]
act = batch.act[:, :-1]
next_obs = batch.obs[:, 1:]
rew = batch.rew[:, :-1]
end = batch.end[:, :-1]
mask = batch.mask_padding[:, :-1]
# When dead, replace frame (gray padding) by true final obs
dead = end.bool().any(dim=1)
if dead.any():
final_obs = torch.stack([i["final_observation"] for i, d in zip(batch.info, dead) if d]).to(obs.device)
next_obs[dead, end[dead].argmax(dim=1)] = final_obs
logits_rew, logits_end, _ = self.predict_rew_end(obs, act, next_obs)
logits_rew = logits_rew[mask]
logits_end = logits_end[mask]
target_rew = rew[mask].sign().long().add(1) # clipped to {-1, 0, 1}
target_end = end[mask]
loss_rew = F.cross_entropy(logits_rew, target_rew)
loss_end = F.cross_entropy(logits_end, target_end)
loss = loss_rew + loss_end
metrics = {
"loss_rew": loss_rew.detach(),
"loss_end": loss_end.detach(),
"loss_total": loss.detach(),
"confusion_matrix": {
"rew": multiclass_confusion_matrix(logits_rew, target_rew, num_classes=3),
"end": multiclass_confusion_matrix(logits_end, target_end, num_classes=2),
},
}
return loss, metrics
class RewEndEncoder(nn.Module):
def __init__(
self,
in_channels: int,
cond_channels: int,
depths: List[int],
channels: List[int],
attn_depths: List[int],
) -> None:
super().__init__()
assert len(depths) == len(channels) == len(attn_depths)
self.conv_in = Conv3x3(in_channels, channels[0])
blocks = []
for i, n in enumerate(depths):
c1 = channels[max(0, i - 1)]
c2 = channels[i]
blocks.append(
ResBlocks(
list_in_channels=[c1] + [c2] * (n - 1),
list_out_channels=[c2] * n,
cond_channels=cond_channels,
attn=attn_depths[i],
)
)
blocks.append(
ResBlocks(
list_in_channels=[channels[-1]] * 2,
list_out_channels=[channels[-1]] * 2,
cond_channels=cond_channels,
attn=True,
)
)
self.blocks = nn.ModuleList(blocks)
self.downsamples = nn.ModuleList([nn.Identity()] + [Downsample(c) for c in channels[:-1]] + [nn.Identity()])
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
x = self.conv_in(x)
for block, down in zip(self.blocks, self.downsamples):
x = down(x)
x, _ = block(x, cond)
return x
|