| import argparse | |
| import datetime | |
| import hashlib | |
| import importlib | |
| import json | |
| import os | |
| from pathlib import Path | |
| import numpy as np | |
| import shrinker as shrinker_module | |
| import yaml | |
| from accelerate import Accelerator | |
| from accelerate.utils import InitProcessGroupKwargs | |
| from lmms_eval.utils import simple_parse_args_string | |
| AVAILABEL_SHRINKER = {"embed": "Embed_Shrinker"} | |
| def parse_arguments(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--shrinker", type=str, help="The type of shrinker you want to use") | |
| parser.add_argument("--num_items", type=str, help="The number of items you want in your shrinked dataset") | |
| parser.add_argument("--tasks", type=str, help="The task you want to shrink. Separate each task with comma, will be parsed in to list") | |
| parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push the shrinked dataset to hub") | |
| parser.add_argument("--shrinker_kwargs", type=str, help="In args=xxx,args2=xxx format. Will be parsed into dict") | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = parse_arguments() | |
| shrinker_kwargs = simple_parse_args_string(args.shrinker_kwargs) | |
| shrinker_name = args.shrinker | |
| tasks = args.tasks.split(",") | |
| num_items = args.num_items.split(",") | |
| assert len(num_items) == 1 or len(num_items) == len(tasks), "Either provide one num items for all task or one num item for each task" | |
| if len(num_items) == 1: | |
| num_items = [float(num_items[0])] * len(tasks) | |
| else: | |
| num_items = [float(n) for n in num_items] | |
| push_to_hub = args.push_to_hub | |
| assert len(num_items) == len(tasks) or len(num_items) == 1, "Either pass in one num_items for whole tasks, or pass in num items for each task" | |
| assert shrinker_name in AVAILABEL_SHRINKER, f"Unavailable shrinker {shrinker_name}. You can choose from {AVAILABEL_SHRINKER.keys()}" | |
| kwargs_handler = InitProcessGroupKwargs(timeout=datetime.timedelta(seconds=60000)) | |
| accelerator = Accelerator(kwargs_handlers=[kwargs_handler]) | |
| for idx, task in enumerate(tasks): | |
| shrinker = getattr(shrinker_module, f"{AVAILABEL_SHRINKER[shrinker_name]}") | |
| shrinker = shrinker(task=task, num_items=num_items[idx], push_to_hub=push_to_hub, name=shrinker_name, **shrinker_kwargs) | |
| shrinker.shrink() | |
| accelerator.wait_for_everyone() | |