File size: 2,349 Bytes
dae5c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import logging
import os
from ..utils.utils import convert_to_torchscript, convert_to_onnx
from ..models.melanoma_classifier import MelanomaClassifier

def convert_checkpoint_to_torchscript_and_onnx(checkpoint_path,

                                               model_class, 

                                               output_dir,

                                               input_size=224):
   """

   Convert a saved checkpoint (.pth file) to TorchScript and ONNX

   """
   if not os.path.exists(output_dir):
      os.makedirs(output_dir)
   
   logging.info(f"Loading checkpoint from {checkpoint_path}")
   checkpoint = torch.load(checkpoint_path, map_location='cpu')
   
   if isinstance(checkpoint, dict) and 'args' in checkpoint:
      chkpt_args = checkpoint["args"]
      
      num_classes = chkpt_args.num_classes
      model_name = chkpt_args.model
      pretrained = False
      in_22k = chkpt_args.in_22k
      if checkpoint['args'].num_groups > 0:
         num_classes = chkpt_args.num_classes * chkpt_args.num_groups
      
      model = model_class(
         model_name=model_name,
         num_classes=num_classes,
         pretrained=pretrained,
         in_22k=in_22k
      )
   else:
      model = model_class(num_classes=2) 

   print(checkpoint.keys())
   if 'model_state_dict' in checkpoint:
      state_dict = checkpoint['model_state_dict']
   else:
      state_dict = checkpoint['model']
   
   model.load_state_dict(state_dict)
   print(model)
   model.eval()
   
   input_tensor = torch.randn(1, 3, input_size, input_size)

   logging.info(f"Converting to TorchScript...")
   convert_to_torchscript(model, input_tensor, os.path.join(output_dir, "model_torchscript.pt"), True)
   logging.info(f"Converting to ONNX...")
   convert_to_onnx(model, input_tensor, os.path.join(output_dir, "model_onnx.onnx"))
   
   logging.info(f"Models exported to {output_dir}")
   
checkpoint_path = r"C:\lumen_melanoma_classification\melanoma-classification\weights\best_model_domain_discriminative.pth"
convert_checkpoint_to_torchscript_and_onnx(
   checkpoint_path=checkpoint_path,
   model_class=MelanomaClassifier,
   input_size=224,
   output_dir=r"C:\lumen_melanoma_classification\melanoma-classification\weights\12_best_model_domain_discriminative"
)