Spaces:
Build error
Build error
| """ | |
| Experiment Tracking Module. | |
| MLflow integration for tracking experiments and models. | |
| """ | |
| from dataclasses import dataclass, field | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| import json | |
| from ..utils import get_logger, get_config, LoggerMixin | |
| logger = get_logger(__name__) | |
| config = get_config() | |
| class ExperimentRun: | |
| """Container for experiment run data.""" | |
| run_id: str | |
| experiment_name: str | |
| params: Dict[str, Any] | |
| metrics: Dict[str, float] | |
| artifacts: List[str] = field(default_factory=list) | |
| tags: Dict[str, str] = field(default_factory=dict) | |
| start_time: str = "" | |
| end_time: str = "" | |
| status: str = "running" | |
| def to_dict(self) -> Dict: | |
| return { | |
| "run_id": self.run_id, | |
| "experiment_name": self.experiment_name, | |
| "params": self.params, | |
| "metrics": self.metrics, | |
| "artifacts": self.artifacts, | |
| "tags": self.tags, | |
| "start_time": self.start_time, | |
| "end_time": self.end_time, | |
| "status": self.status | |
| } | |
| class ExperimentTracker(LoggerMixin): | |
| """ | |
| MLflow-based experiment tracker. | |
| Features: | |
| - Automatic experiment creation | |
| - Parameter and metric logging | |
| - Model artifact management | |
| - Run comparison | |
| """ | |
| def __init__( | |
| self, | |
| tracking_uri: Optional[str] = None, | |
| experiment_name: Optional[str] = None, | |
| use_mlflow: bool = True | |
| ): | |
| """ | |
| Initialize experiment tracker. | |
| Args: | |
| tracking_uri: MLflow tracking server URI | |
| experiment_name: Default experiment name | |
| use_mlflow: Whether to use MLflow (fallback to local JSON) | |
| """ | |
| self.tracking_uri = tracking_uri or config.mlflow.tracking_uri | |
| self.experiment_name = experiment_name or config.mlflow.experiment_name | |
| self.use_mlflow = use_mlflow | |
| self._mlflow = None | |
| self._active_run = None | |
| self._local_runs: List[ExperimentRun] = [] | |
| def _init_mlflow(self): | |
| """Initialize MLflow.""" | |
| if self._mlflow is not None: | |
| return | |
| if not self.use_mlflow: | |
| return | |
| try: | |
| import mlflow | |
| self._mlflow = mlflow | |
| mlflow.set_tracking_uri(self.tracking_uri) | |
| mlflow.set_experiment(self.experiment_name) | |
| self.logger.info(f"MLflow initialized: {self.tracking_uri}") | |
| except ImportError: | |
| self.logger.warning("MLflow not installed, using local tracking") | |
| self.use_mlflow = False | |
| def start_run( | |
| self, | |
| run_name: Optional[str] = None, | |
| tags: Optional[Dict[str, str]] = None | |
| ) -> str: | |
| """ | |
| Start a new experiment run. | |
| Args: | |
| run_name: Optional name for the run | |
| tags: Optional tags for the run | |
| Returns: | |
| Run ID | |
| """ | |
| self._init_mlflow() | |
| if self.use_mlflow and self._mlflow: | |
| run = self._mlflow.start_run(run_name=run_name, tags=tags) | |
| run_id = run.info.run_id | |
| self._active_run = run | |
| else: | |
| # Local tracking | |
| run_id = datetime.now().strftime("%Y%m%d_%H%M%S_%f") | |
| self._active_run = ExperimentRun( | |
| run_id=run_id, | |
| experiment_name=self.experiment_name, | |
| params={}, | |
| metrics={}, | |
| tags=tags or {}, | |
| start_time=datetime.now().isoformat() | |
| ) | |
| self.logger.info(f"Started run: {run_id}") | |
| return run_id | |
| def log_params(self, params: Dict[str, Any]): | |
| """ | |
| Log parameters. | |
| Args: | |
| params: Dict of parameter names to values | |
| """ | |
| if self.use_mlflow and self._mlflow: | |
| self._mlflow.log_params(params) | |
| else: | |
| if isinstance(self._active_run, ExperimentRun): | |
| self._active_run.params.update(params) | |
| self.logger.debug(f"Logged {len(params)} parameters") | |
| def log_param(self, key: str, value: Any): | |
| """Log single parameter.""" | |
| self.log_params({key: value}) | |
| def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): | |
| """ | |
| Log metrics. | |
| Args: | |
| metrics: Dict of metric names to values | |
| step: Optional step number for tracking over time | |
| """ | |
| if self.use_mlflow and self._mlflow: | |
| self._mlflow.log_metrics(metrics, step=step) | |
| else: | |
| if isinstance(self._active_run, ExperimentRun): | |
| self._active_run.metrics.update(metrics) | |
| self.logger.debug(f"Logged {len(metrics)} metrics") | |
| def log_metric(self, key: str, value: float, step: Optional[int] = None): | |
| """Log single metric.""" | |
| self.log_metrics({key: value}, step=step) | |
| def log_artifact( | |
| self, | |
| local_path: str, | |
| artifact_path: Optional[str] = None | |
| ): | |
| """ | |
| Log an artifact file. | |
| Args: | |
| local_path: Path to local file | |
| artifact_path: Optional subdirectory in artifact store | |
| """ | |
| if self.use_mlflow and self._mlflow: | |
| self._mlflow.log_artifact(local_path, artifact_path) | |
| else: | |
| if isinstance(self._active_run, ExperimentRun): | |
| self._active_run.artifacts.append(local_path) | |
| self.logger.debug(f"Logged artifact: {local_path}") | |
| def log_model( | |
| self, | |
| model: Any, | |
| artifact_path: str = "model", | |
| model_type: str = "pytorch" | |
| ): | |
| """ | |
| Log a model artifact. | |
| Args: | |
| model: Model object to log | |
| artifact_path: Path in artifact store | |
| model_type: Type of model ("pytorch", "sklearn", "custom") | |
| """ | |
| if not self.use_mlflow or not self._mlflow: | |
| self.logger.warning("Model logging requires MLflow") | |
| return | |
| if model_type == "pytorch": | |
| import mlflow.pytorch | |
| mlflow.pytorch.log_model(model, artifact_path) | |
| elif model_type == "sklearn": | |
| import mlflow.sklearn | |
| mlflow.sklearn.log_model(model, artifact_path) | |
| elif model_type == "transformers": | |
| import mlflow.transformers | |
| mlflow.transformers.log_model(model, artifact_path) | |
| else: | |
| # Generic Python model | |
| import mlflow.pyfunc | |
| mlflow.pyfunc.log_model(artifact_path, python_model=model) | |
| self.logger.info(f"Logged {model_type} model: {artifact_path}") | |
| def set_tags(self, tags: Dict[str, str]): | |
| """ | |
| Set run tags. | |
| Args: | |
| tags: Dict of tag names to values | |
| """ | |
| if self.use_mlflow and self._mlflow: | |
| self._mlflow.set_tags(tags) | |
| else: | |
| if isinstance(self._active_run, ExperimentRun): | |
| self._active_run.tags.update(tags) | |
| def end_run(self, status: str = "FINISHED"): | |
| """ | |
| End current run. | |
| Args: | |
| status: Run status ("FINISHED", "FAILED", "KILLED") | |
| """ | |
| if self.use_mlflow and self._mlflow: | |
| self._mlflow.end_run(status) | |
| else: | |
| if isinstance(self._active_run, ExperimentRun): | |
| self._active_run.end_time = datetime.now().isoformat() | |
| self._active_run.status = status.lower() | |
| self._local_runs.append(self._active_run) | |
| self._active_run = None | |
| self.logger.info(f"Ended run with status: {status}") | |
| def get_run(self, run_id: str) -> Optional[ExperimentRun]: | |
| """ | |
| Get run by ID. | |
| Args: | |
| run_id: Run ID | |
| Returns: | |
| ExperimentRun or None | |
| """ | |
| if self.use_mlflow and self._mlflow: | |
| try: | |
| run = self._mlflow.get_run(run_id) | |
| return ExperimentRun( | |
| run_id=run.info.run_id, | |
| experiment_name=self.experiment_name, | |
| params=run.data.params, | |
| metrics=run.data.metrics, | |
| tags=run.data.tags, | |
| start_time=str(run.info.start_time), | |
| end_time=str(run.info.end_time), | |
| status=run.info.status | |
| ) | |
| except Exception as e: | |
| self.logger.error(f"Failed to get run: {e}") | |
| return None | |
| else: | |
| for run in self._local_runs: | |
| if run.run_id == run_id: | |
| return run | |
| return None | |
| def list_runs( | |
| self, | |
| max_results: int = 100 | |
| ) -> List[ExperimentRun]: | |
| """ | |
| List recent runs. | |
| Args: | |
| max_results: Maximum number of runs to return | |
| Returns: | |
| List of ExperimentRun objects | |
| """ | |
| if self.use_mlflow and self._mlflow: | |
| try: | |
| runs = self._mlflow.search_runs( | |
| experiment_names=[self.experiment_name], | |
| max_results=max_results | |
| ) | |
| result = [] | |
| for _, row in runs.iterrows(): | |
| result.append(ExperimentRun( | |
| run_id=row['run_id'], | |
| experiment_name=self.experiment_name, | |
| params={k.replace('params.', ''): v | |
| for k, v in row.items() if k.startswith('params.')}, | |
| metrics={k.replace('metrics.', ''): v | |
| for k, v in row.items() if k.startswith('metrics.')}, | |
| status=row.get('status', 'unknown') | |
| )) | |
| return result | |
| except Exception as e: | |
| self.logger.error(f"Failed to list runs: {e}") | |
| return [] | |
| else: | |
| return self._local_runs[-max_results:] | |
| def compare_runs( | |
| self, | |
| run_ids: List[str], | |
| metric_keys: Optional[List[str]] = None | |
| ) -> Dict[str, Dict[str, float]]: | |
| """ | |
| Compare metrics across runs. | |
| Args: | |
| run_ids: List of run IDs to compare | |
| metric_keys: Metrics to compare (None = all) | |
| Returns: | |
| Dict mapping run_id to metrics | |
| """ | |
| comparison = {} | |
| for run_id in run_ids: | |
| run = self.get_run(run_id) | |
| if run: | |
| if metric_keys: | |
| comparison[run_id] = { | |
| k: v for k, v in run.metrics.items() | |
| if k in metric_keys | |
| } | |
| else: | |
| comparison[run_id] = run.metrics | |
| return comparison | |
| def save_local_runs(self, path: Path): | |
| """Save local run history to file.""" | |
| path = Path(path) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(path, 'w') as f: | |
| json.dump([r.to_dict() for r in self._local_runs], f, indent=2) | |
| self.logger.info(f"Saved {len(self._local_runs)} runs to {path}") | |
| def load_local_runs(self, path: Path): | |
| """Load local run history from file.""" | |
| path = Path(path) | |
| if path.exists(): | |
| with open(path) as f: | |
| data = json.load(f) | |
| self._local_runs = [ | |
| ExperimentRun(**run) for run in data | |
| ] | |
| self.logger.info(f"Loaded {len(self._local_runs)} runs from {path}") | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Experiment Tracker Test") | |
| parser.add_argument("--test", action="store_true", help="Run test mode") | |
| args = parser.parse_args() | |
| if args.test: | |
| print("Experiment Tracker Test\n" + "=" * 50) | |
| # Initialize tracker without MLflow | |
| tracker = ExperimentTracker( | |
| experiment_name="test_experiment", | |
| use_mlflow=False | |
| ) | |
| # Start a run | |
| run_id = tracker.start_run(run_name="test_run") | |
| print(f"Started run: {run_id}") | |
| # Log parameters | |
| tracker.log_params({ | |
| "model": "all-mpnet-base-v2", | |
| "embedding_dim": 768, | |
| "top_k": 10 | |
| }) | |
| # Log metrics | |
| tracker.log_metrics({ | |
| "ndcg@5": 0.78, | |
| "mrr": 0.82, | |
| "latency_p50": 45.2 | |
| }) | |
| # End run | |
| tracker.end_run() | |
| # List runs | |
| runs = tracker.list_runs() | |
| print(f"\nTotal runs: {len(runs)}") | |
| for run in runs: | |
| print(f"\nRun: {run.run_id}") | |
| print(f" Params: {run.params}") | |
| print(f" Metrics: {run.metrics}") | |