File size: 824 Bytes
12f8999 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | import torch
from BasePythonModelV3 import BasePythonModelV3
class ECGRecoverPythonModelV3(BasePythonModelV3):
def predict(self, context, model_inputs: list[torch.Tensor]):
with torch.inference_mode():
assert (
isinstance(model_inputs, list)
and len(model_inputs) == 2
and all([each_input.dim() == 3 for each_input in model_inputs])
), "expect list of 2 tensors as input: [digitized, mask], each tensor shape: (batch_size, n_leads, length)"
model_inputs = [
each_input.to(self.device, self.dtype) for each_input in model_inputs
]
self.network.eval()
logit = self.network(model_inputs)
output = self.network.activation(logit).cpu().numpy()
return output
|