|
|
import difflib |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
def get_layer(l_name, library=torch.nn): |
|
|
"""Return layer object handler from library e.g. from torch.nn |
|
|
|
|
|
E.g. if l_name=="elu", returns torch.nn.ELU. |
|
|
|
|
|
Args: |
|
|
l_name (string): Case insensitive name for layer in library (e.g. .'elu'). |
|
|
library (module): Name of library/module where to search for object handler |
|
|
with l_name e.g. "torch.nn". |
|
|
|
|
|
Returns: |
|
|
layer_handler (object): handler for the requested layer e.g. (torch.nn.ELU) |
|
|
|
|
|
""" |
|
|
|
|
|
all_torch_layers = [x for x in dir(torch.nn)] |
|
|
match = [x for x in all_torch_layers if l_name.lower() == x.lower()] |
|
|
if len(match) == 0: |
|
|
close_matches = difflib.get_close_matches( |
|
|
l_name, [x.lower() for x in all_torch_layers] |
|
|
) |
|
|
raise NotImplementedError( |
|
|
"Layer with name {} not found in {}.\n Closest matches: {}".format( |
|
|
l_name, str(library), close_matches |
|
|
) |
|
|
) |
|
|
elif len(match) > 1: |
|
|
close_matches = difflib.get_close_matches( |
|
|
l_name, [x.lower() for x in all_torch_layers] |
|
|
) |
|
|
raise NotImplementedError( |
|
|
"Multiple matchs for layer with name {} not found in {}.\n " |
|
|
"All matches: {}".format(l_name, str(library), close_matches) |
|
|
) |
|
|
else: |
|
|
|
|
|
layer_handler = getattr(library, match[0]) |
|
|
return layer_handler |