|
|
from importlib import import_module |
|
|
from omegaconf import OmegaConf |
|
|
import os |
|
|
from pathlib import Path |
|
|
import shutil |
|
|
from omegaconf import DictConfig |
|
|
from lightning.pytorch.utilities import rank_zero_info |
|
|
|
|
|
def instantiate(config: DictConfig, instantiate_module=True): |
|
|
"""Get arguments from config.""" |
|
|
module = import_module(config.module_name) |
|
|
class_ = getattr(module, config.class_name) |
|
|
if instantiate_module: |
|
|
init_args = {k: v for k, v in config.items() if k not in ["module_name", "class_name"]} |
|
|
return class_(**init_args) |
|
|
else: |
|
|
return class_ |
|
|
|
|
|
def instantiate_motion_gen(module_name, class_name, cfg, hfstyle=False, **init_args): |
|
|
module = import_module(module_name) |
|
|
class_ = getattr(module, class_name) |
|
|
if hfstyle: |
|
|
config_class = class_.config_class |
|
|
cfg = config_class(config_obj=cfg) |
|
|
return class_(cfg, **init_args) |
|
|
|
|
|
def save_config_and_codes(config, save_dir): |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
sanity_check_dir = os.path.join(save_dir, 'sanity_check') |
|
|
os.makedirs(sanity_check_dir, exist_ok=True) |
|
|
with open(os.path.join(sanity_check_dir, f'{config.exp_name}.yaml'), 'w') as f: |
|
|
OmegaConf.save(config, f) |
|
|
current_dir = Path.cwd() |
|
|
for py_file in current_dir.rglob('*.py'): |
|
|
dest_path = Path(sanity_check_dir) / py_file.relative_to(current_dir) |
|
|
dest_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
shutil.copy(py_file, dest_path) |
|
|
|
|
|
def print_model_size(model): |
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
rank_zero_info(f"Total parameters: {total_params:,}") |
|
|
rank_zero_info(f"Trainable parameters: {trainable_params:,}") |
|
|
rank_zero_info(f"Non-trainable parameters: {(total_params - trainable_params):,}") |
|
|
|
|
|
def load_metrics(file_path): |
|
|
metrics = {} |
|
|
with open(file_path, "r") as f: |
|
|
for line in f: |
|
|
key, value = line.strip().split(": ") |
|
|
try: |
|
|
metrics[key] = float(value) |
|
|
except ValueError: |
|
|
metrics[key] = value |
|
|
return metrics |