| import torch |
| from BasePythonModelV3 import BasePythonModelV3 |
|
|
|
|
| class ECGRecoverPythonModelV3(BasePythonModelV3): |
| def predict(self, context, model_inputs: list[torch.Tensor]): |
| with torch.inference_mode(): |
| assert ( |
| isinstance(model_inputs, list) |
| and len(model_inputs) == 2 |
| and all([each_input.dim() == 3 for each_input in model_inputs]) |
| ), "expect list of 2 tensors as input: [digitized, mask], each tensor shape: (batch_size, n_leads, length)" |
| model_inputs = [ |
| each_input.to(self.device, self.dtype) for each_input in model_inputs |
| ] |
|
|
| self.network.eval() |
| logit = self.network(model_inputs) |
| output = self.network.activation(logit).cpu().numpy() |
| return output |
|
|