rendement_agricole / scripts /mlflow_logging.py
stephmnt's picture
Sync from GitHub via hub-sync
fa3d628 verified
"""Helpers MLflow pour journaliser des modeles et des predictions evaluees."""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Any
import mlflow
import mlflow.pyfunc
import mlflow.sklearn
import pandas as pd
from mlflow.tracking import MlflowClient
SKLEARN_PICKLE_WARNING_PREFIX = (
"Saving scikit-learn models in the pickle or cloudpickle format requires exercising caution"
)
class _SuppressSklearnPickleWarning(logging.Filter):
"""Filtre le warning MLflow repete sur la serialisation pickle/cloudpickle."""
def filter(self, record: logging.LogRecord) -> bool:
"""Retourne `False` uniquement pour le warning verbeux attendu."""
return SKLEARN_PICKLE_WARNING_PREFIX not in record.getMessage()
def configure_mlflow_sklearn_logging() -> None:
"""Rend les logs MLflow sklearn lisibles pendant les entrainements longs."""
logger = logging.getLogger("mlflow.sklearn")
if not any(isinstance(item, _SuppressSklearnPickleWarning) for item in logger.filters):
logger.addFilter(_SuppressSklearnPickleWarning())
configure_mlflow_sklearn_logging()
def sanitize_logged_model_name(raw_name: str) -> str:
"""Construit un nom de modele MLflow stable a partir d'un identifiant brut.
Args:
raw_name: Nom de run ou de modele d'origine.
Returns:
str: Nom nettoye et compatible avec MLflow.
"""
candidate = str(raw_name).strip()
if "::" in candidate:
candidate = candidate.rsplit("::", 1)[-1]
if "__" in candidate:
candidate = candidate.rsplit("__", 1)[-1]
cleaned = []
for char in candidate:
if char.isalnum() or char in {"_", "-", "."}:
cleaned.append(char)
else:
cleaned.append("_")
normalized = "".join(cleaned).strip("._-")
return normalized or "model"
def log_named_sklearn_model(estimator: Any, *, model_name: str) -> str:
"""Journalise un estimateur scikit-learn sous un nom MLflow stable.
Args:
estimator: Estimateur scikit-learn a enregistrer.
model_name: Nom descriptif du modele dans le run courant.
Returns:
str: Nom effectivement utilise dans MLflow.
"""
logged_model_name = sanitize_logged_model_name(model_name)
mlflow.sklearn.log_model(estimator, name=logged_model_name)
return logged_model_name
def _registered_model_version_sort_key(version: Any) -> tuple[int, str]:
"""Produit une cle de tri robuste pour les versions du registry MLflow."""
raw_version = str(getattr(version, "version", version))
return (int(raw_version), raw_version) if raw_version.isdigit() else (-1, raw_version)
def resolve_registered_model_version_for_run(
*,
registered_model_name: str,
run_id: str,
tracking_uri: str | None = None,
) -> Any:
"""Recupere la version du registry associee a un run MLflow donne.
Args:
registered_model_name: Nom du registered model a inspecter.
run_id: Identifiant du run source.
tracking_uri: Tracking URI MLflow optionnel.
Returns:
Any: Objet version retourne par le client MLflow.
"""
client = MlflowClient(tracking_uri=tracking_uri)
versions = [
version
for version in client.search_model_versions(f"name = '{registered_model_name}'")
if str(getattr(version, "run_id", "")) == str(run_id)
]
if not versions:
raise RuntimeError(
"Registered model version could not be resolved for "
f"model={registered_model_name!r} and run_id={run_id!r}."
)
return max(versions, key=_registered_model_version_sort_key)
def log_and_register_sklearn_model(
estimator: Any,
*,
artifact_name: str,
registered_model_name: str,
model_metadata: dict[str, Any] | None = None,
await_registration_for: int = 300,
) -> dict[str, str]:
"""Journalise un estimateur et l'enregistre comme registered model MLflow.
Args:
estimator: Estimateur scikit-learn a enregistrer.
artifact_name: Nom de l'artefact de run.
registered_model_name: Nom du registered model cible.
model_metadata: Metadonnees MLflow optionnelles.
await_registration_for: Duree d'attente maximale de l'enregistrement.
Returns:
dict[str, str]: Contexte de registry resolu apres l'enregistrement.
"""
active_run = mlflow.active_run()
if active_run is None:
raise RuntimeError("An active MLflow run is required before registering a model.")
logged_model_name = sanitize_logged_model_name(artifact_name)
model_info = mlflow.sklearn.log_model(
estimator,
name=logged_model_name,
registered_model_name=registered_model_name,
metadata=model_metadata,
await_registration_for=await_registration_for,
)
resolved_version = resolve_registered_model_version_for_run(
registered_model_name=registered_model_name,
run_id=active_run.info.run_id,
tracking_uri=mlflow.get_tracking_uri(),
)
return {
"logged_model_name": logged_model_name,
"registered_model_name": registered_model_name,
"registered_model_version": str(resolved_version.version),
"model_uri": f"models:/{registered_model_name}/{resolved_version.version}",
"run_id": active_run.info.run_id,
"logged_model_uri": str(getattr(model_info, "model_uri", "")),
}
class EvaluationPredictionLookupModel(mlflow.pyfunc.PythonModel):
"""MLflow pyfunc model exposing precomputed evaluation predictions by key lookup.
This is mainly intended for experiment-tracking consistency when a run does not
produce a single global fitted estimator, as in local time-series evaluations.
"""
def load_context(self, context) -> None:
"""Charge les artefacts CSV et JSON exposes au modele pyfunc."""
predictions_path = Path(context.artifacts["predictions"])
self.predictions_df = pd.read_csv(predictions_path)
spec_path = context.artifacts.get("specification")
if spec_path:
self.specification = json.loads(Path(spec_path).read_text(encoding="utf-8"))
else:
self.specification = {}
def predict(self, context: Any, model_input: pd.DataFrame) -> pd.DataFrame:
"""Restitue les predictions pre-calculees correspondant aux cles demandees.
Args:
context: Contexte pyfunc MLflow, inutilise ici.
model_input: Table contenant au minimum `area`, `crop` et `year`.
Returns:
pd.DataFrame: Entree enrichie avec les predictions journalisees.
"""
del context
input_df = model_input if isinstance(model_input, pd.DataFrame) else pd.DataFrame(model_input)
required_columns = ["area", "crop", "year"]
missing_columns = [column for column in required_columns if column not in input_df.columns]
if missing_columns:
raise ValueError(
"EvaluationPredictionLookupModel requires columns "
f"{required_columns}, missing {missing_columns}."
)
lookup_columns = required_columns.copy()
predictions_df = self.predictions_df.copy()
if "split" in input_df.columns and "split" in predictions_df.columns:
lookup_columns.append("split")
elif "split" in predictions_df.columns:
predictions_df = predictions_df.sort_values(lookup_columns + ["split"]).drop_duplicates(
subset=lookup_columns,
keep="last",
)
predictions_df["year"] = predictions_df["year"].astype(input_df["year"].dtype, copy=False)
merged = input_df.merge(
predictions_df,
on=lookup_columns,
how="left",
suffixes=("", "_logged"),
)
return merged
def log_evaluation_lookup_model(
*,
model_name: str,
predictions_path: Path | str,
specification_path: Path | str,
model_metadata: dict[str, Any] | None = None,
) -> str:
"""Journalise un pyfunc MLflow adosse a des predictions pre-calculees.
Args:
model_name: Nom du modele a creer dans MLflow.
predictions_path: CSV des predictions evaluees.
specification_path: JSON de specification associe au modele.
model_metadata: Metadonnees complementaires a propager.
Returns:
str: Nom effectivement utilise dans MLflow.
"""
logged_model_name = sanitize_logged_model_name(model_name)
predictions_path = Path(predictions_path).resolve()
specification_path = Path(specification_path).resolve()
input_example_df = pd.read_csv(predictions_path, nrows=1)
input_example_columns = [column for column in ["area", "crop", "year", "split"] if column in input_example_df.columns]
input_example = input_example_df[input_example_columns] if input_example_columns else None
metadata = {
"model_kind": "evaluation_prediction_lookup",
"source_predictions_file": predictions_path.name,
"source_specification_file": specification_path.name,
}
if model_metadata:
metadata.update(model_metadata)
mlflow.pyfunc.log_model(
name=logged_model_name,
python_model=EvaluationPredictionLookupModel(),
artifacts={
"predictions": str(predictions_path),
"specification": str(specification_path),
},
input_example=input_example,
metadata=metadata,
)
return logged_model_name