Spaces:
Running
Running
| import json | |
| import logging | |
| import mlflow | |
| from typing import Optional | |
| logger = logging.getLogger(__name__) | |
| class MLflowTracker: | |
| def __init__(self, tracking_uri: str, experiment_name: str): | |
| self.tracking_uri = tracking_uri | |
| self.experiment_name = experiment_name | |
| try: | |
| mlflow.set_tracking_uri(self.tracking_uri) | |
| mlflow.set_experiment(self.experiment_name) | |
| except Exception as e: | |
| logger.warning(f"Failed to initialize MLflow Tracking at {tracking_uri}: {e}") | |
| def log_interaction(self, query: str, answer: str, chunks: list[dict], latency_ms: int, ragas_scores: dict) -> None: | |
| """ | |
| Logs to the current active run. Each call is a new child run under the parent experiment. | |
| """ | |
| try: | |
| with mlflow.start_run(nested=True, run_name="Query_Interaction"): | |
| mlflow.log_param("query", query) | |
| # We can log answer as an artifact or param (param limited to 250 chars) | |
| # It's better to log extensive text into artifacts | |
| interaction_data = { | |
| "query": query, | |
| "answer": answer, | |
| "chunks_used": [c.get("metadata", {}).get("doc_id", "unknown") for c in chunks] | |
| } | |
| # MLflow metrics | |
| mlflow.log_metric("latency_ms", latency_ms) | |
| if ragas_scores: | |
| for metric_name, score in ragas_scores.items(): | |
| mlflow.log_metric(f"ragas_{metric_name}", score) | |
| # Save interaction JSON locally and log as artifact | |
| import tempfile | |
| with tempfile.NamedTemporaryFile("w+", suffix=".json", delete=False) as f: | |
| json.dump(interaction_data, f) | |
| temp_path = f.name | |
| mlflow.log_artifact(temp_path, "interaction_details") | |
| import os | |
| os.unlink(temp_path) | |
| except Exception as e: | |
| logger.error(f"Failed to log interaction to MLflow: {e}") | |
| def log_eval_suite(self, results: dict, filepath: str) -> None: | |
| """ | |
| Logs full eval suite results as a run with metric history. | |
| Saves eval JSON as artifact. | |
| """ | |
| try: | |
| with mlflow.start_run(run_name="Evaluation_Suite"): | |
| # Log top level metrics | |
| ragas = results.get("ragas", {}) | |
| for k, v in ragas.items(): | |
| mlflow.log_metric(f"suite_{k}", float(v)) | |
| custom = results.get("custom", {}) | |
| for k, v in custom.items(): | |
| mlflow.log_metric(f"suite_{k}", float(v)) | |
| # Log the artifact file directly | |
| mlflow.log_artifact(filepath, "evaluation_reports") | |
| logger.info("Evaluation Suite saved successfully into MLflow.") | |
| except Exception as e: | |
| logger.error(f"Failed to log eval suite to MLflow: {e}") | |
| def compare_reranker_runs(self, run_id_old: str, run_id_new: str) -> bool: | |
| """ | |
| Returns True if new run's MAP@10 > old run by > 0.02. | |
| Queries the MLflow API for run records. | |
| """ | |
| try: | |
| client = mlflow.tracking.MlflowClient(self.tracking_uri) | |
| old_run = client.get_run(run_id_old) | |
| new_run = client.get_run(run_id_new) | |
| old_map = old_run.data.metrics.get("map_at_10", 0.0) | |
| new_map = new_run.data.metrics.get("map_at_10", 0.0) | |
| return new_map > old_map + 0.02 | |
| except Exception as e: | |
| logger.error(f"Failed comparing MLflow runs {run_id_old} / {run_id_new}: {e}") | |
| return False | |