File size: 2,238 Bytes
9ad5b1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from importlib import import_module
from omegaconf import OmegaConf
import os
from pathlib import Path
import shutil
from omegaconf import DictConfig
from lightning.pytorch.utilities import rank_zero_info

def instantiate(config: DictConfig, instantiate_module=True):
    """Get arguments from config."""
    module = import_module(config.module_name)
    class_ = getattr(module, config.class_name)
    if instantiate_module:
        init_args = {k: v for k, v in config.items() if k not in ["module_name", "class_name"]}
        return class_(**init_args)
    else:
        return class_

def instantiate_motion_gen(module_name, class_name, cfg, hfstyle=False, **init_args):
    module = import_module(module_name)
    class_ = getattr(module, class_name)
    if hfstyle:
        config_class = class_.config_class
        cfg = config_class(config_obj=cfg)
    return class_(cfg, **init_args)
    
def save_config_and_codes(config, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    sanity_check_dir = os.path.join(save_dir, 'sanity_check')
    os.makedirs(sanity_check_dir, exist_ok=True)
    with open(os.path.join(sanity_check_dir, f'{config.exp_name}.yaml'), 'w') as f:
        OmegaConf.save(config, f)
    current_dir = Path.cwd()
    for py_file in current_dir.rglob('*.py'):
        dest_path = Path(sanity_check_dir) / py_file.relative_to(current_dir)
        dest_path.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy(py_file, dest_path)
    
def print_model_size(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    rank_zero_info(f"Total parameters: {total_params:,}")
    rank_zero_info(f"Trainable parameters: {trainable_params:,}")
    rank_zero_info(f"Non-trainable parameters: {(total_params - trainable_params):,}")
    
def load_metrics(file_path):
    metrics = {}
    with open(file_path, "r") as f:
        for line in f:
            key, value = line.strip().split(": ")
            try:
                metrics[key] = float(value)  # Convert to float if possible
            except ValueError:
                metrics[key] = value  # Keep as string if conversion fails
    return metrics