Spaces:
Sleeping
Sleeping
File size: 604 Bytes
c64c726 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Union
import torch
@dataclass
class SegmentId:
episode_id: Union[int, str]
start: int
stop: int
@dataclass
class Segment:
obs: torch.FloatTensor
act: torch.LongTensor
rew: torch.FloatTensor
end: torch.ByteTensor
trunc: torch.ByteTensor
mask_padding: torch.BoolTensor
states: torch.FloatTensor
ego_state: torch.FloatTensor
info: Dict[str, Any]
id: SegmentId
@property
def effective_size(self):
return self.mask_padding.sum().item()
|