File size: 1,575 Bytes
4716563
 
 
 
 
 
 
24ea486
4716563
 
 
 
24ea486
 
 
 
 
 
4716563
24ea486
4716563
 
 
 
 
 
 
24ea486
4716563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24ea486
 
 
 
 
 
 
 
 
 
 
 
 
 
4716563
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
from typing import Optional

import torch


def ensure_export_dir(path: str) -> str:
    """Create export directory and all parent directories if they don't exist."""
    os.makedirs(path, exist_ok=True)
    return path


def get_export_dir() -> str:
    """Get the default export directory, creating it if necessary."""
    export_dir = os.getenv("EXPORT_DIR", "models/exports")
    return ensure_export_dir(export_dir)


def export_torchscript(model: torch.nn.Module, example_inputs: torch.Tensor, out_path: str) -> str:
    """Export model to TorchScript format."""
    model.eval()
    traced = torch.jit.trace(model, example_inputs)
    torch.jit.save(traced, out_path)
    return out_path


def export_onnx(model: torch.nn.Module, example_inputs: torch.Tensor, out_path: str, opset: int = 17) -> str:
    """Export model to ONNX format."""
    model.eval()
    torch.onnx.export(
        model,
        example_inputs,
        out_path,
        export_params=True,
        opset_version=opset,
        do_constant_folding=True,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
    )
    return out_path


def save_checkpoint(model: torch.nn.Module, path: str, **kwargs) -> str:
    """Save model checkpoint with metadata."""
    # Ensure directory exists
    os.makedirs(os.path.dirname(path), exist_ok=True)
    
    # Save checkpoint
    checkpoint = {
        "state_dict": model.state_dict(),
        **kwargs
    }
    torch.save(checkpoint, path)
    return path