GuideFlow3D / lib /util /common.py
suvadityamuk's picture
feat: add initial files for space
382733a
import json
import os
import os.path as osp
import pickle
import re
from pathlib import Path
from typing import Any
import torch
import random
import numpy as np
def make_dir(dir_path: str) -> None:
"""Creates a directory if it does not exist."""
if not Path(dir_path).exists():
Path(dir_path).mkdir(parents=True, exist_ok=True)
def ensure_dir(path: str) -> None:
"""
Ensures that a directory exists; creates it if it does not.
"""
if not osp.exists(path):
os.makedirs(path)
def assert_dir(path: str) -> None:
"""Asserts that a directory exists."""
assert osp.exists(path)
def load_pkl_data(filename: str) -> Any:
"""Loads data from a pickle file."""
with open(filename, 'rb') as handle:
data_dict = pickle.load(handle)
return data_dict
def write_pkl_data(data_dict: Any, filename: str) -> None:
"""Writes data to a pickle file."""
with open(filename, 'wb') as handle:
pickle.dump(data_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
def load_json(filename: str) -> Any:
"""Loads data from a JSON file."""
file = open(filename)
data = json.load(file)
file.close()
return data
def write_json(data_dict: Any, filename: str) -> None:
"""Writes data to a JSON file with indentation."""
json_obj = json.dumps(data_dict, indent=4)
with open(filename, "w") as outfile:
outfile.write(json_obj)
def get_print_format(value: Any) -> str:
"""Determines the appropriate format string for a given value."""
if isinstance(value, int):
return 'd'
if isinstance(value, str):
return 's'
if value == 0:
return '.3f'
if value < 1e-6:
return '.3e'
if value < 1e-3:
return '.6f'
return '.6f'
def get_format_strings(kv_pairs: list) -> list:
"""Generates format strings for a list of key-value pairs."""
log_strings = []
for key, value in kv_pairs:
fmt = get_print_format(value)
format_string = '{}: {:' + fmt + '}'
log_strings.append(format_string.format(key, value))
return log_strings
def get_first_index_batch(x: Any) -> Any:
"""Retrieves the first index from a batch, handling different data types."""
if isinstance(x, list):
x = x[0]
elif isinstance(x, torch.Tensor):
x = x.squeeze(0)
elif isinstance(x, dict):
x = {key: get_first_index_batch(value) for key, value in x.items()}
return x
def split_sentence(sentence: str) -> list:
"""Splits a sentence into individual sentences based on periods."""
sentence = re.split(r'[.]', sentence)
sentence = [s.strip() for s in sentence]
sentence = [s for s in sentence if len(s) > 0]
return sentence
def set_random_seed(seed: int) -> None:
"""Sets the random seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False