File size: 1,803 Bytes
4344b33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# src/tracking/mlflow_tracker.py — MLflow integration (optional)

from __future__ import annotations

from typing import Any, Dict, Optional

from src.tracking.experiments import ExperimentTracker


class MLflowTracker(ExperimentTracker):
    """MLflow-backed tracker (stub — enable when mlflow is available)."""

    def __init__(self, experiment_name: str = "uvm_tb_generator", tracking_uri: Optional[str] = None):
        super().__init__(experiment_name)
        self.tracking_uri = tracking_uri
        self._mlflow_available = self._check_mlflow()

    @staticmethod
    def _check_mlflow() -> bool:
        try:
            import mlflow  # noqa: F401
            return True
        except ImportError:
            return False

    def start_run(self, params: Optional[Dict[str, Any]] = None) -> str:
        if self._mlflow_available:
            import mlflow
            mlflow.set_experiment(self.experiment_name)
            if self.tracking_uri:
                mlflow.set_tracking_uri(self.tracking_uri)
            mlflow.start_run()
            if params:
                mlflow.log_params(params)
            return mlflow.active_run().info.run_id
        return super().start_run(params)

    def log_metric(self, key: str, value: float) -> None:
        if self._mlflow_available:
            import mlflow
            mlflow.log_metric(key, value)
        else:
            super().log_metric(key, value)

    def log_artifact(self, path: str) -> None:
        if self._mlflow_available:
            import mlflow
            mlflow.log_artifact(path)
        else:
            super().log_artifact(path)

    def finish_run(self) -> None:
        if self._mlflow_available:
            import mlflow
            mlflow.end_run()
        else:
            super().finish_run()