File size: 2,616 Bytes
51fdac5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""

Model Export Script



Export trained model to ONNX format for deployment.



Usage:

    python scripts/export_model.py --model outputs/checkpoints/best_doctamper.pth --format onnx

"""

import argparse
import sys
from pathlib import Path

# Add src to path
sys.path.insert(0, str(Path(__file__).parent.parent))

import torch

from src.config import get_config
from src.models import get_model
from src.utils import export_to_onnx, export_to_torchscript


def parse_args():
    parser = argparse.ArgumentParser(description="Export model for deployment")
    
    parser.add_argument('--model', type=str, required=True,
                       help='Path to model checkpoint')
    
    parser.add_argument('--format', type=str, default='onnx',
                       choices=['onnx', 'torchscript', 'both'],
                       help='Export format')
    
    parser.add_argument('--output', type=str, default='outputs/exported',
                       help='Output directory')
    
    parser.add_argument('--config', type=str, default='config.yaml',
                       help='Path to config file')
    
    return parser.parse_args()


def main():
    args = parse_args()
    
    # Load config
    config = get_config(args.config)
    
    print("\n" + "="*60)
    print("Model Export")
    print("="*60)
    print(f"Model: {args.model}")
    print(f"Format: {args.format}")
    print("="*60)
    
    # Create output directory
    output_dir = Path(args.output)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Load model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = get_model(config).to(device)
    
    checkpoint = torch.load(args.model, map_location=device)
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)
    
    model.eval()
    print("Model loaded")
    
    # Get image size
    image_size = config.get('data.image_size', 384)
    
    # Export
    if args.format in ['onnx', 'both']:
        onnx_path = output_dir / 'model.onnx'
        export_to_onnx(model, str(onnx_path), input_size=(image_size, image_size))
    
    if args.format in ['torchscript', 'both']:
        ts_path = output_dir / 'model.pt'
        export_to_torchscript(model, str(ts_path), input_size=(image_size, image_size))
    
    print("\n" + "="*60)
    print("Export Complete!")
    print(f"Output: {output_dir}")
    print("="*60)


if __name__ == '__main__':
    main()