Spaces:
Paused
Paused
| import os | |
| from typing import Optional | |
| import torch | |
| def ensure_export_dir(path: str) -> str: | |
| """Create export directory and all parent directories if they don't exist.""" | |
| os.makedirs(path, exist_ok=True) | |
| return path | |
| def get_export_dir() -> str: | |
| """Get the default export directory, creating it if necessary.""" | |
| export_dir = os.getenv("EXPORT_DIR", "models/exports") | |
| return ensure_export_dir(export_dir) | |
| def export_torchscript(model: torch.nn.Module, example_inputs: torch.Tensor, out_path: str) -> str: | |
| """Export model to TorchScript format.""" | |
| model.eval() | |
| traced = torch.jit.trace(model, example_inputs) | |
| torch.jit.save(traced, out_path) | |
| return out_path | |
| def export_onnx(model: torch.nn.Module, example_inputs: torch.Tensor, out_path: str, opset: int = 17) -> str: | |
| """Export model to ONNX format.""" | |
| model.eval() | |
| torch.onnx.export( | |
| model, | |
| example_inputs, | |
| out_path, | |
| export_params=True, | |
| opset_version=opset, | |
| do_constant_folding=True, | |
| input_names=["input"], | |
| output_names=["output"], | |
| dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}, | |
| ) | |
| return out_path | |
| def save_checkpoint(model: torch.nn.Module, path: str, **kwargs) -> str: | |
| """Save model checkpoint with metadata.""" | |
| # Ensure directory exists | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| # Save checkpoint | |
| checkpoint = { | |
| "state_dict": model.state_dict(), | |
| **kwargs | |
| } | |
| torch.save(checkpoint, path) | |
| return path | |