File size: 5,340 Bytes
da806fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
#!/usr/bin/env python3
import argparse
import glob
import hashlib
import json
import os
import re
from multiprocessing import Pool
from typing import List, Union
from mmengine.config import Config, ConfigDict
# from opencompass.utils import get_prompt_hash
# copied from opencompass.utils.get_prompt_hash, for easy use in ci
def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str:
"""Get the hash of the prompt configuration.
Args:
dataset_cfg (ConfigDict or list[ConfigDict]): The dataset
configuration.
Returns:
str: The hash of the prompt configuration.
"""
if isinstance(dataset_cfg, list):
if len(dataset_cfg) == 1:
dataset_cfg = dataset_cfg[0]
else:
hashes = ','.join([get_prompt_hash(cfg) for cfg in dataset_cfg])
hash_object = hashlib.sha256(hashes.encode())
return hash_object.hexdigest()
# for custom datasets
if 'infer_cfg' not in dataset_cfg:
dataset_cfg.pop('abbr', '')
dataset_cfg.pop('path', '')
d_json = json.dumps(dataset_cfg.to_dict(), sort_keys=True)
hash_object = hashlib.sha256(d_json.encode())
return hash_object.hexdigest()
# for regular datasets
if 'reader_cfg' in dataset_cfg.infer_cfg:
# new config
reader_cfg = dict(type='DatasetReader',
input_columns=dataset_cfg.reader_cfg.input_columns,
output_column=dataset_cfg.reader_cfg.output_column)
dataset_cfg.infer_cfg.reader = reader_cfg
if 'train_split' in dataset_cfg.infer_cfg.reader_cfg:
dataset_cfg.infer_cfg.retriever[
'index_split'] = dataset_cfg.infer_cfg['reader_cfg'][
'train_split']
if 'test_split' in dataset_cfg.infer_cfg.reader_cfg:
dataset_cfg.infer_cfg.retriever[
'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split
for k, v in dataset_cfg.infer_cfg.items():
dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1]
# A compromise for the hash consistency
if 'fix_id_list' in dataset_cfg.infer_cfg.retriever:
fix_id_list = dataset_cfg.infer_cfg.retriever.pop('fix_id_list')
dataset_cfg.infer_cfg.inferencer['fix_id_list'] = fix_id_list
d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True)
hash_object = hashlib.sha256(d_json.encode())
return hash_object.hexdigest()
# Assuming get_hash is a function that computes the hash of a file
# from get_hash import get_hash
def get_hash(path):
cfg = Config.fromfile(path)
for k in cfg.keys():
if k.endswith('_datasets'):
return get_prompt_hash(cfg[k])[:6]
print(f'Could not find *_datasets in {path}')
return None
def check_and_rename(filepath):
base_name = os.path.basename(filepath)
match = re.match(r'(.*)_(gen|ppl|ll|mixed)_(.*).py', base_name)
if match:
dataset, mode, old_hash = match.groups()
try:
new_hash = get_hash(filepath)
except Exception:
print(f'Failed to get hash for {filepath}')
raise ModuleNotFoundError
if not new_hash:
return None, None
if old_hash != new_hash:
new_name = f'{dataset}_{mode}_{new_hash}.py'
new_file = os.path.join(os.path.dirname(filepath), new_name)
print(f'Rename {filepath} to {new_file}')
return filepath, new_file
return None, None
# def update_imports(data):
# python_file, name_pairs = data
# for filepath, new_file in name_pairs:
# old_name = os.path.basename(filepath)[:-3]
# new_name = os.path.basename(new_file)[:-3]
# if not os.path.exists(python_file):
# return
# with open(python_file, 'r') as file:
# filedata = file.read()
# # Replace the old name with new name
# new_data = filedata.replace(old_name, new_name)
# if filedata != new_data:
# with open(python_file, 'w') as file:
# file.write(new_data)
# # print(f"Updated imports in {python_file}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument('python_files', nargs='*')
# Could be opencompass/configs/datasets and configs/datasets
parser.add_argument('--root_folder', default='configs/datasets')
args = parser.parse_args()
root_folder = args.root_folder
if args.python_files:
python_files = [
i for i in args.python_files if i.startswith(root_folder)
]
else:
python_files = glob.glob(f'{root_folder}/**/*.py', recursive=True)
# Use multiprocessing to speed up the check and rename process
with Pool(16) as p:
name_pairs = p.map(check_and_rename, python_files)
name_pairs = [pair for pair in name_pairs if pair[0] is not None]
if not name_pairs:
return
with Pool(16) as p:
p.starmap(os.rename, name_pairs)
# root_folder = 'configs'
# python_files = glob.glob(f'{root_folder}/**/*.py', recursive=True)
# update_data = [(python_file, name_pairs) for python_file in python_files]
# with Pool(16) as p:
# p.map(update_imports, update_data)
if __name__ == '__main__':
main()
|