|
|
"""Synchronise champion MLflow models from the remote registry to the local filesystem.""" |
|
|
|
|
|
import logging |
|
|
import os |
|
|
from pathlib import Path |
|
|
import shutil |
|
|
|
|
|
import mlflow |
|
|
from mlflow.tracking import MlflowClient |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
LANGUAGES = ("python", "java", "pharo") |
|
|
|
|
|
|
|
|
def _get_mlflow_client() -> MlflowClient: |
|
|
"""Return an MLflow client configured from environment variables. |
|
|
|
|
|
If ``MLFLOW_TRACKING_URI`` is defined, it is passed to |
|
|
:func:`mlflow.set_tracking_uri`. Authentication (for example on DagsHub) |
|
|
is handled by MLflow itself via the standard environment variables |
|
|
``MLFLOW_TRACKING_USERNAME`` and ``MLFLOW_TRACKING_PASSWORD``. |
|
|
""" |
|
|
tracking_uri = os.getenv("MLFLOW_TRACKING_URI") |
|
|
if tracking_uri: |
|
|
mlflow.set_tracking_uri(tracking_uri) |
|
|
return MlflowClient() |
|
|
|
|
|
|
|
|
def _find_champion_version_for_language( |
|
|
client: MlflowClient, |
|
|
lang: str, |
|
|
): |
|
|
"""Return the champion model version for the given language, if any. |
|
|
|
|
|
The function searches all registered models and looks for models whose name |
|
|
starts with ``"<lang>-"`` (for example ``"python-transformer"``). For each |
|
|
matching model it tries to resolve the alias ``"<lang>-champion"`` using |
|
|
:meth:`MlflowClient.get_model_version_by_alias`. |
|
|
|
|
|
Args: |
|
|
client: Initialised MLflow client. |
|
|
lang: Language identifier, such as ``"python"``, ``"java"`` or |
|
|
``"pharo"``. |
|
|
|
|
|
Returns: |
|
|
The matching :class:`mlflow.entities.model_registry.ModelVersion` if a |
|
|
champion is found, otherwise ``None``. |
|
|
|
|
|
""" |
|
|
alias_name = f"{lang}-champion" |
|
|
prefix = f"{lang}-" |
|
|
|
|
|
|
|
|
for rm in client.search_registered_models(): |
|
|
model_name = rm.name |
|
|
if not model_name.startswith(prefix): |
|
|
continue |
|
|
|
|
|
try: |
|
|
mv = client.get_model_version_by_alias( |
|
|
name=model_name, |
|
|
alias=alias_name, |
|
|
) |
|
|
logger.info( |
|
|
"Found champion model for %s: %s (version %s)", |
|
|
lang, |
|
|
model_name, |
|
|
mv.version, |
|
|
) |
|
|
return mv |
|
|
except Exception: |
|
|
logger.info("Alias not defined for model %s, trying next one.", model_name) |
|
|
continue |
|
|
|
|
|
logger.warning("No champion model found for %s.", lang) |
|
|
return None |
|
|
|
|
|
|
|
|
def sync_best_models_to_disk( |
|
|
models_root: str | Path = "models", |
|
|
api_subdir: str = "api", |
|
|
) -> None: |
|
|
"""Download champion models from MLflow and write them to disk. |
|
|
|
|
|
For each language in :data:`LANGUAGES`, this function looks up the model |
|
|
version with alias ``"<lang>-champion"`` and downloads its artifacts. After |
|
|
download, the directory structure is normalised so that the final layout is: |
|
|
|
|
|
.. code-block:: text |
|
|
|
|
|
models/ |
|
|
<api_subdir>/ |
|
|
python/ |
|
|
<model_type>/ |
|
|
... |
|
|
java/ |
|
|
<model_type>/ |
|
|
... |
|
|
pharo/ |
|
|
<model_type>/ |
|
|
... |
|
|
|
|
|
For transformer models logged via ``mlflow.transformers``, the inner |
|
|
``model/`` directory is flattened so that the Hugging Face files |
|
|
(``config.json``, ``model.safetensors``, ``tokenizer.json``, and so on) |
|
|
live directly under ``<model_type>/``. |
|
|
|
|
|
Args: |
|
|
models_root: Base directory under which models are written. Can be a |
|
|
string or :class:`pathlib.Path`. Defaults to ``"models"``. |
|
|
api_subdir: Optional subdirectory appended under ``models_root`` (for |
|
|
example ``"api"``). If empty, models are stored directly under |
|
|
``models_root``. |
|
|
|
|
|
Raises: |
|
|
OSError: If creating directories, moving files, or removing directories |
|
|
fails at the OS level. |
|
|
|
|
|
""" |
|
|
client = _get_mlflow_client() |
|
|
|
|
|
root = Path(models_root) |
|
|
if api_subdir: |
|
|
root = root / api_subdir |
|
|
root.mkdir(parents=True, exist_ok=True) |
|
|
logger.info("Syncing best models to: %s", root.resolve()) |
|
|
|
|
|
for lang in LANGUAGES: |
|
|
mv = _find_champion_version_for_language(client, lang) |
|
|
if mv is None: |
|
|
continue |
|
|
|
|
|
model_name = mv.name |
|
|
try: |
|
|
lang_from_name, model_type = model_name.split("-", 1) |
|
|
except ValueError: |
|
|
logger.error("Unexpected model name format: %s", model_name) |
|
|
continue |
|
|
|
|
|
if lang_from_name != lang: |
|
|
logger.warning( |
|
|
"Language mismatch for model %s: expected %s, got %s", |
|
|
model_name, |
|
|
lang, |
|
|
lang_from_name, |
|
|
) |
|
|
|
|
|
dest_dir = root / lang / model_type |
|
|
if dest_dir.exists(): |
|
|
shutil.rmtree(dest_dir) |
|
|
dest_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
logger.info( |
|
|
"Downloading model '%s' version %s to %s...", |
|
|
model_name, |
|
|
mv.version, |
|
|
dest_dir.resolve(), |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
downloaded_path = Path( |
|
|
mlflow.artifacts.download_artifacts( |
|
|
artifact_uri=mv.source, |
|
|
dst_path=str(dest_dir), |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
model_subdir = downloaded_path / "model" |
|
|
if model_subdir.is_dir(): |
|
|
|
|
|
for item in model_subdir.iterdir(): |
|
|
shutil.move(str(item), dest_dir / item.name) |
|
|
|
|
|
|
|
|
if downloaded_path != dest_dir: |
|
|
shutil.rmtree(downloaded_path) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error( |
|
|
"Failed to download/reshape model '%s' version %s: %s", |
|
|
model_name, |
|
|
mv.version, |
|
|
e, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
sync_best_models_to_disk() |
|
|
|