Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| from ..utils.load_model import load_model | |
| class StitchNetwork: | |
| def __init__(self, model_path, device="cuda"): | |
| kwargs = { | |
| "module_name": "StitchingNetwork", | |
| } | |
| self.model, self.model_type = load_model(model_path, device=device, **kwargs) | |
| self.device = device | |
| def __call__(self, kp_source, kp_driving): | |
| if self.model_type == "onnx": | |
| pred = self.model.run(None, {"kp_source": kp_source, "kp_driving": kp_driving})[0] | |
| elif self.model_type == "tensorrt": | |
| self.model.setup({"kp_source": kp_source, "kp_driving": kp_driving}) | |
| self.model.infer() | |
| pred = self.model.buffer["out"][0].copy() | |
| elif self.model_type == 'pytorch': | |
| with torch.no_grad(): | |
| pred = self.model( | |
| torch.from_numpy(kp_source).to(self.device), | |
| torch.from_numpy(kp_driving).to(self.device) | |
| ).cpu().numpy() | |
| else: | |
| raise ValueError(f"Unsupported model type: {self.model_type}") | |
| return pred | |