from __future__ import annotations from dataclasses import dataclass from typing import List import torch from .segment import SegmentId @dataclass class Batch: observations: torch.ByteTensor actions: torch.LongTensor rewards: torch.FloatTensor ends: torch.LongTensor mask_padding: torch.BoolTensor segment_ids: List[SegmentId] def pin_memory(self) -> Batch: return Batch(**{k: v if k == 'segment_ids' else v.pin_memory() for k, v in self.__dict__.items()}) def to(self, device: torch.device) -> Batch: return Batch(**{k: v if k == 'segment_ids' else v.to(device) for k, v in self.__dict__.items()})