File size: 1,589 Bytes
ee3e701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from internlm.initialize.launch import get_config_value
from internlm.utils.logger import get_logger

logger = get_logger(__file__)


def auto_resume_sanity_check(ckpt_config):
    load_given_ckpt = get_config_value(ckpt_config, "load_given_ckpt", None)
    if load_given_ckpt is None:
        return True  # default value is True
    else:
        return not load_given_ckpt


def ckpt_info_sanity_check(ckpt_config):
    load_ckpt_folder = get_config_value(ckpt_config, "load_ckpt_folder", None)

    load_model_only_folder = get_config_value(ckpt_config, "load_model_only_folder", None)

    if load_model_only_folder is not None:
        assert (
            load_ckpt_folder is None
        ), "Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \
# and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
        return dict(path=load_model_only_folder, content=("model",), ckpt_type="internlm")
    else:
        load_optimizer = get_config_value(ckpt_config, "load_optimizer", True)

        if isinstance(load_ckpt_folder, str):
            if load_optimizer:
                return dict(path=load_ckpt_folder, content=("model", "sampler", "optimizer"), ckpt_type="internlm")
            else:
                return dict(path=load_ckpt_folder, content=("model", "sampler"), ckpt_type="internlm")
        elif load_ckpt_folder is None:
            return None
        else:
            assert f"Unsupport data type:'{type(load_ckpt_folder)}' for config.ckpt arg: 'load_ckpt_folder'"