tuntun commited on
Commit
8528842
·
1 Parent(s): 0165571
Files changed (3) hide show
  1. app.py +125 -0
  2. mcp_mlflow_tools.py +380 -0
  3. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import gradio as gr
3
+ from mcp_mlflow_tools import (
4
+ set_tracking_uri,
5
+ get_system_info,
6
+ list_experiments,
7
+ create_experiment,
8
+ register_model,
9
+ search_runs,
10
+ list_registered_models,
11
+ get_model_info
12
+ )
13
+
14
+ def create_interface():
15
+ with gr.Blocks(title="MLflow MCP Service") as app:
16
+ gr.Markdown("# MLflow MCP Service")
17
+ gr.Markdown("A service that exposes MLflow functionality through a web interface and API endpoints.")
18
+
19
+ with gr.Tab("Tracking & System Info"):
20
+ with gr.Group():
21
+ gr.Markdown("## Set Tracking URI")
22
+ uri_input = gr.Textbox(label="MLflow Tracking URI")
23
+ uri_output = gr.JSON(label="Result")
24
+ uri_button = gr.Button("Set URI")
25
+ uri_button.click(
26
+ fn=set_tracking_uri,
27
+ inputs=uri_input,
28
+ outputs=uri_output
29
+ )
30
+
31
+ with gr.Group():
32
+ gr.Markdown("## Get System Info")
33
+ sys_info_output = gr.JSON(label="System Information")
34
+ sys_info_button = gr.Button("Get Info")
35
+ sys_info_button.click(
36
+ fn=get_system_info,
37
+ inputs=[],
38
+ outputs=sys_info_output
39
+ )
40
+
41
+ with gr.Tab("Experiment Management"):
42
+ with gr.Group():
43
+ gr.Markdown("## List Experiments")
44
+ exp_list_output = gr.JSON(label="Experiments")
45
+ exp_list_button = gr.Button("List Experiments")
46
+ exp_list_button.click(
47
+ fn=list_experiments,
48
+ inputs=[],
49
+ outputs=exp_list_output
50
+ )
51
+
52
+ with gr.Group():
53
+ gr.Markdown("## Create Experiment")
54
+ exp_name_input = gr.Textbox(label="Experiment Name")
55
+ exp_tags_input = gr.Textbox(label="Tags (JSON format)", placeholder='{"key": "value"}')
56
+ exp_create_output = gr.JSON(label="Result")
57
+ exp_create_button = gr.Button("Create Experiment")
58
+
59
+ def create_exp_with_tags(name, tags_str):
60
+ try:
61
+ tags = json.loads(tags_str) if tags_str else None
62
+ except json.JSONDecodeError:
63
+ return {"error": True, "message": "Invalid JSON format for tags"}
64
+ return create_experiment(name, tags)
65
+
66
+ exp_create_button.click(
67
+ fn=create_exp_with_tags,
68
+ inputs=[exp_name_input, exp_tags_input],
69
+ outputs=exp_create_output
70
+ )
71
+
72
+ with gr.Tab("Model Registry"):
73
+ with gr.Group():
74
+ gr.Markdown("## Register Model")
75
+ reg_run_id = gr.Textbox(label="Run ID")
76
+ reg_artifact_path = gr.Textbox(label="Artifact Path")
77
+ reg_model_name = gr.Textbox(label="Model Name")
78
+ reg_output = gr.JSON(label="Result")
79
+ reg_button = gr.Button("Register Model")
80
+ reg_button.click(
81
+ fn=register_model,
82
+ inputs=[reg_run_id, reg_artifact_path, reg_model_name],
83
+ outputs=reg_output
84
+ )
85
+
86
+ with gr.Group():
87
+ gr.Markdown("## List Registered Models")
88
+ list_models_output = gr.JSON(label="Models")
89
+ list_models_button = gr.Button("List Models")
90
+ list_models_button.click(
91
+ fn=list_registered_models,
92
+ inputs=[],
93
+ outputs=list_models_output
94
+ )
95
+
96
+ with gr.Group():
97
+ gr.Markdown("## Get Model Info")
98
+ model_info_name = gr.Textbox(label="Model Name")
99
+ model_info_output = gr.JSON(label="Model Information")
100
+ model_info_button = gr.Button("Get Info")
101
+ model_info_button.click(
102
+ fn=get_model_info,
103
+ inputs=model_info_name,
104
+ outputs=model_info_output
105
+ )
106
+
107
+ with gr.Tab("Run Search"):
108
+ with gr.Group():
109
+ gr.Markdown("## Search Runs")
110
+ search_exp_id = gr.Textbox(label="Experiment ID")
111
+ search_filter = gr.Textbox(label="Filter String")
112
+ search_max_results = gr.Number(label="Max Results", value=100, precision=0)
113
+ search_output = gr.JSON(label="Search Results")
114
+ search_button = gr.Button("Search")
115
+ search_button.click(
116
+ fn=search_runs,
117
+ inputs=[search_exp_id, search_filter, search_max_results],
118
+ outputs=search_output
119
+ )
120
+
121
+ return app
122
+
123
+ if __name__ == "__main__":
124
+ app = create_interface()
125
+ app.launch(mcp_server=True)
mcp_mlflow_tools.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import mlflow
3
+ from datetime import datetime
4
+ from typing import Dict, List, Optional, Literal
5
+ from mlflow.tracking import MlflowClient
6
+
7
+ # Configure logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ def _format_timestamp(ts: int) -> str:
12
+ """Convert MLflow timestamp (milliseconds since epoch) to readable string."""
13
+ dt = datetime.fromtimestamp(ts / 1000.0)
14
+ return dt.strftime("%Y-%m-%d %H:%M:%S")
15
+
16
+ def set_tracking_uri(uri: str) -> Dict:
17
+ """Set MLflow tracking URI and verify connection."""
18
+ if not uri:
19
+ return {"error": True, "message": "URI cannot be empty"}
20
+
21
+ try:
22
+ logger.info(f"Setting MLflow tracking URI to {uri}")
23
+ mlflow.set_tracking_uri(uri)
24
+ return get_system_info()
25
+ except Exception as e:
26
+ return {"error": True, "message": f"Failed to set URI: {str(e)}"}
27
+
28
+ def get_system_info() -> Dict:
29
+ """Get MLflow system information."""
30
+ try:
31
+ client = MlflowClient()
32
+ return {
33
+ "mlflow_version": mlflow.__version__,
34
+ "tracking_uri": mlflow.get_tracking_uri(),
35
+ "registry_uri": mlflow.get_registry_uri(),
36
+ "artifact_uri": mlflow.get_artifact_uri(),
37
+ "python_version": mlflow.__version__,
38
+ "server_time": _format_timestamp(int(datetime.now().timestamp() * 1000)),
39
+ "experiment_count": len(mlflow.search_experiments()),
40
+ "model_count": len(client.search_registered_models())
41
+ }
42
+ except Exception as e:
43
+ return {"error": True, "message": f"Failed to fetch system info: {str(e)}"}
44
+
45
+ def list_experiments(name_contains: Optional[str] = "", max_results: Optional[int] = 100) -> Dict:
46
+ """
47
+ List all experiments in the MLflow tracking server, with optional filtering.
48
+ Includes run count for each experiment.
49
+
50
+ Args:
51
+ name_contains: Optional filter to only include experiments whose names contain this string (case-insensitive).
52
+ max_results: Maximum number of results to return (default: 100). None means no limit after filtering.
53
+ A negative value will result in an empty list.
54
+
55
+ Returns:
56
+ A dictionary containing the total count of returned experiments and a list of their details.
57
+ Format: {"total_experiments": count, "experiments": [exp_details, ...]}
58
+ Returns {"error": True, "message": ...} on failure.
59
+ """
60
+ logger.info(f"Fetching experiments (filter: '{name_contains}', max_results: {max_results})")
61
+
62
+ try:
63
+ client = MlflowClient()
64
+ all_mlflow_experiments: List[mlflow.entities.Experiment] = client.search_experiments()
65
+ filtered_experiments: List[mlflow.entities.Experiment]
66
+ processed_name_filter = name_contains.strip().lower() if name_contains else ""
67
+
68
+ if processed_name_filter:
69
+ filtered_experiments = [
70
+ exp for exp in all_mlflow_experiments
71
+ if processed_name_filter in exp.name.lower()
72
+ ]
73
+ else:
74
+ filtered_experiments = all_mlflow_experiments
75
+
76
+ # Apply max_results limit
77
+ limited_experiments: List[mlflow.entities.Experiment]
78
+ if max_results is not None:
79
+ if max_results < 0:
80
+ limited_experiments = []
81
+ else:
82
+ limited_experiments = filtered_experiments[:max_results]
83
+ else: # max_results is None, return all filtered experiments
84
+ limited_experiments = filtered_experiments
85
+
86
+ experiments_info = []
87
+
88
+ for exp in limited_experiments:
89
+ creation_time_str = None
90
+ if hasattr(exp, "creation_time") and exp.creation_time is not None:
91
+ creation_time_str = _format_timestamp(exp.creation_time)
92
+
93
+ tags_dict = {}
94
+ if hasattr(exp, "tags") and exp.tags:
95
+ tags_dict = dict(exp.tags) # exp.tags is already a dict {key: value}
96
+
97
+ exp_detail = {
98
+ "experiment_id": exp.experiment_id,
99
+ "name": exp.name,
100
+ "artifact_location": exp.artifact_location,
101
+ "lifecycle_stage": exp.lifecycle_stage,
102
+ "creation_time": creation_time_str,
103
+ "tags": tags_dict
104
+ }
105
+
106
+ run_count_val: any # Can be int or str
107
+ try:
108
+ # Check if any runs exist for this experiment (counts active and deleted)
109
+ probe_runs = client.search_runs(
110
+ experiment_ids=[exp.experiment_id],
111
+ max_results=1,
112
+ run_view_type=mlflow.entities.ViewType.ALL
113
+ )
114
+ if probe_runs:
115
+ # If runs exist, get a more accurate count up to a practical limit
116
+ all_runs_for_count = client.search_runs(
117
+ experiment_ids=[exp.experiment_id],
118
+ max_results=50000, # Practical limit for counting
119
+ run_view_type=mlflow.entities.ViewType.ALL
120
+ )
121
+ run_count_val = len(all_runs_for_count)
122
+ else:
123
+ run_count_val = 0
124
+ except Exception as e_runs:
125
+ logger.warning(f"Error getting run count for experiment '{exp.name}' (ID: {exp.experiment_id}): {str(e_runs)}")
126
+ run_count_val = "Error getting count"
127
+
128
+ exp_detail["run_count"] = run_count_val
129
+ experiments_info.append(exp_detail)
130
+
131
+ result = {
132
+ "total_experiments": len(experiments_info),
133
+ "experiments": experiments_info
134
+ }
135
+
136
+ return result
137
+
138
+ except Exception as e:
139
+ error_msg = f"Error listing experiments: {str(e)}"
140
+ logger.error(error_msg, exc_info=True)
141
+ return {"error": True, "message": error_msg}
142
+
143
+ def create_experiment(name: str, tags: Optional[Dict[str, str]] = None) -> Dict:
144
+ """Create a new experiment."""
145
+ if not name:
146
+ return {"error": True, "message": "Experiment name cannot be empty"}
147
+
148
+ try:
149
+ experiment_id = mlflow.create_experiment(name=name, tags=tags or {})
150
+ return {
151
+ "experiment_id": experiment_id,
152
+ "message": "Created experiment"
153
+ }
154
+ except Exception as e:
155
+ return {"error": True, "message": f"Failed to create experiment: {str(e)}"}
156
+
157
+ def search_runs(
158
+ experiment_id: str,
159
+ filter_string: str,
160
+ order_string: Optional[str] = None,
161
+ max_results: int = 100
162
+ ) -> Dict:
163
+ """
164
+ Search runs in a given experiment, with filtering and ordering.
165
+
166
+ Args:
167
+ experiment_id: The ID of the experiment to search runs in.
168
+ filter_string: A filter query string used to search for runs.
169
+ It follows the MLflow search filter syntax.
170
+ Examples:
171
+ - "metrics.accuracy > 0.95"
172
+ - "params.learning_rate = '0.001'"
173
+ - "tags.environment = 'production'"
174
+ - "attributes.status = 'FINISHED'"
175
+ - "metrics.loss < 0.2 AND params.optimizer = 'Adam'"
176
+ If an empty string is provided, no filtering is applied by this string.
177
+ Multiple conditions can be combined using 'AND' or 'OR'.
178
+ order_string: An optional string to define the order of the results.
179
+ It should be a single string composed of a metric, parameter, or attribute
180
+ followed by 'ASC' (ascending) or 'DESC' (descending).
181
+ Examples:
182
+ - "metrics.validation_loss ASC"
183
+ - "params.num_epochs DESC"
184
+ - "attributes.start_time DESC"
185
+ If None or an empty string, results are ordered by MLflow's default (usually start_time DESC).
186
+ max_results: Maximum number of runs to return (default: 100).
187
+
188
+ Returns:
189
+ A dictionary containing a list of runs matching the criteria or an error message.
190
+ Format: {"runs": [run_details, ...]} or {"error": True, "message": ...}
191
+ """
192
+ # Validate experiment_id (must be non-empty)
193
+ if not experiment_id:
194
+ return {"error": True, "message": "Experiment ID cannot be empty"}
195
+
196
+ # Validate max_results
197
+ if max_results <= 0:
198
+ return {"error": True, "message": "max_results must be a positive integer"}
199
+
200
+ # Ensure filter_string is not None, default to empty if it is (for mlflow.search_runs)
201
+ current_filter_string = filter_string if filter_string is not None else ""
202
+
203
+ found_runs: List[mlflow.entities.Run] # Type hint for the list of Run objects
204
+ try:
205
+ logger.info(f"Searching runs in experiment '{experiment_id}' with filter '{current_filter_string}', order by '{order_string}', max_results '{max_results}'")
206
+
207
+ order_by_list = [order_string] if order_string and order_string.strip() else None
208
+
209
+ found_runs = mlflow.search_runs(
210
+ experiment_ids=[str(experiment_id)], # Ensure experiment_id is a string
211
+ filter_string=current_filter_string,
212
+ max_results=max_results,
213
+ order_by=order_by_list,
214
+ output_format="list" # Get a list of Run objects instead of DataFrame
215
+ )
216
+ except Exception as e_search:
217
+ logger.error(f"MLflow search_runs API call failed for experiment_id '{experiment_id}': {str(e_search)}", exc_info=True)
218
+ return {"error": True, "message": f"MLflow search_runs API call failed: {str(e_search)}"}
219
+
220
+ processed_runs_info = []
221
+ if not found_runs:
222
+ logger.info(f"No runs found for experiment_id '{experiment_id}' with the given criteria.")
223
+ return {"runs": []}
224
+
225
+ for run_obj in found_runs:
226
+ run_id_for_log = run_obj.info.run_id if run_obj.info else "N/A"
227
+ try:
228
+ start_time_ms = run_obj.info.start_time
229
+ end_time_ms = run_obj.info.end_time
230
+
231
+ run_details = {
232
+ "run_id": run_obj.info.run_id,
233
+ "status": run_obj.info.status,
234
+ "start_time": _format_timestamp(start_time_ms) if start_time_ms is not None else None,
235
+ "end_time": _format_timestamp(end_time_ms) if end_time_ms is not None else None,
236
+ "params": dict(run_obj.data.params),
237
+ "metrics": dict(run_obj.data.metrics),
238
+ "tags": dict(run_obj.data.tags)
239
+ }
240
+ processed_runs_info.append(run_details)
241
+ except Exception as e_process_run:
242
+ logger.warning(
243
+ f"Failed to process data for run_id '{run_id_for_log}' in experiment '{experiment_id}'. Error: {str(e_process_run)}. Skipping this run.",
244
+ exc_info=True
245
+ )
246
+ continue # Skip to the next run
247
+
248
+ return {"runs": processed_runs_info}
249
+
250
+ def list_registered_models() -> Dict:
251
+ """List all registered models."""
252
+ try:
253
+ logger.info("Listing registered models")
254
+ client = MlflowClient()
255
+ models = client.search_registered_models()
256
+
257
+ return {
258
+ "models": [
259
+ {
260
+ "name": model.name,
261
+ "creation_timestamp": _format_timestamp(model.creation_timestamp),
262
+ "last_updated_timestamp": _format_timestamp(model.last_updated_timestamp),
263
+ "description": model.description or "",
264
+ "tags": {tag.key: tag.value for tag in model.tags} if hasattr(model, "tags") else {},
265
+ "latest_versions": [mv.version for mv in model.latest_versions]
266
+ }
267
+ for model in models
268
+ ]
269
+ }
270
+ except Exception as e:
271
+ return {"error": True, "message": f"Failed to list registered models: {str(e)}"}
272
+
273
+ def get_model_info(model_name: str) -> Dict:
274
+ """Get detailed information about a registered model."""
275
+ if not model_name:
276
+ return {"error": True, "message": "Model name cannot be empty"}
277
+
278
+ try:
279
+ logger.info(f"Fetching info for model '{model_name}'")
280
+ client = MlflowClient()
281
+ model = client.get_registered_model(name=model_name)
282
+
283
+ model_info = {
284
+ "name": model.name,
285
+ "creation_timestamp": _format_timestamp(model.creation_timestamp),
286
+ "last_updated_timestamp": _format_timestamp(model.last_updated_timestamp),
287
+ "description": model.description or "",
288
+ "tags": {tag.key: tag.value for tag in model.tags} if hasattr(model, "tags") else {},
289
+ "versions": []
290
+ }
291
+
292
+ for mv in model.latest_versions:
293
+ run_id = mv.run_id
294
+ version_dict = {
295
+ "version": mv.version,
296
+ "current_stage": mv.current_stage,
297
+ "creation_timestamp": _format_timestamp(mv.creation_timestamp),
298
+ "last_updated_timestamp": _format_timestamp(mv.last_updated_timestamp),
299
+ "run": {}
300
+ }
301
+
302
+ run = client.get_run(run_id)
303
+ version_dict["run"] = {
304
+ "status": run.info.status,
305
+ "start_time": _format_timestamp(run.info.start_time),
306
+ "end_time": _format_timestamp(run.info.end_time) if run.info.end_time else None,
307
+ "metrics": run.data.metrics
308
+ }
309
+ model_info["versions"].append(version_dict)
310
+
311
+ return {"model": model_info}
312
+ except Exception as e:
313
+ return {"error": True, "message": f"Failed to fetch model info: {str(e)}"}
314
+
315
+ def register_model(
316
+ run_id: str,
317
+ model_name: str,
318
+ description: Optional[str] = None,
319
+ tags: Optional[Dict[str, str]] = None
320
+ ) -> Dict:
321
+ """
322
+ Register a model from a run, with optional description and tags.
323
+ Assumes artifact path is 'model'.
324
+ """
325
+ if not all([run_id, model_name]):
326
+ return {"error": True, "message": "Run ID and model name must be non-empty"}
327
+
328
+ # Prepare description and tags
329
+ final_description = description + "Model registered by LLM through MCP service."
330
+
331
+ final_tags = {
332
+ "registered_by": "mcp-llm-service",
333
+ "registration_timestamp": datetime.now().isoformat()
334
+ }
335
+ if tags:
336
+ final_tags.update(tags)
337
+
338
+ try:
339
+ logger.info(f"Registering model '{model_name}' from run '{run_id}/model' with description and tags.")
340
+ model_uri = f"runs:/{run_id}/model"
341
+ result = mlflow.register_model(
342
+ model_uri=model_uri,
343
+ name=model_name,
344
+ tags=final_tags
345
+ )
346
+
347
+ client = MlflowClient()
348
+ client.update_model_version(
349
+ name=model_name,
350
+ version=result.version,
351
+ description=final_description
352
+ )
353
+
354
+ return {
355
+ "model_name": model_name,
356
+ "version": result.version,
357
+ "description": final_description,
358
+ "tags": final_tags,
359
+ "message": "Registered successfully"
360
+ }
361
+ except Exception as e:
362
+ return {"error": True, "message": f"Registration failed: {str(e)}"}
363
+
364
+
365
+
366
+
367
+
368
+
369
+ set_tracking_uri("http://127.0.0.1:5000")
370
+ #result = search_runs(1, "", "metrics.rmse ASC", 3)
371
+ # result = list_registered_models()
372
+ # print(result)
373
+ # print(get_model_info("RandomForestBestModel"))
374
+
375
+ print(register_model(
376
+ "a217f7261dbf4b0f8dbc575dad0e2f67",
377
+ "RandomForestSecondBestModel",
378
+ "This is a test model."
379
+ ))
380
+ # run_id: str, artifact_path: str, model_name: str
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ mlflow>=2.22.0
2
+ gradio[mcp]>=5.32.0