zongzhex's picture
Add source code
06acd95 verified
import json
import logging
from copy import deepcopy
from pathlib import Path
from typing import Optional, Tuple, Union
import torch
from .biosignals_coca_model import BiosignalsCoCa
from .model import get_cast_dtype, convert_weights_to_lp
from .tokenizer import SimpleTokenizer, DEFAULT_CONTEXT_LENGTH
_MODEL_CONFIG_PATHS = [Path(__file__).parent / "model_configs/"]
_MODEL_CONFIGS = {}
def _rescan_model_configs():
global _MODEL_CONFIGS
config_files = []
for config_path in _MODEL_CONFIG_PATHS:
if config_path.is_dir():
config_files.extend(config_path.glob("*.json"))
for cf in config_files:
with open(cf, "r") as f:
model_cfg = json.load(f)
if all(a in model_cfg for a in ("embed_dim", "biosignals_cfg", "text_cfg")):
_MODEL_CONFIGS[cf.stem] = model_cfg
_rescan_model_configs()
def get_model_config(model_name: str):
return deepcopy(_MODEL_CONFIGS.get(model_name))
def create_model(
model_name: str,
precision: str = "fp32",
device: Union[str, torch.device] = "cpu",
**model_kwargs,
) -> BiosignalsCoCa:
if isinstance(device, str):
device = torch.device(device)
model_cfg = get_model_config(model_name)
if model_cfg is None:
raise RuntimeError(f"Model config for '{model_name}' not found. Available: {list(_MODEL_CONFIGS.keys())}")
model_cfg.pop("custom_text", None)
model_cfg.update(model_kwargs)
cast_dtype = get_cast_dtype(precision)
model = BiosignalsCoCa(**model_cfg, cast_dtype=cast_dtype)
if precision in ("fp16", "bf16"):
dtype = torch.float16 if "fp16" in precision else torch.bfloat16
model.to(device=device)
convert_weights_to_lp(model, dtype=dtype)
elif precision in ("pure_fp16", "pure_bf16"):
dtype = torch.float16 if "fp16" in precision else torch.bfloat16
model.to(device=device, dtype=dtype)
else:
model.to(device=device)
model.output_dict = True
return model
def load_checkpoint(model, checkpoint_path: str, device="cpu"):
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
state_dict = checkpoint.get("state_dict", checkpoint)
if next(iter(state_dict)).startswith("module."):
state_dict = {k[len("module."):]: v for k, v in state_dict.items()}
incompatible = model.load_state_dict(state_dict, strict=False)
return incompatible
def get_tokenizer(model_name: str = "", context_length: Optional[int] = None, **kwargs):
config = get_model_config(model_name) or {}
text_cfg = config.get("text_cfg", {})
if context_length is None:
context_length = text_cfg.get("context_length", DEFAULT_CONTEXT_LENGTH)
return SimpleTokenizer(context_length=context_length, **kwargs)
def get_input_dtype(precision: str):
input_dtype = None
if precision in ("bf16", "pure_bf16"):
input_dtype = torch.bfloat16
elif precision in ("fp16", "pure_fp16"):
input_dtype = torch.float16
return input_dtype