| | """Script to register model in MLflow model registry.""" |
| |
|
| | import logging |
| | import argparse |
| | from pathlib import Path |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def register_model( |
| | model_path: str, |
| | model_name: str, |
| | run_id: Optional[str] = None, |
| | tracking_uri: Optional[str] = None, |
| | tags: Optional[dict] = None, |
| | ) -> None: |
| | """ |
| | Register model in MLflow model registry. |
| | |
| | Args: |
| | model_path: Path to model file or MLflow run URI |
| | model_name: Name for model in registry |
| | run_id: MLflow run ID (if model_path is not a URI) |
| | tracking_uri: MLflow tracking URI |
| | tags: Dictionary of tags |
| | """ |
| | try: |
| | import mlflow |
| | import mlflow.pytorch |
| | except ImportError: |
| | raise ImportError("mlflow not installed. Install with: pip install mlflow") |
| | |
| | |
| | if tracking_uri: |
| | mlflow.set_tracking_uri(tracking_uri) |
| | |
| | |
| | if model_path.startswith("runs:/"): |
| | model_uri = model_path |
| | elif run_id: |
| | model_uri = f"runs:/{run_id}/{model_path}" |
| | else: |
| | |
| | model_uri = model_path |
| | |
| | |
| | logger.info(f"Registering model: {model_name}") |
| | logger.info(f"Model URI: {model_uri}") |
| | |
| | try: |
| | mlflow.register_model( |
| | model_uri=model_uri, |
| | name=model_name, |
| | tags=tags or {}, |
| | ) |
| | logger.info(f"Model '{model_name}' registered successfully!") |
| | except Exception as e: |
| | logger.error(f"Failed to register model: {e}") |
| | raise |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Register model in MLflow") |
| | parser.add_argument( |
| | "--model-path", |
| | type=str, |
| | required=True, |
| | help="Path to model or MLflow run URI (runs:/<run_id>/model)" |
| | ) |
| | parser.add_argument( |
| | "--model-name", |
| | type=str, |
| | required=True, |
| | help="Name for model in registry" |
| | ) |
| | parser.add_argument( |
| | "--run-id", |
| | type=str, |
| | default=None, |
| | help="MLflow run ID (if model_path is not a URI)" |
| | ) |
| | parser.add_argument( |
| | "--tracking-uri", |
| | type=str, |
| | default=None, |
| | help="MLflow tracking URI" |
| | ) |
| | |
| | args = parser.parse_args() |
| | |
| | register_model( |
| | model_path=args.model_path, |
| | model_name=args.model_name, |
| | run_id=args.run_id, |
| | tracking_uri=args.tracking_uri, |
| | ) |
| |
|
| |
|