| from __future__ import annotations | |
| from dataclasses import dataclass | |
| import torch | |
| class SegmentId: | |
| episode_id: int | |
| start: int | |
| stop: int | |
| class Segment: | |
| observations: torch.ByteTensor | |
| actions: torch.LongTensor | |
| rewards: torch.FloatTensor | |
| ends: torch.LongTensor | |
| mask_padding: torch.BoolTensor | |
| id: SegmentId | |
| def effective_size(self) -> int: | |
| return self.mask_padding.sum().item() | |