Spaces:
Sleeping
Sleeping
| """Point d'entree CLI pour la brique de simulation locale P2/P3.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| import sys | |
| import mlflow | |
| from mlflow.tracking import MlflowClient | |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] | |
| if str(PROJECT_ROOT) not in sys.path: | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| from scripts.mlflow_logging import log_and_register_sklearn_model | |
| from scripts.mlflow_config import ( | |
| SIMULATION_RUNTIME_EXPERIMENT_NAME, | |
| experiment_artifact_location, | |
| normalize_tracking_uri, | |
| ) | |
| from scripts.pipeline_utils import ensure_paths_exist, relative_to_project | |
| from scripts.prediction_adjustment import ( | |
| SIMULATION_METADATA_PATH, | |
| SIMULATION_MODEL_PATH, | |
| load_or_train_simulation_model, | |
| ) | |
| from scripts.runtime_model_specs import ( | |
| DEFAULT_MLFLOW_TRACKING_URI, | |
| SIMULATION_RUNTIME_MODEL_SPEC, | |
| ) | |
| SIMULATION_OUTPUTS = [ | |
| SIMULATION_MODEL_PATH, | |
| SIMULATION_METADATA_PATH, | |
| ] | |
| SIMULATION_MLFLOW_EXPERIMENT_NAME = SIMULATION_RUNTIME_EXPERIMENT_NAME | |
| def parse_args() -> argparse.Namespace: | |
| """Construit l'interface en ligne de commande du script.""" | |
| parser = argparse.ArgumentParser( | |
| description="Load or retrain the local simulation model used for the P2/P3 adjustment.", | |
| ) | |
| parser.add_argument( | |
| "--force-retrain", | |
| action="store_true", | |
| help="Retrain the simulation model even if artifacts already exist.", | |
| ) | |
| parser.add_argument( | |
| "--sample-size", | |
| type=int, | |
| default=200_000, | |
| help="Maximum number of rows sampled during training.", | |
| ) | |
| parser.add_argument( | |
| "--no-save", | |
| action="store_true", | |
| help="Train in memory without rewriting the model artifacts.", | |
| ) | |
| parser.add_argument( | |
| "--tracking-uri", | |
| default=DEFAULT_MLFLOW_TRACKING_URI, | |
| help="Tracking URI MLflow utilise pour journaliser et enregistrer le modele.", | |
| ) | |
| return parser.parse_args() | |
| def _ensure_simulation_mlflow_experiment(tracking_uri: str) -> None: | |
| """Initialise l'experiment MLflow utilise par la brique de simulation.""" | |
| tracking_uri = normalize_tracking_uri(tracking_uri) | |
| mlflow.set_tracking_uri(tracking_uri) | |
| client = MlflowClient(tracking_uri=tracking_uri) | |
| experiment = client.get_experiment_by_name(SIMULATION_MLFLOW_EXPERIMENT_NAME) | |
| if experiment is None: | |
| client.create_experiment( | |
| SIMULATION_MLFLOW_EXPERIMENT_NAME, | |
| artifact_location=experiment_artifact_location( | |
| SIMULATION_MLFLOW_EXPERIMENT_NAME, | |
| tracking_uri=tracking_uri, | |
| ), | |
| ) | |
| mlflow.set_experiment(SIMULATION_MLFLOW_EXPERIMENT_NAME) | |
| def _register_simulation_runtime_model( | |
| *, | |
| loaded_model, | |
| tracking_uri: str, | |
| ) -> dict[str, str]: | |
| """Journalise et enregistre le modele local comme registered model MLflow.""" | |
| _ensure_simulation_mlflow_experiment(tracking_uri) | |
| metrics = loaded_model.metadata.get("metrics", {}) | |
| with mlflow.start_run(run_name=f"{SIMULATION_MLFLOW_EXPERIMENT_NAME}__runtime_model"): | |
| mlflow.log_param("runtime_model_role", SIMULATION_RUNTIME_MODEL_SPEC.role) | |
| mlflow.log_param("registered_model_name", SIMULATION_RUNTIME_MODEL_SPEC.registered_model_name) | |
| mlflow.log_param("training_entrypoint", "scripts/train_simulation_model.py") | |
| mlflow.log_param("model_name", loaded_model.metadata.get("model_name")) | |
| mlflow.log_param("dataset_source", loaded_model.metadata.get("dataset_source")) | |
| mlflow.log_param("sample_size", loaded_model.metadata.get("sample_size")) | |
| for metric_name, metric_value in metrics.items(): | |
| if metric_value is not None: | |
| mlflow.log_metric(metric_name, float(metric_value)) | |
| return log_and_register_sklearn_model( | |
| loaded_model.pipeline, | |
| artifact_name=SIMULATION_RUNTIME_MODEL_SPEC.registered_model_name, | |
| registered_model_name=SIMULATION_RUNTIME_MODEL_SPEC.registered_model_name, | |
| model_metadata={ | |
| "runtime_model_role": SIMULATION_RUNTIME_MODEL_SPEC.role, | |
| "training_entrypoint": "scripts/train_simulation_model.py", | |
| }, | |
| ) | |
| def train_simulation_model( | |
| *, | |
| force_retrain: bool = False, | |
| save_artifact: bool = True, | |
| sample_size: int = 200_000, | |
| tracking_uri: str = DEFAULT_MLFLOW_TRACKING_URI, | |
| ) -> dict[str, object]: | |
| """Charge ou reentraine le modele local de simulation. | |
| Args: | |
| force_retrain: Force le reentrainement meme si les artefacts existent. | |
| save_artifact: Ecrit les artefacts sur disque si `True`. | |
| sample_size: Nombre maximal de lignes echantillonnees pour l'entrainement. | |
| tracking_uri: Tracking URI MLflow utilise pour le registry. | |
| Returns: | |
| dict[str, object]: Resume du dataset utilise, des metriques et des sorties. | |
| """ | |
| tracking_uri = normalize_tracking_uri(tracking_uri) | |
| reused_existing_artifact = ( | |
| not force_retrain | |
| and SIMULATION_MODEL_PATH.exists() | |
| and SIMULATION_METADATA_PATH.exists() | |
| ) | |
| loaded_model, simulation_df = load_or_train_simulation_model( | |
| force_retrain=force_retrain, | |
| save_artifact=save_artifact, | |
| sample_size=sample_size, | |
| ) | |
| registration = _register_simulation_runtime_model( | |
| loaded_model=loaded_model, | |
| tracking_uri=tracking_uri, | |
| ) | |
| loaded_model.metadata.update( | |
| { | |
| "runtime_model_role": SIMULATION_RUNTIME_MODEL_SPEC.role, | |
| "registered_model_name": registration["registered_model_name"], | |
| "registered_model_version": registration["registered_model_version"], | |
| "registered_model_run_id": registration["run_id"], | |
| "model_uri": registration["model_uri"], | |
| } | |
| ) | |
| output_paths: list[str] = [] | |
| if save_artifact: | |
| SIMULATION_METADATA_PATH.write_text( | |
| json.dumps(loaded_model.metadata, indent=2, ensure_ascii=True), | |
| encoding="utf-8", | |
| ) | |
| resolved_outputs = ensure_paths_exist(SIMULATION_OUTPUTS, label="simulation model outputs") | |
| output_paths = [relative_to_project(path) for path in resolved_outputs] | |
| metrics = loaded_model.metadata.get("metrics", {}) | |
| print( | |
| "[simulation] Model ready " | |
| f"(sample_size={loaded_model.metadata.get('sample_size')}, " | |
| f"test_rmse={metrics.get('test_rmse')}, test_r2={metrics.get('test_r2')})" | |
| ) | |
| return { | |
| "dataset_rows": int(len(simulation_df)), | |
| "sample_size": loaded_model.metadata.get("sample_size"), | |
| "artifact_source": "reused_existing" if reused_existing_artifact else "retrained", | |
| "registered_model_name": registration["registered_model_name"], | |
| "registered_model_version": registration["registered_model_version"], | |
| "registered_model_run_id": registration["run_id"], | |
| "model_uri": registration["model_uri"], | |
| "metrics": metrics, | |
| "outputs": output_paths, | |
| } | |
| def main() -> None: | |
| """Execute le script de simulation depuis la CLI.""" | |
| args = parse_args() | |
| train_simulation_model( | |
| force_retrain=args.force_retrain, | |
| save_artifact=not args.no_save, | |
| sample_size=args.sample_size, | |
| tracking_uri=args.tracking_uri, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |