Spaces:
Running
Running
File size: 3,027 Bytes
382733a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import json
import os
import os.path as osp
import pickle
import re
from pathlib import Path
from typing import Any
import torch
import random
import numpy as np
def make_dir(dir_path: str) -> None:
"""Creates a directory if it does not exist."""
if not Path(dir_path).exists():
Path(dir_path).mkdir(parents=True, exist_ok=True)
def ensure_dir(path: str) -> None:
"""
Ensures that a directory exists; creates it if it does not.
"""
if not osp.exists(path):
os.makedirs(path)
def assert_dir(path: str) -> None:
"""Asserts that a directory exists."""
assert osp.exists(path)
def load_pkl_data(filename: str) -> Any:
"""Loads data from a pickle file."""
with open(filename, 'rb') as handle:
data_dict = pickle.load(handle)
return data_dict
def write_pkl_data(data_dict: Any, filename: str) -> None:
"""Writes data to a pickle file."""
with open(filename, 'wb') as handle:
pickle.dump(data_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
def load_json(filename: str) -> Any:
"""Loads data from a JSON file."""
file = open(filename)
data = json.load(file)
file.close()
return data
def write_json(data_dict: Any, filename: str) -> None:
"""Writes data to a JSON file with indentation."""
json_obj = json.dumps(data_dict, indent=4)
with open(filename, "w") as outfile:
outfile.write(json_obj)
def get_print_format(value: Any) -> str:
"""Determines the appropriate format string for a given value."""
if isinstance(value, int):
return 'd'
if isinstance(value, str):
return 's'
if value == 0:
return '.3f'
if value < 1e-6:
return '.3e'
if value < 1e-3:
return '.6f'
return '.6f'
def get_format_strings(kv_pairs: list) -> list:
"""Generates format strings for a list of key-value pairs."""
log_strings = []
for key, value in kv_pairs:
fmt = get_print_format(value)
format_string = '{}: {:' + fmt + '}'
log_strings.append(format_string.format(key, value))
return log_strings
def get_first_index_batch(x: Any) -> Any:
"""Retrieves the first index from a batch, handling different data types."""
if isinstance(x, list):
x = x[0]
elif isinstance(x, torch.Tensor):
x = x.squeeze(0)
elif isinstance(x, dict):
x = {key: get_first_index_batch(value) for key, value in x.items()}
return x
def split_sentence(sentence: str) -> list:
"""Splits a sentence into individual sentences based on periods."""
sentence = re.split(r'[.]', sentence)
sentence = [s.strip() for s in sentence]
sentence = [s for s in sentence if len(s) > 0]
return sentence
def set_random_seed(seed: int) -> None:
"""Sets the random seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False |