| import torch | |
| import logging | |
| import os | |
| from ..utils.utils import convert_to_torchscript, convert_to_onnx | |
| from ..models.melanoma_classifier import MelanomaClassifier | |
| def convert_checkpoint_to_torchscript_and_onnx(checkpoint_path, | |
| model_class, | |
| output_dir, | |
| input_size=224): | |
| """ | |
| Convert a saved checkpoint (.pth file) to TorchScript and ONNX | |
| """ | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| logging.info(f"Loading checkpoint from {checkpoint_path}") | |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
| if isinstance(checkpoint, dict) and 'args' in checkpoint: | |
| chkpt_args = checkpoint["args"] | |
| num_classes = chkpt_args.num_classes | |
| model_name = chkpt_args.model | |
| pretrained = False | |
| in_22k = chkpt_args.in_22k | |
| if checkpoint['args'].num_groups > 0: | |
| num_classes = chkpt_args.num_classes * chkpt_args.num_groups | |
| model = model_class( | |
| model_name=model_name, | |
| num_classes=num_classes, | |
| pretrained=pretrained, | |
| in_22k=in_22k | |
| ) | |
| else: | |
| model = model_class(num_classes=2) | |
| print(checkpoint.keys()) | |
| if 'model_state_dict' in checkpoint: | |
| state_dict = checkpoint['model_state_dict'] | |
| else: | |
| state_dict = checkpoint['model'] | |
| model.load_state_dict(state_dict) | |
| print(model) | |
| model.eval() | |
| input_tensor = torch.randn(1, 3, input_size, input_size) | |
| logging.info(f"Converting to TorchScript...") | |
| convert_to_torchscript(model, input_tensor, os.path.join(output_dir, "model_torchscript.pt"), True) | |
| logging.info(f"Converting to ONNX...") | |
| convert_to_onnx(model, input_tensor, os.path.join(output_dir, "model_onnx.onnx")) | |
| logging.info(f"Models exported to {output_dir}") | |
| checkpoint_path = r"C:\lumen_melanoma_classification\melanoma-classification\weights\best_model_domain_discriminative.pth" | |
| convert_checkpoint_to_torchscript_and_onnx( | |
| checkpoint_path=checkpoint_path, | |
| model_class=MelanomaClassifier, | |
| input_size=224, | |
| output_dir=r"C:\lumen_melanoma_classification\melanoma-classification\weights\12_best_model_domain_discriminative" | |
| ) |