| | import sys |
| | from pathlib import Path |
| | import subprocess |
| | import torch |
| | from .. import logger |
| |
|
| | from ..utils.base_model import BaseModel |
| |
|
| | example_path = Path(__file__).parent / "../../third_party/example" |
| | sys.path.append(str(example_path)) |
| |
|
| | |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | class Example(BaseModel): |
| | |
| | default_conf = { |
| | "name": "example", |
| | "keypoint_threshold": 0.1, |
| | "max_keypoints": 2000, |
| | "model_name": "model.pth", |
| | } |
| | required_inputs = ["image"] |
| |
|
| | def _init(self, conf): |
| | |
| | model_path = example_path / "checkpoints" / f'{conf["model_name"]}' |
| | if not model_path.exists(): |
| | logger.info(f"No model found at {model_path}") |
| |
|
| | |
| | self.net = callable |
| | |
| | state_dict = torch.load(model_path, map_location="cpu") |
| | self.net.load_state_dict(state_dict["model_state"]) |
| | logger.info(f"Load example model done.") |
| |
|
| | def _forward(self, data): |
| | |
| | |
| | |
| | image = data["image"] |
| |
|
| | |
| | |
| | |
| | |
| | keypoints, scores, descriptors = self.net(image) |
| |
|
| | return { |
| | "keypoints": keypoints, |
| | "scores": scores, |
| | "descriptors": descriptors, |
| | } |
| |
|