Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| from typing import Dict, List | |
| def disassemble_dict(d, keep_dim=False): | |
| ''' | |
| Unpack a dictionary into a list of dictionaries. The values should be in the same length. | |
| If not keep dim: {k: [...] * N} -> [{k: [...]}] * N. | |
| If keep dim: {k: [...] * N} -> [{k: [[...]]}] * N. | |
| ''' | |
| Ls = [len(v) for v in d.values()] | |
| assert len(set(Ls)) == 1, 'The lengths of the values should be the same!' | |
| N = Ls[0] | |
| if keep_dim: | |
| return [{k: v[[i]] for k, v in d.items()} for i in range(N)] | |
| else: | |
| return [{k: v[i] for k, v in d.items()} for i in range(N)] | |
| def assemble_dict(d, expand_dim=False, keys=None): | |
| ''' | |
| Pack a list of dictionaries into one dictionary. | |
| If expand dim, perform stack, else, perform concat. | |
| ''' | |
| keys = list(d[0].keys()) if keys is None else keys | |
| if isinstance(d[0][keys[0]], np.ndarray): | |
| if expand_dim: | |
| return {k: np.stack([v[k] for v in d], axis=0) for k in keys} | |
| else: | |
| return {k: np.concatenate([v[k] for v in d], axis=0) for k in keys} | |
| elif isinstance(d[0][keys[0]], torch.Tensor): | |
| if expand_dim: | |
| return {k: torch.stack([v[k] for v in d], dim=0) for k in keys} | |
| else: | |
| return {k: torch.cat([v[k] for v in d], dim=0) for k in keys} | |
| def filter_dict(d:Dict, keys:List, full:bool=False, strict:bool=False): | |
| ''' | |
| Use path-like syntax to filter the embedded dictionary. | |
| The `'*'` string is regarded as a wildcard, and will return the matched keys. | |
| For control flags: | |
| - If `full`, return the full path, otherwise, only return the matched values. | |
| - If `strict`, raise error if the key is not found, otherwise, simply ignore. | |
| Eg. | |
| - `x = {'fruit': {'yellow': 'banana', 'red': 'apple'}, 'recycle': {'yellow': 'trash', 'blue': 'recyclable'}}` | |
| - `filter_dict(x, ['*', 'yellow'])` -> `{'fruit': 'banana', 'recycle': 'trash'}` | |
| - `filter_dict(x, ['*', 'yellow'], full=True)` -> `{'fruit': {'yellow': 'banana'}, 'recycle': {'yellow': 'trash'}}` | |
| - `filter_dict(x, ['*', 'blue'])` -> `{'recycle': 'recyclable'}` | |
| - `filter_dict(x, ['*', 'blue'], strict=True)` -> `KeyError: 'blue'` | |
| ''' | |
| ret = {} | |
| if keys: | |
| cur_key, rest_keys = keys[0], keys[1:] | |
| if cur_key == '*': | |
| for match in d.keys(): | |
| try: | |
| res = filter_dict(d[match], rest_keys, full=full, strict=strict) | |
| if res: | |
| ret[match] = res | |
| except Exception as e: | |
| if strict: | |
| raise e | |
| else: | |
| try: | |
| res = filter_dict(d[cur_key], rest_keys, full=full, strict=strict) | |
| if res: | |
| ret = { cur_key : res } if full else res | |
| except Exception as e: | |
| if strict: | |
| raise e | |
| else: | |
| ret = d | |
| return ret |