Spaces:
Sleeping
Sleeping
| """ | |
| Define collate functions for new data types here | |
| """ | |
| from functools import partial | |
| from itertools import chain | |
| import dgl | |
| import torch | |
| from torch.nn.utils.rnn import pad_sequence | |
| from torch.utils.data._utils.collate import default_collate_fn_map, collate_tensor_fn, collate | |
| import torch_geometric | |
| def collate_pyg_fn(batch, collate_fn_map=None): | |
| """ | |
| PyG graph collation | |
| """ | |
| return torch_geometric.data.Batch.from_data_list(batch) | |
| def collate_dgl_fn(batch, collate_fn_map=None): | |
| """ | |
| DGL graph collation | |
| """ | |
| return dgl.batch(batch) | |
| def pad_collate_tensor_fn(batch, padding_value=0.0, collate_fn_map=None): | |
| """ | |
| Similar to pad_packed_sequence(pack_sequence(batch, enforce_sorted=False), batch_first=True), | |
| but additionally supports padding a list of square Tensors of size ``(L x L x ...)``. | |
| :param batch: | |
| :param padding_value: | |
| :param collate_fn_map: | |
| :return: padded_batch, lengths | |
| """ | |
| lengths = [tensor.size(0) for tensor in batch] | |
| if any(element != lengths[0] for element in lengths[1:]): | |
| try: | |
| # Tensors share at least one common dimension size, use pad_sequence | |
| batch = pad_sequence(batch, batch_first=True, padding_value=padding_value) | |
| except RuntimeError: | |
| # Tensors do not share any common dimension size, find the max size of each dimension in the batch | |
| max_sizes = [max([tensor.size(dim) for tensor in batch]) for dim in range(batch[0].dim())] | |
| # Pad every dimension of all tensors in the batch to be the respective max size with the value | |
| batch = collate_tensor_fn([ | |
| torch.nn.functional.pad( | |
| tensor, tuple(chain.from_iterable( | |
| [(0, max_sizes[dim] - tensor.size(dim)) for dim in range(tensor.dim())][::-1]) | |
| ), mode='constant', value=padding_value) for tensor in batch | |
| ]) | |
| else: | |
| batch = collate_tensor_fn(batch) | |
| lengths = torch.as_tensor(lengths) | |
| # Return the padded batch tensor and the lengths | |
| return batch, lengths | |
| # Join custom collate functions with the default collation map of PyTorch | |
| COLLATE_FN_MAP = default_collate_fn_map | { | |
| torch_geometric.data.data.BaseData: collate_pyg_fn, | |
| dgl.DGLGraph: collate_dgl_fn, | |
| } | |
| def collate_fn(batch, automatic_padding=False, padding_value=0): | |
| if automatic_padding: | |
| COLLATE_FN_MAP.update({ | |
| torch.Tensor: partial(pad_collate_tensor_fn, padding_value=padding_value), | |
| }) | |
| return collate(batch, collate_fn_map=COLLATE_FN_MAP) | |
| # class VariableLengthSequence(torch.Tensor): | |
| # """ | |
| # A custom PyTorch Tensor class that is similar to PackedSequence, except it can be directly used as a batch tensor, | |
| # and it has an attribute called lengths, which signifies the length of each original sequence in the batch. | |
| # """ | |
| # | |
| # def __new__(cls, data, lengths): | |
| # """ | |
| # Creates a new VariableLengthSequence object from the given data and lengths. | |
| # Args: | |
| # data (torch.Tensor): The batch collated tensor of shape (batch_size, max_length, *). | |
| # lengths (torch.Tensor): The lengths of each original sequence in the batch of shape (batch_size,). | |
| # Returns: | |
| # VariableLengthSequence: A new VariableLengthSequence object. | |
| # """ | |
| # # Check the validity of the inputs | |
| # assert isinstance(data, torch.Tensor), "data must be a torch.Tensor" | |
| # assert isinstance(lengths, torch.Tensor), "lengths must be a torch.Tensor" | |
| # assert data.dim() >= 2, "data must have at least two dimensions" | |
| # assert lengths.dim() == 1, "lengths must have one dimension" | |
| # assert data.size(0) == lengths.size(0), "data and lengths must have the same batch size" | |
| # assert lengths.min() > 0, "lengths must be positive" | |
| # assert lengths.max() <= data.size(1), "lengths must not exceed the max length of data" | |
| # | |
| # # Create a new tensor object from data | |
| # obj = super().__new__(cls, data) | |
| # | |
| # # Set the lengths attribute | |
| # obj.lengths = lengths | |
| # | |
| # return obj | |
| # class VariableLengthSequence(torch.Tensor): | |
| # _lengths = torch.Tensor() | |
| # | |
| # def __new__(cls, data, lengths, *args, **kwargs): | |
| # self = super().__new__(cls, data, *args, **kwargs) | |
| # self.lengths = lengths | |
| # return self | |
| # | |
| # def clone(self, *args, **kwargs): | |
| # return VariableLengthSequence(super().clone(*args, **kwargs), self.lengths.clone()) | |
| # | |
| # def new_empty(self, *size): | |
| # return VariableLengthSequence(super().new_empty(*size), self.lengths) | |
| # | |
| # def to(self, *args, **kwargs): | |
| # return VariableLengthSequence(super().to(*args, **kwargs), self.lengths.to(*args, **kwargs)) | |
| # | |
| # def __format__(self, format_spec): | |
| # # Convert self to a string or a number here, depending on what you need | |
| # return self.item().__format__(format_spec) | |
| # | |
| # @property | |
| # def lengths(self): | |
| # return self._lengths | |
| # | |
| # @lengths.setter | |
| # def lengths(self, lengths): | |
| # self._lengths = lengths | |
| # | |
| # def cpu(self, *args, **kwargs): | |
| # return VariableLengthSequence(super().cpu(*args, **kwargs), self.lengths.cpu(*args, **kwargs)) | |
| # | |
| # def cuda(self, *args, **kwargs): | |
| # return VariableLengthSequence(super().cuda(*args, **kwargs), self.lengths.cuda(*args, **kwargs)) | |
| # | |
| # def pin_memory(self): | |
| # return VariableLengthSequence(super().pin_memory(), self.lengths.pin_memory()) | |
| # | |
| # def share_memory_(self): | |
| # super().share_memory_() | |
| # self.lengths.share_memory_() | |
| # return self | |
| # | |
| # def detach_(self, *args, **kwargs): | |
| # super().detach_(*args, **kwargs) | |
| # self.lengths.detach_(*args, **kwargs) | |
| # return self | |
| # | |
| # def detach(self, *args, **kwargs): | |
| # return VariableLengthSequence(super().detach(*args, **kwargs), self.lengths.detach(*args, **kwargs)) | |
| # | |
| # def record_stream(self, *args, **kwargs): | |
| # super().record_stream(*args, **kwargs) | |
| # self.lengths.record_stream(*args, **kwargs) | |
| # return self | |
| # @classmethod | |
| # def __torch_function__(cls, func, types, args=(), kwargs=None): | |
| # return super().__torch_function__(func, types, args, kwargs) \ | |
| # if cls.lengths is not None else torch.Tensor.__torch_function__(func, types, args, kwargs) | |