File size: 612 Bytes
f28049f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
"""Utility functions for Chiluka."""
import torch
from munch import Munch
def length_to_mask(lengths):
"""Convert lengths to attention mask."""
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask + 1, lengths.unsqueeze(1))
return mask
def recursive_munch(d):
"""Recursively convert dict to Munch for dot notation access."""
if isinstance(d, dict):
return Munch((k, recursive_munch(v)) for k, v in d.items())
elif isinstance(d, list):
return [recursive_munch(v) for v in d]
else:
return d
|