rendement_agricole / scripts /train_simulation_model.py
stephmnt's picture
Sync from GitHub via hub-sync
fa3d628 verified
"""Point d'entree CLI pour la brique de simulation locale P2/P3."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import sys
import mlflow
from mlflow.tracking import MlflowClient
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from scripts.mlflow_logging import log_and_register_sklearn_model
from scripts.mlflow_config import (
SIMULATION_RUNTIME_EXPERIMENT_NAME,
experiment_artifact_location,
normalize_tracking_uri,
)
from scripts.pipeline_utils import ensure_paths_exist, relative_to_project
from scripts.prediction_adjustment import (
SIMULATION_METADATA_PATH,
SIMULATION_MODEL_PATH,
load_or_train_simulation_model,
)
from scripts.runtime_model_specs import (
DEFAULT_MLFLOW_TRACKING_URI,
SIMULATION_RUNTIME_MODEL_SPEC,
)
SIMULATION_OUTPUTS = [
SIMULATION_MODEL_PATH,
SIMULATION_METADATA_PATH,
]
SIMULATION_MLFLOW_EXPERIMENT_NAME = SIMULATION_RUNTIME_EXPERIMENT_NAME
def parse_args() -> argparse.Namespace:
"""Construit l'interface en ligne de commande du script."""
parser = argparse.ArgumentParser(
description="Load or retrain the local simulation model used for the P2/P3 adjustment.",
)
parser.add_argument(
"--force-retrain",
action="store_true",
help="Retrain the simulation model even if artifacts already exist.",
)
parser.add_argument(
"--sample-size",
type=int,
default=200_000,
help="Maximum number of rows sampled during training.",
)
parser.add_argument(
"--no-save",
action="store_true",
help="Train in memory without rewriting the model artifacts.",
)
parser.add_argument(
"--tracking-uri",
default=DEFAULT_MLFLOW_TRACKING_URI,
help="Tracking URI MLflow utilise pour journaliser et enregistrer le modele.",
)
return parser.parse_args()
def _ensure_simulation_mlflow_experiment(tracking_uri: str) -> None:
"""Initialise l'experiment MLflow utilise par la brique de simulation."""
tracking_uri = normalize_tracking_uri(tracking_uri)
mlflow.set_tracking_uri(tracking_uri)
client = MlflowClient(tracking_uri=tracking_uri)
experiment = client.get_experiment_by_name(SIMULATION_MLFLOW_EXPERIMENT_NAME)
if experiment is None:
client.create_experiment(
SIMULATION_MLFLOW_EXPERIMENT_NAME,
artifact_location=experiment_artifact_location(
SIMULATION_MLFLOW_EXPERIMENT_NAME,
tracking_uri=tracking_uri,
),
)
mlflow.set_experiment(SIMULATION_MLFLOW_EXPERIMENT_NAME)
def _register_simulation_runtime_model(
*,
loaded_model,
tracking_uri: str,
) -> dict[str, str]:
"""Journalise et enregistre le modele local comme registered model MLflow."""
_ensure_simulation_mlflow_experiment(tracking_uri)
metrics = loaded_model.metadata.get("metrics", {})
with mlflow.start_run(run_name=f"{SIMULATION_MLFLOW_EXPERIMENT_NAME}__runtime_model"):
mlflow.log_param("runtime_model_role", SIMULATION_RUNTIME_MODEL_SPEC.role)
mlflow.log_param("registered_model_name", SIMULATION_RUNTIME_MODEL_SPEC.registered_model_name)
mlflow.log_param("training_entrypoint", "scripts/train_simulation_model.py")
mlflow.log_param("model_name", loaded_model.metadata.get("model_name"))
mlflow.log_param("dataset_source", loaded_model.metadata.get("dataset_source"))
mlflow.log_param("sample_size", loaded_model.metadata.get("sample_size"))
for metric_name, metric_value in metrics.items():
if metric_value is not None:
mlflow.log_metric(metric_name, float(metric_value))
return log_and_register_sklearn_model(
loaded_model.pipeline,
artifact_name=SIMULATION_RUNTIME_MODEL_SPEC.registered_model_name,
registered_model_name=SIMULATION_RUNTIME_MODEL_SPEC.registered_model_name,
model_metadata={
"runtime_model_role": SIMULATION_RUNTIME_MODEL_SPEC.role,
"training_entrypoint": "scripts/train_simulation_model.py",
},
)
def train_simulation_model(
*,
force_retrain: bool = False,
save_artifact: bool = True,
sample_size: int = 200_000,
tracking_uri: str = DEFAULT_MLFLOW_TRACKING_URI,
) -> dict[str, object]:
"""Charge ou reentraine le modele local de simulation.
Args:
force_retrain: Force le reentrainement meme si les artefacts existent.
save_artifact: Ecrit les artefacts sur disque si `True`.
sample_size: Nombre maximal de lignes echantillonnees pour l'entrainement.
tracking_uri: Tracking URI MLflow utilise pour le registry.
Returns:
dict[str, object]: Resume du dataset utilise, des metriques et des sorties.
"""
tracking_uri = normalize_tracking_uri(tracking_uri)
reused_existing_artifact = (
not force_retrain
and SIMULATION_MODEL_PATH.exists()
and SIMULATION_METADATA_PATH.exists()
)
loaded_model, simulation_df = load_or_train_simulation_model(
force_retrain=force_retrain,
save_artifact=save_artifact,
sample_size=sample_size,
)
registration = _register_simulation_runtime_model(
loaded_model=loaded_model,
tracking_uri=tracking_uri,
)
loaded_model.metadata.update(
{
"runtime_model_role": SIMULATION_RUNTIME_MODEL_SPEC.role,
"registered_model_name": registration["registered_model_name"],
"registered_model_version": registration["registered_model_version"],
"registered_model_run_id": registration["run_id"],
"model_uri": registration["model_uri"],
}
)
output_paths: list[str] = []
if save_artifact:
SIMULATION_METADATA_PATH.write_text(
json.dumps(loaded_model.metadata, indent=2, ensure_ascii=True),
encoding="utf-8",
)
resolved_outputs = ensure_paths_exist(SIMULATION_OUTPUTS, label="simulation model outputs")
output_paths = [relative_to_project(path) for path in resolved_outputs]
metrics = loaded_model.metadata.get("metrics", {})
print(
"[simulation] Model ready "
f"(sample_size={loaded_model.metadata.get('sample_size')}, "
f"test_rmse={metrics.get('test_rmse')}, test_r2={metrics.get('test_r2')})"
)
return {
"dataset_rows": int(len(simulation_df)),
"sample_size": loaded_model.metadata.get("sample_size"),
"artifact_source": "reused_existing" if reused_existing_artifact else "retrained",
"registered_model_name": registration["registered_model_name"],
"registered_model_version": registration["registered_model_version"],
"registered_model_run_id": registration["run_id"],
"model_uri": registration["model_uri"],
"metrics": metrics,
"outputs": output_paths,
}
def main() -> None:
"""Execute le script de simulation depuis la CLI."""
args = parse_args()
train_simulation_model(
force_retrain=args.force_retrain,
save_artifact=not args.no_save,
sample_size=args.sample_size,
tracking_uri=args.tracking_uri,
)
if __name__ == "__main__":
main()