Spaces:
Sleeping
Sleeping
| """ | |
| Model Export Script | |
| Export trained model to ONNX format for deployment. | |
| Usage: | |
| python scripts/export_model.py --model outputs/checkpoints/best_doctamper.pth --format onnx | |
| """ | |
| import argparse | |
| import sys | |
| from pathlib import Path | |
| # Add src to path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| import torch | |
| from src.config import get_config | |
| from src.models import get_model | |
| from src.utils import export_to_onnx, export_to_torchscript | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Export model for deployment") | |
| parser.add_argument('--model', type=str, required=True, | |
| help='Path to model checkpoint') | |
| parser.add_argument('--format', type=str, default='onnx', | |
| choices=['onnx', 'torchscript', 'both'], | |
| help='Export format') | |
| parser.add_argument('--output', type=str, default='outputs/exported', | |
| help='Output directory') | |
| parser.add_argument('--config', type=str, default='config.yaml', | |
| help='Path to config file') | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| # Load config | |
| config = get_config(args.config) | |
| print("\n" + "="*60) | |
| print("Model Export") | |
| print("="*60) | |
| print(f"Model: {args.model}") | |
| print(f"Format: {args.format}") | |
| print("="*60) | |
| # Create output directory | |
| output_dir = Path(args.output) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Load model | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = get_model(config).to(device) | |
| checkpoint = torch.load(args.model, map_location=device) | |
| if 'model_state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| model.eval() | |
| print("Model loaded") | |
| # Get image size | |
| image_size = config.get('data.image_size', 384) | |
| # Export | |
| if args.format in ['onnx', 'both']: | |
| onnx_path = output_dir / 'model.onnx' | |
| export_to_onnx(model, str(onnx_path), input_size=(image_size, image_size)) | |
| if args.format in ['torchscript', 'both']: | |
| ts_path = output_dir / 'model.pt' | |
| export_to_torchscript(model, str(ts_path), input_size=(image_size, image_size)) | |
| print("\n" + "="*60) | |
| print("Export Complete!") | |
| print(f"Output: {output_dir}") | |
| print("="*60) | |
| if __name__ == '__main__': | |
| main() | |