File size: 4,017 Bytes
bbe01fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import json
import logging
import mlflow
from typing import Optional

logger = logging.getLogger(__name__)


class MLflowTracker:
    def __init__(self, tracking_uri: str, experiment_name: str):
        self.tracking_uri = tracking_uri
        self.experiment_name = experiment_name
        
        try:
             mlflow.set_tracking_uri(self.tracking_uri)
             mlflow.set_experiment(self.experiment_name)
        except Exception as e:
             logger.warning(f"Failed to initialize MLflow Tracking at {tracking_uri}: {e}")

    def log_interaction(self, query: str, answer: str, chunks: list[dict], latency_ms: int, ragas_scores: dict) -> None:
        """
        Logs to the current active run. Each call is a new child run under the parent experiment.
        """
        try:
            with mlflow.start_run(nested=True, run_name="Query_Interaction"):
                mlflow.log_param("query", query)
                
                # We can log answer as an artifact or param (param limited to 250 chars)
                # It's better to log extensive text into artifacts
                interaction_data = {
                     "query": query,
                     "answer": answer,
                     "chunks_used": [c.get("metadata", {}).get("doc_id", "unknown") for c in chunks]
                }
                
                # MLflow metrics
                mlflow.log_metric("latency_ms", latency_ms)
                
                if ragas_scores:
                     for metric_name, score in ragas_scores.items():
                          mlflow.log_metric(f"ragas_{metric_name}", score)
                          
                # Save interaction JSON locally and log as artifact
                import tempfile
                with tempfile.NamedTemporaryFile("w+", suffix=".json", delete=False) as f:
                     json.dump(interaction_data, f)
                     temp_path = f.name
                     
                mlflow.log_artifact(temp_path, "interaction_details")
                
                import os
                os.unlink(temp_path)
                
        except Exception as e:
            logger.error(f"Failed to log interaction to MLflow: {e}")

    def log_eval_suite(self, results: dict, filepath: str) -> None:
        """
        Logs full eval suite results as a run with metric history.
        Saves eval JSON as artifact.
        """
        try:
             with mlflow.start_run(run_name="Evaluation_Suite"):
                 # Log top level metrics
                 
                 ragas = results.get("ragas", {})
                 for k, v in ragas.items():
                      mlflow.log_metric(f"suite_{k}", float(v))
                      
                 custom = results.get("custom", {})
                 for k, v in custom.items():
                      mlflow.log_metric(f"suite_{k}", float(v))
                      
                 # Log the artifact file directly
                 mlflow.log_artifact(filepath, "evaluation_reports")
                 
                 logger.info("Evaluation Suite saved successfully into MLflow.")
        except Exception as e:
             logger.error(f"Failed to log eval suite to MLflow: {e}")

    def compare_reranker_runs(self, run_id_old: str, run_id_new: str) -> bool:
        """
        Returns True if new run's MAP@10 > old run by > 0.02.
        Queries the MLflow API for run records.
        """
        try:
             client = mlflow.tracking.MlflowClient(self.tracking_uri)
             
             old_run = client.get_run(run_id_old)
             new_run = client.get_run(run_id_new)
             
             old_map = old_run.data.metrics.get("map_at_10", 0.0)
             new_map = new_run.data.metrics.get("map_at_10", 0.0)
             
             return new_map > old_map + 0.02
        except Exception as e:
             logger.error(f"Failed comparing MLflow runs {run_id_old} / {run_id_new}: {e}")
             return False