import json import logging from copy import deepcopy from pathlib import Path from typing import Optional, Tuple, Union import torch from .biosignals_coca_model import BiosignalsCoCa from .model import get_cast_dtype, convert_weights_to_lp from .tokenizer import SimpleTokenizer, DEFAULT_CONTEXT_LENGTH _MODEL_CONFIG_PATHS = [Path(__file__).parent / "model_configs/"] _MODEL_CONFIGS = {} def _rescan_model_configs(): global _MODEL_CONFIGS config_files = [] for config_path in _MODEL_CONFIG_PATHS: if config_path.is_dir(): config_files.extend(config_path.glob("*.json")) for cf in config_files: with open(cf, "r") as f: model_cfg = json.load(f) if all(a in model_cfg for a in ("embed_dim", "biosignals_cfg", "text_cfg")): _MODEL_CONFIGS[cf.stem] = model_cfg _rescan_model_configs() def get_model_config(model_name: str): return deepcopy(_MODEL_CONFIGS.get(model_name)) def create_model( model_name: str, precision: str = "fp32", device: Union[str, torch.device] = "cpu", **model_kwargs, ) -> BiosignalsCoCa: if isinstance(device, str): device = torch.device(device) model_cfg = get_model_config(model_name) if model_cfg is None: raise RuntimeError(f"Model config for '{model_name}' not found. Available: {list(_MODEL_CONFIGS.keys())}") model_cfg.pop("custom_text", None) model_cfg.update(model_kwargs) cast_dtype = get_cast_dtype(precision) model = BiosignalsCoCa(**model_cfg, cast_dtype=cast_dtype) if precision in ("fp16", "bf16"): dtype = torch.float16 if "fp16" in precision else torch.bfloat16 model.to(device=device) convert_weights_to_lp(model, dtype=dtype) elif precision in ("pure_fp16", "pure_bf16"): dtype = torch.float16 if "fp16" in precision else torch.bfloat16 model.to(device=device, dtype=dtype) else: model.to(device=device) model.output_dict = True return model def load_checkpoint(model, checkpoint_path: str, device="cpu"): checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) state_dict = checkpoint.get("state_dict", checkpoint) if next(iter(state_dict)).startswith("module."): state_dict = {k[len("module."):]: v for k, v in state_dict.items()} incompatible = model.load_state_dict(state_dict, strict=False) return incompatible def get_tokenizer(model_name: str = "", context_length: Optional[int] = None, **kwargs): config = get_model_config(model_name) or {} text_cfg = config.get("text_cfg", {}) if context_length is None: context_length = text_cfg.get("context_length", DEFAULT_CONTEXT_LENGTH) return SimpleTokenizer(context_length=context_length, **kwargs) def get_input_dtype(precision: str): input_dtype = None if precision in ("bf16", "pure_bf16"): input_dtype = torch.bfloat16 elif precision in ("fp16", "pure_fp16"): input_dtype = torch.float16 return input_dtype