import os import sys import json from typing import Any, Optional import importlib import numpy as np import mlflow import torch from pydantic import BaseModel, ConfigDict import pickle from mai.sigproc.sigproc_composer import SigprocComposer from mai.sigproc.datamodel import SigprocConfig # Duplicated with config.datamodel. See: https://git.medicalai.com:50001/team-ai/solver/solver2/-/issues/526#note_148134 class MAIBaseModel(BaseModel): model_config = ConfigDict(extra="allow") def __init__(self, **data: Any): super().__init__(**data) # 명시되지 않은 필드를 저장 for key, value in data.items(): if key not in self.model_fields: self.__dict__[key] = self._convert_to_model(value) def __getattr__(self, name: str) -> Any: if name in self.__dict__: return self.__dict__[name] raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{name}'" ) def __setattr__(self, name: str, value: Any) -> None: if name in self.model_fields: super().__setattr__(name, value) else: self.__dict__[name] = self._convert_to_model(value) def _convert_to_model(self, value: Any) -> Any: if isinstance(value, dict): return MAIBaseModel(**value) elif isinstance(value, list): return [ self._convert_to_model(item) if isinstance(item, dict) else item for item in value ] return value class ActivationConfig(BaseModel): name: str params: Optional[dict] = dict() class HeadConfig(BaseModel): output_size: int activation: ActivationConfig loss_idx: int class NetworkParams(MAIBaseModel): num_leads: int output_size: int activation: Optional[ActivationConfig] = None task: Optional[str] = "binary" scope: Optional[str] = None code_from_local: Optional[bool] = False num_aux: Optional[int] = 0 multitask_head: Optional[list[HeadConfig]] = None class NetworkConfig(BaseModel): name: str params: Optional[NetworkParams] = None code: Optional[str] = None weight: Optional[dict] = None class BasePythonModelV3(mlflow.pyfunc.PythonModel): """ BasePythonModelV3 is a base class for all Python models. (input array) -> single preproccesor -> (input tensor)-> model -> (output) -> activation 과 같은 순서를 따르는 모든 모델들을 포함함. """ def load_context(self, context): self.device = os.getenv("DEVICE", "cpu") self.calib_model_path = ( context.artifacts["calib_model_path"] if "calib_model_path" in context.artifacts else None ) with open(context.artifacts["sigproc_json_path"], "r") as f: self.sigproc_config = json.load(f) with open(context.artifacts["params_json_path"], "r") as f: params = json.load(f) self._load_network(params["network"], context.artifacts["network_weights_path"]) self._load_calibrator() def _load_network(self, network_config, network_weight_path): net_name = network_config["name"] net_params = NetworkParams(**network_config["params"]) net_module = self._import_module(net_name) net_cls = getattr(net_module, net_name) self.network = net_cls(net_params) # load weight model_state = self.network.state_dict() model_state.update(torch.load(network_weight_path, map_location=self.device)) self.network.load_state_dict(model_state) self.network.to(self.device).eval() param_sample = next(self.network.parameters()) self.dtype = param_sample.dtype def _load_calibrator(self): if self.calib_model_path: with open(self.calib_model_path, "rb") as f: model = pickle.load(f) self.calibrator = model["model"] else: self.calibrator = None # See https://git.medicalai.com:50001/team-ai/solver/solver2/-/issues/431#note_169113 def _import_module(self, module_name): if module_name in sys.modules: return importlib.reload(sys.modules[module_name]) else: return importlib.import_module(module_name) def predict(self, context, model_inputs: torch.Tensor): with torch.inference_mode(): model_inputs = model_inputs.to(self.device, self.dtype) assert ( model_inputs.dim() == 3 ), "BasePythonModelV3 expect 3D tensor as input: (batch_size, n_leads, length)" self.network.eval() model_output = self.network(model_inputs) if isinstance(model_output, (tuple, list)): logit = model_output[0] else: logit = model_output output = self.network.activation(logit).cpu().numpy() if self.calibrator: output = self.calibrator.transform(output.astype(dtype=np.float64)) output = np.stack([1 - output, output]).T return output def preprocess(self, signal_dicts: list[dict], sigproc_config: dict = None): """ signal_dicts: list[dict] = [signal_dict,...] signal_dict: dict = { "ecg": { "sampling_rate": 250, "waveform": { "data":{ "I":[1,2,3,4,5,6,7,8,9,10,11,12], "II":[1,2,3,4,5,6,7,8,9,10,11,12], "III":[1,2,3,4,5,6,7,8,9,10,11,12], }, } }, "ppg": { "sampling_rate": 100, "waveform": { "data":{ "IR":[1,2,3,4,5,6,7,8,9,10,11,12], }, } } } Example1: model_inputs = [signal_dict,signal_dict] loaded_model = mlflow.pyfunc.load_model(model_uri=model_uri) model_inputs = loaded_model.unwrap_python_model().preprocess(model_inputs) model_outputs = loaded_model.predict(model_inputs) """ if sigproc_config is None: sigproc_config = self.sigproc_config processor = self.load_sigprocessor(sigproc_config) batch_input = list() for signal_dict in signal_dicts: single_input = list() for signal_name, signal_data in signal_dict.items(): sampling_rate = signal_data["sampling_rate"] data = signal_data["waveform"]["data"] data, _ = processor[signal_name](data, sampling_rate) data = np.array(list(data.values())) data = torch.from_numpy(data).to(device=self.device, dtype=self.dtype) single_input.append(data) single_input = torch.stack(single_input).to( device=self.device, dtype=self.dtype ) batch_input.append(single_input) model_inputs = torch.cat(batch_input, dim=0).to( device=self.device, dtype=self.dtype ) return model_inputs @staticmethod def load_sigprocessor(preproc_config): """ Example: preproc_config: dict = { "ecg": [ {"name": "Fill_lead", "params": {}} ], "ppg": [ {"name": "Bandpass", "params": {}} ], } """ processor = dict() for wave_name, preproc_config in preproc_config.items(): config = SigprocConfig(preproc_config) processor[wave_name] = SigprocComposer().create(config) return processor