| import os |
| import sys |
| import json |
| from typing import Any, Optional |
| import importlib |
| import numpy as np |
| import mlflow |
| import torch |
| from pydantic import BaseModel, ConfigDict |
| import pickle |
| from mai.sigproc.sigproc_composer import SigprocComposer |
| from mai.sigproc.datamodel import SigprocConfig |
|
|
|
|
| |
| class MAIBaseModel(BaseModel): |
| model_config = ConfigDict(extra="allow") |
|
|
| def __init__(self, **data: Any): |
| super().__init__(**data) |
| |
| for key, value in data.items(): |
| if key not in self.model_fields: |
| self.__dict__[key] = self._convert_to_model(value) |
|
|
| def __getattr__(self, name: str) -> Any: |
| if name in self.__dict__: |
| return self.__dict__[name] |
| raise AttributeError( |
| f"'{self.__class__.__name__}' object has no attribute '{name}'" |
| ) |
|
|
| def __setattr__(self, name: str, value: Any) -> None: |
| if name in self.model_fields: |
| super().__setattr__(name, value) |
| else: |
| self.__dict__[name] = self._convert_to_model(value) |
|
|
| def _convert_to_model(self, value: Any) -> Any: |
| if isinstance(value, dict): |
| return MAIBaseModel(**value) |
| elif isinstance(value, list): |
| return [ |
| self._convert_to_model(item) if isinstance(item, dict) else item |
| for item in value |
| ] |
| return value |
|
|
|
|
| class ActivationConfig(BaseModel): |
| name: str |
| params: Optional[dict] = dict() |
|
|
|
|
| class HeadConfig(BaseModel): |
| output_size: int |
| activation: ActivationConfig |
| loss_idx: int |
|
|
|
|
| class NetworkParams(MAIBaseModel): |
| num_leads: int |
| output_size: int |
| activation: Optional[ActivationConfig] = None |
| task: Optional[str] = "binary" |
| scope: Optional[str] = None |
| code_from_local: Optional[bool] = False |
| num_aux: Optional[int] = 0 |
| multitask_head: Optional[list[HeadConfig]] = None |
|
|
|
|
| class NetworkConfig(BaseModel): |
| name: str |
| params: Optional[NetworkParams] = None |
| code: Optional[str] = None |
| weight: Optional[dict] = None |
|
|
|
|
| class BasePythonModelV3(mlflow.pyfunc.PythonModel): |
| """ |
| BasePythonModelV3 is a base class for all Python models. |
| (input array) -> single preproccesor -> (input tensor)-> model -> (output) -> activation ๊ณผ ๊ฐ์ ์์๋ฅผ ๋ฐ๋ฅด๋ ๋ชจ๋ ๋ชจ๋ธ๋ค์ ํฌํจํจ. |
| """ |
|
|
| def load_context(self, context): |
| self.device = os.getenv("DEVICE", "cpu") |
| self.calib_model_path = ( |
| context.artifacts["calib_model_path"] |
| if "calib_model_path" in context.artifacts |
| else None |
| ) |
|
|
| with open(context.artifacts["sigproc_json_path"], "r") as f: |
| self.sigproc_config = json.load(f) |
|
|
| with open(context.artifacts["params_json_path"], "r") as f: |
| params = json.load(f) |
| self._load_network(params["network"], context.artifacts["network_weights_path"]) |
| self._load_calibrator() |
|
|
| def _load_network(self, network_config, network_weight_path): |
| net_name = network_config["name"] |
| net_params = NetworkParams(**network_config["params"]) |
|
|
| net_module = self._import_module(net_name) |
| net_cls = getattr(net_module, net_name) |
| self.network = net_cls(net_params) |
|
|
| |
| model_state = self.network.state_dict() |
| model_state.update(torch.load(network_weight_path, map_location=self.device)) |
| self.network.load_state_dict(model_state) |
| self.network.to(self.device).eval() |
|
|
| param_sample = next(self.network.parameters()) |
| self.dtype = param_sample.dtype |
|
|
| def _load_calibrator(self): |
| if self.calib_model_path: |
| with open(self.calib_model_path, "rb") as f: |
| model = pickle.load(f) |
| self.calibrator = model["model"] |
| else: |
| self.calibrator = None |
|
|
| |
| def _import_module(self, module_name): |
| if module_name in sys.modules: |
| return importlib.reload(sys.modules[module_name]) |
| else: |
| return importlib.import_module(module_name) |
|
|
| def predict(self, context, model_inputs: torch.Tensor): |
| with torch.inference_mode(): |
| model_inputs = model_inputs.to(self.device, self.dtype) |
|
|
| assert ( |
| model_inputs.dim() == 3 |
| ), "BasePythonModelV3 expect 3D tensor as input: (batch_size, n_leads, length)" |
|
|
| self.network.eval() |
| model_output = self.network(model_inputs) |
| if isinstance(model_output, (tuple, list)): |
| logit = model_output[0] |
| else: |
| logit = model_output |
|
|
| output = self.network.activation(logit).cpu().numpy() |
| if self.calibrator: |
| output = self.calibrator.transform(output.astype(dtype=np.float64)) |
| output = np.stack([1 - output, output]).T |
| return output |
|
|
| def preprocess(self, signal_dicts: list[dict], sigproc_config: dict = None): |
| """ |
| signal_dicts: list[dict] = [signal_dict,...] |
| |
| signal_dict: dict = { |
| "ecg": { |
| "sampling_rate": 250, |
| "waveform": { |
| "data":{ |
| "I":[1,2,3,4,5,6,7,8,9,10,11,12], |
| "II":[1,2,3,4,5,6,7,8,9,10,11,12], |
| "III":[1,2,3,4,5,6,7,8,9,10,11,12], |
| }, |
| } |
| }, |
| "ppg": { |
| "sampling_rate": 100, |
| "waveform": { |
| "data":{ |
| "IR":[1,2,3,4,5,6,7,8,9,10,11,12], |
| }, |
| } |
| } |
| } |
| |
| Example1: |
| model_inputs = [signal_dict,signal_dict] |
| loaded_model = mlflow.pyfunc.load_model(model_uri=model_uri) |
| model_inputs = loaded_model.unwrap_python_model().preprocess(model_inputs) |
| model_outputs = loaded_model.predict(model_inputs) |
| """ |
|
|
| if sigproc_config is None: |
| sigproc_config = self.sigproc_config |
| processor = self.load_sigprocessor(sigproc_config) |
|
|
| batch_input = list() |
| for signal_dict in signal_dicts: |
| single_input = list() |
| for signal_name, signal_data in signal_dict.items(): |
| sampling_rate = signal_data["sampling_rate"] |
| data = signal_data["waveform"]["data"] |
| data, _ = processor[signal_name](data, sampling_rate) |
| data = np.array(list(data.values())) |
| data = torch.from_numpy(data).to(device=self.device, dtype=self.dtype) |
| single_input.append(data) |
|
|
| single_input = torch.stack(single_input).to( |
| device=self.device, dtype=self.dtype |
| ) |
| batch_input.append(single_input) |
|
|
| model_inputs = torch.cat(batch_input, dim=0).to( |
| device=self.device, dtype=self.dtype |
| ) |
|
|
| return model_inputs |
|
|
| @staticmethod |
| def load_sigprocessor(preproc_config): |
| """ |
| Example: |
| preproc_config: dict = { |
| "ecg": [ |
| {"name": "Fill_lead", "params": {}} |
| ], |
| "ppg": [ |
| {"name": "Bandpass", "params": {}} |
| ], |
| } |
| |
| """ |
|
|
| processor = dict() |
| for wave_name, preproc_config in preproc_config.items(): |
| config = SigprocConfig(preproc_config) |
| processor[wave_name] = SigprocComposer().create(config) |
|
|
| return processor |
|
|