"""Synchronise champion MLflow models from the remote registry to the local filesystem.""" import logging import os from pathlib import Path import shutil import mlflow from mlflow.tracking import MlflowClient logger = logging.getLogger(__name__) LANGUAGES = ("python", "java", "pharo") def _get_mlflow_client() -> MlflowClient: """Return an MLflow client configured from environment variables. If ``MLFLOW_TRACKING_URI`` is defined, it is passed to :func:`mlflow.set_tracking_uri`. Authentication (for example on DagsHub) is handled by MLflow itself via the standard environment variables ``MLFLOW_TRACKING_USERNAME`` and ``MLFLOW_TRACKING_PASSWORD``. """ tracking_uri = os.getenv("MLFLOW_TRACKING_URI") if tracking_uri: mlflow.set_tracking_uri(tracking_uri) return MlflowClient() def _find_champion_version_for_language( client: MlflowClient, lang: str, ): """Return the champion model version for the given language, if any. The function searches all registered models and looks for models whose name starts with ``"-"`` (for example ``"python-transformer"``). For each matching model it tries to resolve the alias ``"-champion"`` using :meth:`MlflowClient.get_model_version_by_alias`. Args: client: Initialised MLflow client. lang: Language identifier, such as ``"python"``, ``"java"`` or ``"pharo"``. Returns: The matching :class:`mlflow.entities.model_registry.ModelVersion` if a champion is found, otherwise ``None``. """ alias_name = f"{lang}-champion" prefix = f"{lang}-" # Get all registered models and filter by language prefix. for rm in client.search_registered_models(): model_name = rm.name if not model_name.startswith(prefix): continue try: mv = client.get_model_version_by_alias( name=model_name, alias=alias_name, ) logger.info( "Found champion model for %s: %s (version %s)", lang, model_name, mv.version, ) return mv except Exception: # noqa: BLE001 logger.info("Alias not defined for model %s, trying next one.", model_name) continue logger.warning("No champion model found for %s.", lang) return None def sync_best_models_to_disk( models_root: str | Path = "models", api_subdir: str = "api", ) -> None: """Download champion models from MLflow and write them to disk. For each language in :data:`LANGUAGES`, this function looks up the model version with alias ``"-champion"`` and downloads its artifacts. After download, the directory structure is normalised so that the final layout is: .. code-block:: text models/ / python/ / ... java/ / ... pharo/ / ... For transformer models logged via ``mlflow.transformers``, the inner ``model/`` directory is flattened so that the Hugging Face files (``config.json``, ``model.safetensors``, ``tokenizer.json``, and so on) live directly under ``/``. Args: models_root: Base directory under which models are written. Can be a string or :class:`pathlib.Path`. Defaults to ``"models"``. api_subdir: Optional subdirectory appended under ``models_root`` (for example ``"api"``). If empty, models are stored directly under ``models_root``. Raises: OSError: If creating directories, moving files, or removing directories fails at the OS level. """ client = _get_mlflow_client() root = Path(models_root) if api_subdir: root = root / api_subdir root.mkdir(parents=True, exist_ok=True) logger.info("Syncing best models to: %s", root.resolve()) for lang in LANGUAGES: mv = _find_champion_version_for_language(client, lang) if mv is None: continue model_name = mv.name try: lang_from_name, model_type = model_name.split("-", 1) except ValueError: logger.error("Unexpected model name format: %s", model_name) continue if lang_from_name != lang: logger.warning( "Language mismatch for model %s: expected %s, got %s", model_name, lang, lang_from_name, ) dest_dir = root / lang / model_type if dest_dir.exists(): shutil.rmtree(dest_dir) dest_dir.mkdir(parents=True, exist_ok=True) logger.info( "Downloading model '%s' version %s to %s...", model_name, mv.version, dest_dir.resolve(), ) try: # Download the artifact (for example ".../java_transformer_model"). downloaded_path = Path( mlflow.artifacts.download_artifacts( artifact_uri=mv.source, dst_path=str(dest_dir), ), ) # For transformer models logged with mlflow.transformers, artifacts # are stored under an inner "model/" directory. model_subdir = downloaded_path / "model" if model_subdir.is_dir(): # Move the contents of "model" directly into dest_dir. for item in model_subdir.iterdir(): shutil.move(str(item), dest_dir / item.name) # Remove the wrapper directory (with MLmodel, conda.yaml, etc.). if downloaded_path != dest_dir: shutil.rmtree(downloaded_path) except Exception as e: logger.error( "Failed to download/reshape model '%s' version %s: %s", model_name, mv.version, e, ) if __name__ == "__main__": logging.basicConfig(level=logging.INFO) sync_best_models_to_disk()