import torch from BasePythonModelV3 import BasePythonModelV3 class ECGRecoverPythonModelV3(BasePythonModelV3): def predict(self, context, model_inputs: list[torch.Tensor]): with torch.inference_mode(): assert ( isinstance(model_inputs, list) and len(model_inputs) == 2 and all([each_input.dim() == 3 for each_input in model_inputs]) ), "expect list of 2 tensors as input: [digitized, mask], each tensor shape: (batch_size, n_leads, length)" model_inputs = [ each_input.to(self.device, self.dtype) for each_input in model_inputs ] self.network.eval() logit = self.network(model_inputs) output = self.network.activation(logit).cpu().numpy() return output