recomendation / utils /export.py
Ali Mohsin
Next level fix
24ea486
import os
from typing import Optional
import torch
def ensure_export_dir(path: str) -> str:
"""Create export directory and all parent directories if they don't exist."""
os.makedirs(path, exist_ok=True)
return path
def get_export_dir() -> str:
"""Get the default export directory, creating it if necessary."""
export_dir = os.getenv("EXPORT_DIR", "models/exports")
return ensure_export_dir(export_dir)
def export_torchscript(model: torch.nn.Module, example_inputs: torch.Tensor, out_path: str) -> str:
"""Export model to TorchScript format."""
model.eval()
traced = torch.jit.trace(model, example_inputs)
torch.jit.save(traced, out_path)
return out_path
def export_onnx(model: torch.nn.Module, example_inputs: torch.Tensor, out_path: str, opset: int = 17) -> str:
"""Export model to ONNX format."""
model.eval()
torch.onnx.export(
model,
example_inputs,
out_path,
export_params=True,
opset_version=opset,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
)
return out_path
def save_checkpoint(model: torch.nn.Module, path: str, **kwargs) -> str:
"""Save model checkpoint with metadata."""
# Ensure directory exists
os.makedirs(os.path.dirname(path), exist_ok=True)
# Save checkpoint
checkpoint = {
"state_dict": model.state_dict(),
**kwargs
}
torch.save(checkpoint, path)
return path