File size: 3,048 Bytes
06acd95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import json
import logging
from copy import deepcopy
from pathlib import Path
from typing import Optional, Tuple, Union

import torch

from .biosignals_coca_model import BiosignalsCoCa
from .model import get_cast_dtype, convert_weights_to_lp
from .tokenizer import SimpleTokenizer, DEFAULT_CONTEXT_LENGTH

_MODEL_CONFIG_PATHS = [Path(__file__).parent / "model_configs/"]
_MODEL_CONFIGS = {}


def _rescan_model_configs():
    global _MODEL_CONFIGS
    config_files = []
    for config_path in _MODEL_CONFIG_PATHS:
        if config_path.is_dir():
            config_files.extend(config_path.glob("*.json"))
    for cf in config_files:
        with open(cf, "r") as f:
            model_cfg = json.load(f)
            if all(a in model_cfg for a in ("embed_dim", "biosignals_cfg", "text_cfg")):
                _MODEL_CONFIGS[cf.stem] = model_cfg


_rescan_model_configs()


def get_model_config(model_name: str):
    return deepcopy(_MODEL_CONFIGS.get(model_name))


def create_model(
    model_name: str,
    precision: str = "fp32",
    device: Union[str, torch.device] = "cpu",
    **model_kwargs,
) -> BiosignalsCoCa:
    if isinstance(device, str):
        device = torch.device(device)

    model_cfg = get_model_config(model_name)
    if model_cfg is None:
        raise RuntimeError(f"Model config for '{model_name}' not found. Available: {list(_MODEL_CONFIGS.keys())}")

    model_cfg.pop("custom_text", None)
    model_cfg.update(model_kwargs)

    cast_dtype = get_cast_dtype(precision)
    model = BiosignalsCoCa(**model_cfg, cast_dtype=cast_dtype)

    if precision in ("fp16", "bf16"):
        dtype = torch.float16 if "fp16" in precision else torch.bfloat16
        model.to(device=device)
        convert_weights_to_lp(model, dtype=dtype)
    elif precision in ("pure_fp16", "pure_bf16"):
        dtype = torch.float16 if "fp16" in precision else torch.bfloat16
        model.to(device=device, dtype=dtype)
    else:
        model.to(device=device)

    model.output_dict = True
    return model


def load_checkpoint(model, checkpoint_path: str, device="cpu"):
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    state_dict = checkpoint.get("state_dict", checkpoint)
    if next(iter(state_dict)).startswith("module."):
        state_dict = {k[len("module."):]: v for k, v in state_dict.items()}
    incompatible = model.load_state_dict(state_dict, strict=False)
    return incompatible


def get_tokenizer(model_name: str = "", context_length: Optional[int] = None, **kwargs):
    config = get_model_config(model_name) or {}
    text_cfg = config.get("text_cfg", {})
    if context_length is None:
        context_length = text_cfg.get("context_length", DEFAULT_CONTEXT_LENGTH)
    return SimpleTokenizer(context_length=context_length, **kwargs)


def get_input_dtype(precision: str):
    input_dtype = None
    if precision in ("bf16", "pure_bf16"):
        input_dtype = torch.bfloat16
    elif precision in ("fp16", "pure_fp16"):
        input_dtype = torch.float16
    return input_dtype