Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| from typing import Optional | |
| import onnxruntime as rt | |
| from huggingface_hub import hf_hub_download | |
| def download_onnx( | |
| repo_id: str, | |
| filename: str = "model.onnx", | |
| revision: Optional[str] = None, | |
| token: Optional[str] = None, | |
| ) -> Path: | |
| if not filename.endswith(".onnx"): | |
| filename += ".onnx" | |
| model_path = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, token=token) | |
| return Path(model_path).resolve() | |
| def create_session( | |
| repo_id: str, | |
| revision: Optional[str] = None, | |
| token: Optional[str] = None, | |
| ) -> rt.InferenceSession: | |
| model_path = download_onnx(repo_id, revision=revision, token=token) | |
| if not model_path.is_file(): | |
| model_path = model_path.joinpath("model.onnx") | |
| if not model_path.is_file(): | |
| raise FileNotFoundError(f"Model not found: {model_path}") | |
| model = rt.InferenceSession( | |
| str(model_path), | |
| providers=[("CUDAExecutionProvider", {}), "CPUExecutionProvider"], | |
| ) | |
| return model | |