Spaces:
Paused
Paused
File size: 1,575 Bytes
4716563 24ea486 4716563 24ea486 4716563 24ea486 4716563 24ea486 4716563 24ea486 4716563 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import os
from typing import Optional
import torch
def ensure_export_dir(path: str) -> str:
"""Create export directory and all parent directories if they don't exist."""
os.makedirs(path, exist_ok=True)
return path
def get_export_dir() -> str:
"""Get the default export directory, creating it if necessary."""
export_dir = os.getenv("EXPORT_DIR", "models/exports")
return ensure_export_dir(export_dir)
def export_torchscript(model: torch.nn.Module, example_inputs: torch.Tensor, out_path: str) -> str:
"""Export model to TorchScript format."""
model.eval()
traced = torch.jit.trace(model, example_inputs)
torch.jit.save(traced, out_path)
return out_path
def export_onnx(model: torch.nn.Module, example_inputs: torch.Tensor, out_path: str, opset: int = 17) -> str:
"""Export model to ONNX format."""
model.eval()
torch.onnx.export(
model,
example_inputs,
out_path,
export_params=True,
opset_version=opset,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},
)
return out_path
def save_checkpoint(model: torch.nn.Module, path: str, **kwargs) -> str:
"""Save model checkpoint with metadata."""
# Ensure directory exists
os.makedirs(os.path.dirname(path), exist_ok=True)
# Save checkpoint
checkpoint = {
"state_dict": model.state_dict(),
**kwargs
}
torch.save(checkpoint, path)
return path
|