"""Utility functions for Chiluka.""" import torch from munch import Munch def length_to_mask(lengths): """Convert lengths to attention mask.""" mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) mask = torch.gt(mask + 1, lengths.unsqueeze(1)) return mask def recursive_munch(d): """Recursively convert dict to Munch for dot notation access.""" if isinstance(d, dict): return Munch((k, recursive_munch(v)) for k, v in d.items()) elif isinstance(d, list): return [recursive_munch(v) for v in d] else: return d