Spaces:
Running
on
Zero
Running
on
Zero
| from typing import * | |
| import time | |
| from pathlib import Path | |
| from numbers import Number | |
| def catch_exception(fn): | |
| def wrapper(*args, **kwargs): | |
| try: | |
| return fn(*args, **kwargs) | |
| except Exception as e: | |
| import traceback | |
| print(f"Exception in {fn.__name__}({', '.join(repr(arg) for arg in args)}, {', '.join(f'{k}={v!r}' for k, v in kwargs.items())})") | |
| traceback.print_exc(chain=False) | |
| time.sleep(0.1) | |
| return None | |
| return wrapper | |
| class CallbackOnException: | |
| def __init__(self, callback: Callable, exception: type): | |
| self.exception = exception | |
| self.callback = callback | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| if isinstance(exc_val, self.exception): | |
| self.callback() | |
| return True | |
| return False | |
| def traverse_nested_dict_keys(d: Dict[str, Dict]) -> Generator[Tuple[str, ...], None, None]: | |
| for k, v in d.items(): | |
| if isinstance(v, dict): | |
| for sub_key in traverse_nested_dict_keys(v): | |
| yield (k, ) + sub_key | |
| else: | |
| yield (k, ) | |
| def get_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], default: Any = None): | |
| for k in keys: | |
| d = d.get(k, default) | |
| if d is None: | |
| break | |
| return d | |
| def set_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], value: Any): | |
| for k in keys[:-1]: | |
| d = d.setdefault(k, {}) | |
| d[keys[-1]] = value | |
| def key_average(list_of_dicts: list) -> Dict[str, Any]: | |
| """ | |
| Returns a dictionary with the average value of each key in the input list of dictionaries. | |
| """ | |
| _nested_dict_keys = set() | |
| for d in list_of_dicts: | |
| _nested_dict_keys.update(traverse_nested_dict_keys(d)) | |
| _nested_dict_keys = sorted(_nested_dict_keys) | |
| result = {} | |
| for k in _nested_dict_keys: | |
| values = [ | |
| get_nested_dict(d, k) for d in list_of_dicts | |
| if get_nested_dict(d, k) is not None | |
| ] | |
| avg = sum(values) / len(values) if values else float('nan') | |
| set_nested_dict(result, k, avg) | |
| return result | |
| def flatten_nested_dict(d: Dict[str, Any], parent_key: Tuple[str, ...] = None) -> Dict[Tuple[str, ...], Any]: | |
| """ | |
| Flattens a nested dictionary into a single-level dictionary, with keys as tuples. | |
| """ | |
| items = [] | |
| if parent_key is None: | |
| parent_key = () | |
| for k, v in d.items(): | |
| new_key = parent_key + (k, ) | |
| if isinstance(v, MutableMapping): | |
| items.extend(flatten_nested_dict(v, new_key).items()) | |
| else: | |
| items.append((new_key, v)) | |
| return dict(items) | |
| def unflatten_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Unflattens a single-level dictionary into a nested dictionary, with keys as tuples. | |
| """ | |
| result = {} | |
| for k, v in d.items(): | |
| sub_dict = result | |
| for k_ in k[:-1]: | |
| if k_ not in sub_dict: | |
| sub_dict[k_] = {} | |
| sub_dict = sub_dict[k_] | |
| sub_dict[k[-1]] = v | |
| return result | |
| def read_jsonl(file): | |
| import json | |
| with open(file, 'r') as f: | |
| data = f.readlines() | |
| return [json.loads(line) for line in data] | |
| def write_jsonl(data: List[dict], file): | |
| import json | |
| with open(file, 'w') as f: | |
| for item in data: | |
| f.write(json.dumps(item) + '\n') | |
| def save_metrics(save_path: Union[str, Path], all_metrics: Dict[str, List[Dict]]): | |
| import pandas as pd | |
| import json | |
| with open(save_path, 'w') as f: | |
| json.dump(all_metrics, f, indent=4) | |
| def to_hierachical_dataframe(data: List[Dict[Tuple[str, ...], Any]]): | |
| import pandas as pd | |
| data = [flatten_nested_dict(d) for d in data] | |
| df = pd.DataFrame(data) | |
| df = df.sort_index(axis=1) | |
| df.columns = pd.MultiIndex.from_tuples(df.columns) | |
| return df | |
| def recursive_replace(d: Union[List, Dict, str], mapping: Dict[str, str]): | |
| if isinstance(d, str): | |
| for old, new in mapping.items(): | |
| d = d.replace(old, new) | |
| elif isinstance(d, list): | |
| for i, item in enumerate(d): | |
| d[i] = recursive_replace(item, mapping) | |
| elif isinstance(d, dict): | |
| for k, v in d.items(): | |
| d[k] = recursive_replace(v, mapping) | |
| return d | |
| class timeit: | |
| _history: Dict[str, List['timeit']] = {} | |
| def __init__(self, name: str = None, verbose: bool = True, multiple: bool = False): | |
| self.name = name | |
| self.verbose = verbose | |
| self.start = None | |
| self.end = None | |
| self.multiple = multiple | |
| if multiple and name not in timeit._history: | |
| timeit._history[name] = [] | |
| def __call__(self, func: Callable): | |
| import inspect | |
| if inspect.iscoroutinefunction(func): | |
| async def wrapper(*args, **kwargs): | |
| with timeit(self.name or func.__qualname__): | |
| ret = await func(*args, **kwargs) | |
| return ret | |
| return wrapper | |
| else: | |
| def wrapper(*args, **kwargs): | |
| with timeit(self.name or func.__qualname__): | |
| ret = func(*args, **kwargs) | |
| return ret | |
| return wrapper | |
| def __enter__(self): | |
| self.start = time.time() | |
| def time(self) -> float: | |
| assert self.start is not None, "Time not yet started." | |
| assert self.end is not None, "Time not yet ended." | |
| return self.end - self.start | |
| def history(self) -> List['timeit']: | |
| return timeit._history.get(self.name, []) | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| self.end = time.time() | |
| if self.multiple: | |
| timeit._history[self.name].append(self) | |
| if self.verbose: | |
| if self.multiple: | |
| avg = sum(t.time for t in timeit._history[self.name]) / len(timeit._history[self.name]) | |
| print(f"{self.name or 'It'} took {avg} seconds in average.") | |
| else: | |
| print(f"{self.name or 'It'} took {self.time} seconds.") | |
| def strip_common_prefix_suffix(strings: List[str]) -> List[str]: | |
| first = strings[0] | |
| for start in range(len(first)): | |
| if any(s[start] != strings[0][start] for s in strings): | |
| break | |
| for end in range(1, min(len(s) for s in strings)): | |
| if any(s[-end] != first[-end] for s in strings): | |
| break | |
| return [s[start:len(s) - end + 1] for s in strings] | |
| def multithead_execute(inputs: List[Any], num_workers: int, pbar = None): | |
| from concurrent.futures import ThreadPoolExecutor | |
| from contextlib import nullcontext | |
| from tqdm import tqdm | |
| if pbar is not None: | |
| pbar.total = len(inputs) if hasattr(inputs, '__len__') else None | |
| else: | |
| pbar = tqdm(total=len(inputs) if hasattr(inputs, '__len__') else None) | |
| def decorator(fn: Callable): | |
| with ( | |
| ThreadPoolExecutor(max_workers=num_workers) as executor, | |
| pbar | |
| ): | |
| pbar.refresh() | |
| def _fn(input): | |
| ret = fn(input) | |
| pbar.update() | |
| return ret | |
| executor.map(_fn, inputs) | |
| executor.shutdown(wait=True) | |
| return decorator |