File size: 3,127 Bytes
ff0e79e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""

Model export utilities (ONNX, quantization)

"""

import torch
import torch.onnx
from pathlib import Path
from typing import Tuple


def export_to_onnx(model: torch.nn.Module,

                  save_path: str,

                  input_size: Tuple[int, int] = (384, 384),

                  opset_version: int = 14):
    """

    Export model to ONNX format

    

    Args:

        model: PyTorch model

        save_path: Path to save ONNX model

        input_size: Input image size (H, W)

        opset_version: ONNX opset version

    """
    model.eval()
    device = next(model.parameters()).device
    
    # Create dummy input
    dummy_input = torch.randn(1, 3, input_size[0], input_size[1]).to(device)
    
    # Export
    torch.onnx.export(
        model,
        dummy_input,
        save_path,
        export_params=True,
        opset_version=opset_version,
        do_constant_folding=True,
        input_names=['input'],
        output_names=['output', 'features'],
        dynamic_axes={
            'input': {0: 'batch_size'},
            'output': {0: 'batch_size'}
        }
    )
    
    print(f"Model exported to ONNX: {save_path}")
    
    # Verify
    try:
        import onnx
        onnx_model = onnx.load(save_path)
        onnx.checker.check_model(onnx_model)
        print("ONNX model verified successfully")
    except ImportError:
        print("onnx package not installed, skipping verification")


def export_to_torchscript(model: torch.nn.Module,

                         save_path: str,

                         input_size: Tuple[int, int] = (384, 384)):
    """

    Export model to TorchScript format

    

    Args:

        model: PyTorch model

        save_path: Path to save model

        input_size: Input image size (H, W)

    """
    model.eval()
    device = next(model.parameters()).device
    
    # Create dummy input
    dummy_input = torch.randn(1, 3, input_size[0], input_size[1]).to(device)
    
    # Script the model
    scripted_model = torch.jit.trace(model, dummy_input)
    
    # Save
    scripted_model.save(save_path)
    
    print(f"Model exported to TorchScript: {save_path}")


def quantize_model(model: torch.nn.Module,

                  save_path: str,

                  quantization_type: str = 'dynamic'):
    """

    Quantize model for faster inference

    

    Args:

        model: PyTorch model

        save_path: Path to save quantized model

        quantization_type: 'dynamic' or 'static'

    """
    model.eval()
    model = model.cpu()
    
    if quantization_type == 'dynamic':
        # Dynamic quantization (easiest)
        quantized_model = torch.quantization.quantize_dynamic(
            model,
            {torch.nn.Linear, torch.nn.Conv2d},
            dtype=torch.qint8
        )
    else:
        raise ValueError(f"Unsupported quantization type: {quantization_type}")
    
    # Save
    torch.save(quantized_model.state_dict(), save_path)
    
    print(f"Quantized model saved to: {save_path}")
    
    return quantized_model