HRA / nlu /DeBERTa /training /_utils.py
nvan13's picture
Add files using upload-large-folder tool
ab0f6ec verified
import torch
from collections.abc import Sequence, Mapping
def batch_apply(batch, fn):
if isinstance(batch, torch.Tensor):
return fn(batch)
elif isinstance(batch, Sequence):
return [batch_apply(x, fn) for x in batch]
elif isinstance(batch, Mapping):
return {x:batch_apply(batch[x], fn) for x in batch}
else:
raise NotImplementedError(f'Type of {type(batch)} are not supported in batch_apply')
def batch_to(batch, device):
return batch_apply(batch, lambda x: x.to(device))