Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import os.path as osp | |
| import pickle | |
| import re | |
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| import random | |
| import numpy as np | |
| def make_dir(dir_path: str) -> None: | |
| """Creates a directory if it does not exist.""" | |
| if not Path(dir_path).exists(): | |
| Path(dir_path).mkdir(parents=True, exist_ok=True) | |
| def ensure_dir(path: str) -> None: | |
| """ | |
| Ensures that a directory exists; creates it if it does not. | |
| """ | |
| if not osp.exists(path): | |
| os.makedirs(path) | |
| def assert_dir(path: str) -> None: | |
| """Asserts that a directory exists.""" | |
| assert osp.exists(path) | |
| def load_pkl_data(filename: str) -> Any: | |
| """Loads data from a pickle file.""" | |
| with open(filename, 'rb') as handle: | |
| data_dict = pickle.load(handle) | |
| return data_dict | |
| def write_pkl_data(data_dict: Any, filename: str) -> None: | |
| """Writes data to a pickle file.""" | |
| with open(filename, 'wb') as handle: | |
| pickle.dump(data_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) | |
| def load_json(filename: str) -> Any: | |
| """Loads data from a JSON file.""" | |
| file = open(filename) | |
| data = json.load(file) | |
| file.close() | |
| return data | |
| def write_json(data_dict: Any, filename: str) -> None: | |
| """Writes data to a JSON file with indentation.""" | |
| json_obj = json.dumps(data_dict, indent=4) | |
| with open(filename, "w") as outfile: | |
| outfile.write(json_obj) | |
| def get_print_format(value: Any) -> str: | |
| """Determines the appropriate format string for a given value.""" | |
| if isinstance(value, int): | |
| return 'd' | |
| if isinstance(value, str): | |
| return 's' | |
| if value == 0: | |
| return '.3f' | |
| if value < 1e-6: | |
| return '.3e' | |
| if value < 1e-3: | |
| return '.6f' | |
| return '.6f' | |
| def get_format_strings(kv_pairs: list) -> list: | |
| """Generates format strings for a list of key-value pairs.""" | |
| log_strings = [] | |
| for key, value in kv_pairs: | |
| fmt = get_print_format(value) | |
| format_string = '{}: {:' + fmt + '}' | |
| log_strings.append(format_string.format(key, value)) | |
| return log_strings | |
| def get_first_index_batch(x: Any) -> Any: | |
| """Retrieves the first index from a batch, handling different data types.""" | |
| if isinstance(x, list): | |
| x = x[0] | |
| elif isinstance(x, torch.Tensor): | |
| x = x.squeeze(0) | |
| elif isinstance(x, dict): | |
| x = {key: get_first_index_batch(value) for key, value in x.items()} | |
| return x | |
| def split_sentence(sentence: str) -> list: | |
| """Splits a sentence into individual sentences based on periods.""" | |
| sentence = re.split(r'[.]', sentence) | |
| sentence = [s.strip() for s in sentence] | |
| sentence = [s for s in sentence if len(s) > 0] | |
| return sentence | |
| def set_random_seed(seed: int) -> None: | |
| """Sets the random seed for reproducibility.""" | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False |