Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # coding=utf-8 | |
| import torch | |
| import torch.nn.functional as F | |
| class Batch: | |
| def build(data): | |
| fields = list(data[0].keys()) | |
| transposed = {} | |
| for field in fields: | |
| if isinstance(data[0][field], tuple): | |
| transposed[field] = tuple(Batch._stack(field, [example[field][i] for example in data]) for i in range(len(data[0][field]))) | |
| else: | |
| transposed[field] = Batch._stack(field, [example[field] for example in data]) | |
| return transposed | |
| def _stack(field: str, examples): | |
| if field == "anchored_labels": | |
| return examples | |
| dim = examples[0].dim() | |
| if dim == 0: | |
| return torch.stack(examples) | |
| lengths = [max(example.size(i) for example in examples) for i in range(dim)] | |
| if any(length == 0 for length in lengths): | |
| return torch.LongTensor(len(examples), *lengths) | |
| examples = [F.pad(example, Batch._pad_size(example, lengths)) for example in examples] | |
| return torch.stack(examples) | |
| def _pad_size(example, total_size): | |
| return [p for i, l in enumerate(total_size[::-1]) for p in (0, l - example.size(-1 - i))] | |
| def index_select(batch, indices): | |
| filtered_batch = {} | |
| for key, examples in batch.items(): | |
| if isinstance(examples, list) or isinstance(examples, tuple): | |
| filtered_batch[key] = [example.index_select(0, indices) for example in examples] | |
| else: | |
| filtered_batch[key] = examples.index_select(0, indices) | |
| return filtered_batch | |
| def to_str(batch): | |
| string = "\n".join([f"\t{name}: {Batch._short_str(item)}" for name, item in batch.items()]) | |
| return string | |
| def to(batch, device): | |
| converted = {} | |
| for field in batch.keys(): | |
| converted[field] = Batch._to(batch[field], device) | |
| return converted | |
| def _short_str(tensor): | |
| # unwrap variable to tensor | |
| if not torch.is_tensor(tensor): | |
| # (1) unpack variable | |
| if hasattr(tensor, "data"): | |
| tensor = getattr(tensor, "data") | |
| # (2) handle include_lengths | |
| elif isinstance(tensor, tuple) or isinstance(tensor, list): | |
| return str(tuple(Batch._short_str(t) for t in tensor)) | |
| # (3) fallback to default str | |
| else: | |
| return str(tensor) | |
| # copied from torch _tensor_str | |
| size_str = "x".join(str(size) for size in tensor.size()) | |
| device_str = "" if not tensor.is_cuda else " (GPU {})".format(tensor.get_device()) | |
| strt = "[{} of size {}{}]".format(torch.typename(tensor), size_str, device_str) | |
| return strt | |
| def _to(tensor, device): | |
| if not torch.is_tensor(tensor): | |
| if isinstance(tensor, tuple): | |
| return tuple(Batch._to(t, device) for t in tensor) | |
| elif isinstance(tensor, list): | |
| return [Batch._to(t, device) for t in tensor] | |
| else: | |
| raise Exception(f"unsupported type of {tensor} to be casted to cuda") | |
| return tensor.to(device, non_blocking=True) | |