dr-studio / model /code /ECGRecoverPythonModelV3.py
wogh2012's picture
Upload Dash Docker Space
12f8999 verified
import torch
from BasePythonModelV3 import BasePythonModelV3
class ECGRecoverPythonModelV3(BasePythonModelV3):
def predict(self, context, model_inputs: list[torch.Tensor]):
with torch.inference_mode():
assert (
isinstance(model_inputs, list)
and len(model_inputs) == 2
and all([each_input.dim() == 3 for each_input in model_inputs])
), "expect list of 2 tensors as input: [digitized, mask], each tensor shape: (batch_size, n_leads, length)"
model_inputs = [
each_input.to(self.device, self.dtype) for each_input in model_inputs
]
self.network.eval()
logit = self.network(model_inputs)
output = self.network.activation(logit).cpu().numpy()
return output