Spaces:
Running
Running
| from typing import List, Optional, Dict | |
| import os | |
| import torch | |
| from utils.common.log import logger | |
| import hashlib | |
| def get_dataset_cache_path(root_dir: str, | |
| classes: List[str], ignore_classes: List[str], idx_map: Optional[Dict[int, int]]): | |
| def _hash(o): | |
| if isinstance(o, list): | |
| o = sorted(o) | |
| elif isinstance(o, dict): | |
| o = {k: o[k] for k in sorted(o)} | |
| elif isinstance(o, set): | |
| o = sorted(list(o)) | |
| # else: | |
| # print(type(o)) | |
| obj = hashlib.md5() | |
| obj.update(str(o).encode('utf-8')) | |
| return obj.hexdigest() | |
| cache_key = _hash(f'zql_data_{_hash(root_dir)}_{_hash(classes)}_{_hash(ignore_classes)}_{_hash(idx_map)}.cache') | |
| # print(root_dir, classes, ignore_classes, idx_map) | |
| # print('cache key', cache_key) | |
| cache_file_path = os.path.join('/tmp', f'./zql_data_cache_{cache_key}.cache') | |
| return cache_file_path | |
| def cache_dataset_status(status, cache_file_path, dataset_name): | |
| logger.info(f'cache dataset status: {dataset_name}') | |
| torch.save(status, cache_file_path) | |
| def read_cached_dataset_status(cache_file_path, dataset_name): | |
| logger.info(f'read dataset cache: {dataset_name}') | |
| return torch.load(cache_file_path) | |