Spaces:
Runtime error
Runtime error
| from typing import List, Set | |
| import torch | |
| def sorted_list(s: Set[str]) -> List[str]: | |
| return sorted(list(set(s))) | |
| def device(): | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def nested_to_device(s): | |
| # s is either a tensor or a dictionary | |
| if isinstance(s, torch.Tensor): | |
| return s.to(device()) | |
| return {k: v.to(device()) for k, v in s.items()} | |
| def nested_apply(h, s): | |
| # h is an unary function, s is one of N, tuple of N, list of N, or set of N | |
| if isinstance(s, str): | |
| return h(s) | |
| ret = [nested_apply(h, i) for i in s] | |
| if isinstance(s, tuple): | |
| return tuple(ret) | |
| if isinstance(s, set): | |
| return set(ret) | |
| return ret | |