Spaces:
Runtime error
Runtime error
| """Utils""" | |
| from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| from typing import Literal | |
| from loguru import logger | |
| def download_model( | |
| model_name: str, | |
| model_stage: Literal["staging", "production"], | |
| model_dir: str | Path = "model", | |
| ) -> Path: | |
| """Download model from mlflow""" | |
| import mlflow.artifacts | |
| import mlflow.models | |
| from mlflow.client import MlflowClient | |
| logger.info(f"Looking for model {model_name}/{model_stage}") | |
| if isinstance(model_dir, str): | |
| model_dir = Path(model_dir) | |
| client = MlflowClient() | |
| model_versions = client.get_latest_versions(model_name, stages=[model_stage]) | |
| if len(model_versions) != 1: | |
| raise ValueError(f"No model version for {model_name}/{model_stage}") | |
| artifact_uri = model_versions[0].source | |
| model_version = model_versions[0].version | |
| logger.info(f"Found version {model_version} for {model_name}/{model_stage}") | |
| model_path = model_dir / artifact_uri.split("/")[-1] # type: ignore | |
| if model_path.exists(): | |
| logger.info(f"Found model in {model_path}, skipping download") | |
| return model_path | |
| logger.info(f"Downloading artifacts {artifact_uri} to {model_dir}") | |
| model_path = mlflow.artifacts.download_artifacts(artifact_uri, dst_path=str(model_dir)) | |
| logger.info(f"Succesfully downloaded {model_name}") | |
| model_info = mlflow.models.get_model_info(model_path) | |
| metadata = model_info.metadata | |
| metadata_path = Path(model_path) / "metadata.json" | |
| logger.info(f"Saving metadata to {metadata_path}") | |
| with open(metadata_path, "w", encoding="utf-8") as file: | |
| json.dump(metadata, file) | |
| return Path(model_path) | |