Spaces:
Sleeping
Sleeping
| import logging | |
| import mlflow | |
| from datetime import datetime | |
| from typing import Dict, List, Optional, Literal | |
| from mlflow.tracking import MlflowClient | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def _format_timestamp(ts: int) -> str: | |
| """Convert MLflow timestamp (milliseconds since epoch) to readable string.""" | |
| dt = datetime.fromtimestamp(ts / 1000.0) | |
| return dt.strftime("%Y-%m-%d %H:%M:%S") | |
| def set_tracking_uri(uri: str) -> Dict: | |
| """Set MLflow tracking URI and verify connection.""" | |
| if not uri: | |
| return {"error": True, "message": "URI cannot be empty"} | |
| try: | |
| logger.info(f"Setting MLflow tracking URI to {uri}") | |
| mlflow.set_tracking_uri(uri) | |
| return get_system_info() | |
| except Exception as e: | |
| return {"error": True, "message": f"Failed to set URI: {str(e)}"} | |
| def get_system_info() -> Dict: | |
| """Get MLflow system information.""" | |
| try: | |
| client = MlflowClient() | |
| return { | |
| "mlflow_version": mlflow.__version__, | |
| "tracking_uri": mlflow.get_tracking_uri(), | |
| "registry_uri": mlflow.get_registry_uri(), | |
| "artifact_uri": mlflow.get_artifact_uri(), | |
| "python_version": mlflow.__version__, | |
| "server_time": _format_timestamp(int(datetime.now().timestamp() * 1000)), | |
| "experiment_count": len(mlflow.search_experiments()), | |
| "model_count": len(client.search_registered_models()) | |
| } | |
| except Exception as e: | |
| return {"error": True, "message": f"Failed to fetch system info: {str(e)}"} | |
| def list_experiments(name_contains: Optional[str] = "", max_results: Optional[int] = 100) -> Dict: | |
| """List all experiments in the MLflow tracking server, with optional filtering. Includes run count for each experiment.""" | |
| """ | |
| Args: | |
| name_contains: Optional filter to only include experiments whose names contain this string (case-insensitive). | |
| max_results: Maximum number of results to return (default: 100). None means no limit after filtering. | |
| A negative value will result in an empty list. | |
| Returns: | |
| A dictionary containing the total count of returned experiments and a list of their details. | |
| Format: {"total_experiments": count, "experiments": [exp_details, ...]} | |
| Returns {"error": True, "message": ...} on failure. | |
| """ | |
| logger.info(f"Fetching experiments (filter: '{name_contains}', max_results: {max_results})") | |
| try: | |
| client = MlflowClient() | |
| all_mlflow_experiments: List[mlflow.entities.Experiment] = client.search_experiments() | |
| filtered_experiments: List[mlflow.entities.Experiment] | |
| processed_name_filter = name_contains.strip().lower() if name_contains else "" | |
| if processed_name_filter: | |
| filtered_experiments = [ | |
| exp for exp in all_mlflow_experiments | |
| if processed_name_filter in exp.name.lower() | |
| ] | |
| else: | |
| filtered_experiments = all_mlflow_experiments | |
| # Apply max_results limit | |
| limited_experiments: List[mlflow.entities.Experiment] | |
| if max_results is not None: | |
| if max_results < 0: | |
| limited_experiments = [] | |
| else: | |
| limited_experiments = filtered_experiments[:max_results] | |
| else: # max_results is None, return all filtered experiments | |
| limited_experiments = filtered_experiments | |
| experiments_info = [] | |
| for exp in limited_experiments: | |
| creation_time_str = None | |
| if hasattr(exp, "creation_time") and exp.creation_time is not None: | |
| creation_time_str = _format_timestamp(exp.creation_time) | |
| tags_dict = {} | |
| if hasattr(exp, "tags") and exp.tags: | |
| tags_dict = dict(exp.tags) # exp.tags is already a dict {key: value} | |
| exp_detail = { | |
| "experiment_id": exp.experiment_id, | |
| "name": exp.name, | |
| "artifact_location": exp.artifact_location, | |
| "lifecycle_stage": exp.lifecycle_stage, | |
| "creation_time": creation_time_str, | |
| "tags": tags_dict | |
| } | |
| run_count_val: any # Can be int or str | |
| try: | |
| # Check if any runs exist for this experiment (counts active and deleted) | |
| probe_runs = client.search_runs( | |
| experiment_ids=[exp.experiment_id], | |
| max_results=1, | |
| run_view_type=mlflow.entities.ViewType.ALL | |
| ) | |
| if probe_runs: | |
| # If runs exist, get a more accurate count up to a practical limit | |
| all_runs_for_count = client.search_runs( | |
| experiment_ids=[exp.experiment_id], | |
| max_results=50000, # Practical limit for counting | |
| run_view_type=mlflow.entities.ViewType.ALL | |
| ) | |
| run_count_val = len(all_runs_for_count) | |
| else: | |
| run_count_val = 0 | |
| except Exception as e_runs: | |
| logger.warning(f"Error getting run count for experiment '{exp.name}' (ID: {exp.experiment_id}): {str(e_runs)}") | |
| run_count_val = "Error getting count" | |
| exp_detail["run_count"] = run_count_val | |
| experiments_info.append(exp_detail) | |
| result = { | |
| "total_experiments": len(experiments_info), | |
| "experiments": experiments_info | |
| } | |
| return result | |
| except Exception as e: | |
| error_msg = f"Error listing experiments: {str(e)}" | |
| logger.error(error_msg, exc_info=True) | |
| return {"error": True, "message": error_msg} | |
| def create_experiment(name: str, tags: Optional[Dict[str, str]] = None) -> Dict: | |
| """Create a new experiment. Given the name and tags""" | |
| if not name: | |
| return {"error": True, "message": "Experiment name cannot be empty"} | |
| try: | |
| experiment_id = mlflow.create_experiment(name=name, tags=tags or {}) | |
| return { | |
| "experiment_id": experiment_id, | |
| "message": "Created experiment" | |
| } | |
| except Exception as e: | |
| return {"error": True, "message": f"Failed to create experiment: {str(e)}"} | |
| def search_runs( | |
| experiment_id: str, | |
| filter_string: str, | |
| order_string: Optional[str] = None, | |
| max_results: int = 100 | |
| ) -> Dict: | |
| """Search runs in a given experiment, with filtering and ordering.""" | |
| """ | |
| Args: | |
| experiment_id: The ID of the experiment to search runs in. | |
| filter_string: A filter query string used to search for runs. | |
| It follows the MLflow search filter syntax. | |
| Examples: | |
| - "metrics.accuracy > 0.95" | |
| - "params.learning_rate = '0.001'" | |
| - "tags.environment = 'production'" | |
| - "attributes.status = 'FINISHED'" | |
| - "metrics.loss < 0.2 AND params.optimizer = 'Adam'" | |
| If an empty string is provided, no filtering is applied by this string. | |
| Multiple conditions can be combined using 'AND' or 'OR'. | |
| order_string: An optional string to define the order of the results. | |
| It should be a single string composed of a metric, parameter, or attribute | |
| followed by 'ASC' (ascending) or 'DESC' (descending). | |
| Examples: | |
| - "metrics.validation_loss ASC" | |
| - "params.num_epochs DESC" | |
| - "attributes.start_time DESC" | |
| If None or an empty string, results are ordered by MLflow's default (usually start_time DESC). | |
| max_results: Maximum number of runs to return (default: 100). | |
| Returns: | |
| A dictionary containing a list of runs matching the criteria or an error message. | |
| Format: {"runs": [run_details, ...]} or {"error": True, "message": ...} | |
| """ | |
| # Validate experiment_id (must be non-empty) | |
| if not experiment_id: | |
| return {"error": True, "message": "Experiment ID cannot be empty"} | |
| # Validate max_results | |
| if max_results <= 0: | |
| return {"error": True, "message": "max_results must be a positive integer"} | |
| # Ensure filter_string is not None, default to empty if it is (for mlflow.search_runs) | |
| current_filter_string = filter_string if filter_string is not None else "" | |
| found_runs: List[mlflow.entities.Run] # Type hint for the list of Run objects | |
| try: | |
| logger.info(f"Searching runs in experiment '{experiment_id}' with filter '{current_filter_string}', order by '{order_string}', max_results '{max_results}'") | |
| order_by_list = [order_string] if order_string and order_string.strip() else None | |
| found_runs = mlflow.search_runs( | |
| experiment_ids=[str(experiment_id)], # Ensure experiment_id is a string | |
| filter_string=current_filter_string, | |
| max_results=max_results, | |
| order_by=order_by_list, | |
| output_format="list" # Get a list of Run objects instead of DataFrame | |
| ) | |
| except Exception as e_search: | |
| logger.error(f"MLflow search_runs API call failed for experiment_id '{experiment_id}': {str(e_search)}", exc_info=True) | |
| return {"error": True, "message": f"MLflow search_runs API call failed: {str(e_search)}"} | |
| processed_runs_info = [] | |
| if not found_runs: | |
| logger.info(f"No runs found for experiment_id '{experiment_id}' with the given criteria.") | |
| return {"runs": []} | |
| for run_obj in found_runs: | |
| run_id_for_log = run_obj.info.run_id if run_obj.info else "N/A" | |
| try: | |
| start_time_ms = run_obj.info.start_time | |
| end_time_ms = run_obj.info.end_time | |
| run_details = { | |
| "run_id": run_obj.info.run_id, | |
| "status": run_obj.info.status, | |
| "start_time": _format_timestamp(start_time_ms) if start_time_ms is not None else None, | |
| "end_time": _format_timestamp(end_time_ms) if end_time_ms is not None else None, | |
| "params": dict(run_obj.data.params), | |
| "metrics": dict(run_obj.data.metrics), | |
| "tags": dict(run_obj.data.tags) | |
| } | |
| processed_runs_info.append(run_details) | |
| except Exception as e_process_run: | |
| logger.warning( | |
| f"Failed to process data for run_id '{run_id_for_log}' in experiment '{experiment_id}'. Error: {str(e_process_run)}. Skipping this run.", | |
| exc_info=True | |
| ) | |
| continue # Skip to the next run | |
| return {"runs": processed_runs_info} | |
| def list_registered_models() -> Dict: | |
| """List all registered models.""" | |
| try: | |
| logger.info("Listing registered models") | |
| client = MlflowClient() | |
| models = client.search_registered_models() | |
| return { | |
| "models": [ | |
| { | |
| "name": model.name, | |
| "creation_timestamp": _format_timestamp(model.creation_timestamp), | |
| "last_updated_timestamp": _format_timestamp(model.last_updated_timestamp), | |
| "description": model.description or "", | |
| "tags": {tag.key: tag.value for tag in model.tags} if hasattr(model, "tags") else {}, | |
| "latest_versions": [mv.version for mv in model.latest_versions] | |
| } | |
| for model in models | |
| ] | |
| } | |
| except Exception as e: | |
| return {"error": True, "message": f"Failed to list registered models: {str(e)}"} | |
| def get_model_info(model_name: str) -> Dict: | |
| """Get detailed information about a registered model.""" | |
| if not model_name: | |
| return {"error": True, "message": "Model name cannot be empty"} | |
| try: | |
| logger.info(f"Fetching info for model '{model_name}'") | |
| client = MlflowClient() | |
| model = client.get_registered_model(name=model_name) | |
| model_info = { | |
| "name": model.name, | |
| "creation_timestamp": _format_timestamp(model.creation_timestamp), | |
| "last_updated_timestamp": _format_timestamp(model.last_updated_timestamp), | |
| "description": model.description or "", | |
| "tags": {tag.key: tag.value for tag in model.tags} if hasattr(model, "tags") else {}, | |
| "versions": [] | |
| } | |
| for mv in model.latest_versions: | |
| run_id = mv.run_id | |
| version_dict = { | |
| "version": mv.version, | |
| "current_stage": mv.current_stage, | |
| "creation_timestamp": _format_timestamp(mv.creation_timestamp), | |
| "last_updated_timestamp": _format_timestamp(mv.last_updated_timestamp), | |
| "run": {} | |
| } | |
| run = client.get_run(run_id) | |
| version_dict["run"] = { | |
| "status": run.info.status, | |
| "start_time": _format_timestamp(run.info.start_time), | |
| "end_time": _format_timestamp(run.info.end_time) if run.info.end_time else None, | |
| "metrics": run.data.metrics | |
| } | |
| model_info["versions"].append(version_dict) | |
| return {"model": model_info} | |
| except Exception as e: | |
| return {"error": True, "message": f"Failed to fetch model info: {str(e)}"} | |
| def register_model( | |
| run_id: str, | |
| model_name: str, | |
| description: Optional[str] = None, | |
| tags: Optional[Dict[str, str]] = None | |
| ) -> Dict: | |
| """ | |
| Register a model from a run, with optional description and tags. | |
| """ | |
| if not all([run_id, model_name]): | |
| return {"error": True, "message": "Run ID and model name must be non-empty"} | |
| # Prepare description and tags | |
| final_description = (description or "") + " Model registered by LLM through MCP service." | |
| final_tags = { | |
| "registered_by": "mcp-llm-service", | |
| "registration_timestamp": datetime.now().isoformat() | |
| } | |
| if tags: | |
| final_tags.update(tags) | |
| try: | |
| logger.info(f"Registering model '{model_name}' from run '{run_id}/model' with description and tags.") | |
| model_uri = f"runs:/{run_id}/model" | |
| result = mlflow.register_model( | |
| model_uri=model_uri, | |
| name=model_name, | |
| tags=final_tags | |
| ) | |
| client = MlflowClient() | |
| client.update_model_version( | |
| name=model_name, | |
| version=result.version, | |
| description=final_description | |
| ) | |
| return { | |
| "model_name": model_name, | |
| "version": result.version, | |
| "description": final_description, | |
| "tags": final_tags, | |
| "message": "Registered successfully" | |
| } | |
| except Exception as e: | |
| return {"error": True, "message": f"Registration failed: {str(e)}"} |