import os from typing import Optional import torch def ensure_export_dir(path: str) -> str: """Create export directory and all parent directories if they don't exist.""" os.makedirs(path, exist_ok=True) return path def get_export_dir() -> str: """Get the default export directory, creating it if necessary.""" export_dir = os.getenv("EXPORT_DIR", "models/exports") return ensure_export_dir(export_dir) def export_torchscript(model: torch.nn.Module, example_inputs: torch.Tensor, out_path: str) -> str: """Export model to TorchScript format.""" model.eval() traced = torch.jit.trace(model, example_inputs) torch.jit.save(traced, out_path) return out_path def export_onnx(model: torch.nn.Module, example_inputs: torch.Tensor, out_path: str, opset: int = 17) -> str: """Export model to ONNX format.""" model.eval() torch.onnx.export( model, example_inputs, out_path, export_params=True, opset_version=opset, do_constant_folding=True, input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}, ) return out_path def save_checkpoint(model: torch.nn.Module, path: str, **kwargs) -> str: """Save model checkpoint with metadata.""" # Ensure directory exists os.makedirs(os.path.dirname(path), exist_ok=True) # Save checkpoint checkpoint = { "state_dict": model.state_dict(), **kwargs } torch.save(checkpoint, path) return path