File size: 824 Bytes
12f8999
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
from BasePythonModelV3 import BasePythonModelV3


class ECGRecoverPythonModelV3(BasePythonModelV3):
    def predict(self, context, model_inputs: list[torch.Tensor]):
        with torch.inference_mode():
            assert (
                isinstance(model_inputs, list)
                and len(model_inputs) == 2
                and all([each_input.dim() == 3 for each_input in model_inputs])
            ), "expect list of 2 tensors as input: [digitized, mask], each tensor shape: (batch_size, n_leads, length)"
            model_inputs = [
                each_input.to(self.device, self.dtype) for each_input in model_inputs
            ]

            self.network.eval()
            logit = self.network(model_inputs)
            output = self.network.activation(logit).cpu().numpy()
        return output