Spaces:
Runtime error
Runtime error
| """ | |
| AutoTrain Gradio MCP Server - All-in-One | |
| This single Gradio app: | |
| 1. Provides a web interface for managing AutoTrain jobs | |
| 2. Automatically exposes MCP tools at /gradio_api/mcp/sse | |
| 3. Handles all AutoTrain operations directly (no FastAPI needed) | |
| """ | |
| import os | |
| import json | |
| import time | |
| import uuid | |
| import threading | |
| from datetime import datetime | |
| from typing import List, Dict, Any | |
| import socket | |
| import gradio as gr | |
| import pandas as pd | |
| import wandb | |
| from autotrain.project import AutoTrainProject | |
| from autotrain.params import ( | |
| LLMTrainingParams, | |
| TextClassificationParams, | |
| ImageClassificationParams, | |
| ) | |
| # Simple JSON-based storage (replace with SQLite if needed) | |
| RUNS_FILE = "training_runs.json" | |
| WANDB_PROJECT = os.environ.get("WANDB_PROJECT", "autotrain-mcp") | |
| def load_runs() -> List[Dict[str, Any]]: | |
| """Load training runs from JSON file""" | |
| if os.path.exists(RUNS_FILE): | |
| try: | |
| with open(RUNS_FILE, "r") as f: | |
| return json.load(f) | |
| except (json.JSONDecodeError, IOError): | |
| return [] | |
| return [] | |
| def save_runs(runs: List[Dict[str, Any]]): | |
| """Save training runs to JSON file""" | |
| with open(RUNS_FILE, "w") as f: | |
| json.dump(runs, f, indent=2) | |
| def get_status_emoji(status: str) -> str: | |
| """Get emoji for training status""" | |
| emoji_map = { | |
| "pending": "β³", | |
| "running": "π", | |
| "completed": "β ", | |
| "failed": "β", | |
| "cancelled": "βΉοΈ", | |
| } | |
| return emoji_map.get(status.lower(), "β") | |
| def create_autotrain_params( | |
| task: str, | |
| base_model: str, | |
| project_name: str, | |
| dataset_path: str, | |
| epochs: int, | |
| batch_size: int, | |
| learning_rate: float, | |
| **kwargs, | |
| ): | |
| """Create AutoTrain parameter object based on task type""" | |
| common_params = { | |
| "model": base_model, | |
| "project_name": project_name, | |
| "data_path": dataset_path, | |
| "train_split": kwargs.get("train_split", "train"), | |
| "valid_split": kwargs.get("valid_split"), | |
| "epochs": epochs, | |
| "batch_size": batch_size, | |
| "lr": learning_rate, | |
| "log": "wandb", | |
| # Required defaults | |
| "warmup_ratio": 0.1, | |
| "gradient_accumulation": 1, | |
| "optimizer": "adamw_torch", | |
| "scheduler": "linear", | |
| "weight_decay": 0.01, | |
| "max_grad_norm": 1.0, | |
| "seed": 42, | |
| "logging_steps": 10, | |
| "auto_find_batch_size": False, | |
| "mixed_precision": "no", | |
| "save_total_limit": 1, | |
| "eval_strategy": "epoch", | |
| } | |
| if task == "text-classification": | |
| return TextClassificationParams( | |
| **common_params, | |
| text_column=kwargs.get("text_column", "text"), | |
| target_column=kwargs.get("target_column", "label"), | |
| max_seq_length=kwargs.get("max_seq_length", 128), | |
| early_stopping_patience=3, | |
| early_stopping_threshold=0.01, | |
| ) | |
| elif task.startswith("llm-"): | |
| trainer_map = { | |
| "llm-sft": "sft", | |
| "llm-dpo": "dpo", | |
| "llm-orpo": "orpo", | |
| "llm-reward": "reward", | |
| } | |
| return LLMTrainingParams( | |
| **{ | |
| k: v | |
| for k, v in common_params.items() | |
| if k not in ["early_stopping_patience", "early_stopping_threshold"] | |
| }, | |
| text_column=kwargs.get("text_column", "messages"), | |
| block_size=kwargs.get("block_size", 2048), | |
| peft=kwargs.get("use_peft", True), | |
| quantization=kwargs.get("quantization", "int4"), | |
| trainer=trainer_map[task], | |
| chat_template="tokenizer", | |
| # LLM-specific defaults | |
| add_eos_token=True, | |
| model_max_length=2048, | |
| padding="right", | |
| use_flash_attention_2=False, | |
| disable_gradient_checkpointing=False, | |
| target_modules="all-linear", | |
| merge_adapter=False, | |
| lora_r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.05, | |
| model_ref=None, | |
| dpo_beta=0.1, | |
| max_prompt_length=512, | |
| max_completion_length=1024, | |
| prompt_text_column="prompt", | |
| rejected_text_column="rejected", | |
| unsloth=False, | |
| distributed_backend="accelerate", | |
| ) | |
| elif task == "image-classification": | |
| return ImageClassificationParams( | |
| **common_params, | |
| image_column=kwargs.get("image_column", "image"), | |
| target_column=kwargs.get("target_column", "label"), | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported task type: {task}") | |
| def run_training_background(run_id: str, params: Any, backend: str): | |
| """Run training job in background thread""" | |
| runs = load_runs() | |
| # Update status to running | |
| for run in runs: | |
| if run["run_id"] == run_id: | |
| run["status"] = "running" | |
| run["started_at"] = datetime.utcnow().isoformat() | |
| break | |
| save_runs(runs) | |
| try: | |
| # Initialize W&B | |
| wandb_run = wandb.init( | |
| project=WANDB_PROJECT, | |
| name=f"{params.project_name}-{int(time.time())}", | |
| tags=["autotrain", "mcp"], | |
| config={ | |
| "base_model": params.model, | |
| "dataset": params.data_path, | |
| "epochs": params.epochs, | |
| "batch_size": params.batch_size, | |
| "learning_rate": params.lr, | |
| "backend": backend, | |
| }, | |
| ) | |
| wandb_url = ( | |
| wandb_run.url if wandb_run.url else f"https://wandb.ai/{WANDB_PROJECT}" | |
| ) | |
| # Update with W&B URL | |
| runs = load_runs() | |
| for run in runs: | |
| if run["run_id"] == run_id: | |
| run["wandb_url"] = wandb_url | |
| break | |
| save_runs(runs) | |
| # Create and start AutoTrain project | |
| project = AutoTrainProject(params=params, backend=backend, process=True) | |
| job_id = project.create() | |
| print(f"Training started for run {run_id} with job ID: {job_id}") | |
| # For demo purposes, simulate training completion after a short delay | |
| time.sleep(10) # In real implementation, monitor actual training | |
| # Update status to completed | |
| runs = load_runs() | |
| for run in runs: | |
| if run["run_id"] == run_id: | |
| run["status"] = "completed" | |
| run["completed_at"] = datetime.utcnow().isoformat() | |
| break | |
| save_runs(runs) | |
| wandb.finish() | |
| except Exception as e: | |
| print(f"Training failed for run {run_id}: {str(e)}") | |
| # Update status to failed | |
| runs = load_runs() | |
| for run in runs: | |
| if run["run_id"] == run_id: | |
| run["status"] = "failed" | |
| run["error_message"] = str(e) | |
| run["completed_at"] = datetime.utcnow().isoformat() | |
| break | |
| save_runs(runs) | |
| if wandb.run: | |
| wandb.finish() | |
| # MCP Tool Functions (these automatically become MCP tools) | |
| def start_training_job( | |
| task: str = "text-classification", | |
| project_name: str = "test-project", | |
| base_model: str = "distilbert-base-uncased", | |
| dataset_path: str = "imdb", | |
| epochs: str = "1", | |
| batch_size: str = "8", | |
| learning_rate: str = "2e-5", | |
| backend: str = "local", | |
| ) -> str: | |
| """ | |
| Start a new AutoTrain training job. | |
| Args: | |
| task: Type of training task (text-classification, llm-sft, | |
| llm-dpo, llm-orpo, image-classification) | |
| project_name: Name for the training project | |
| base_model: Base model from Hugging Face Hub | |
| (e.g., distilbert-base-uncased) | |
| dataset_path: Dataset path or HF dataset name (e.g., imdb) | |
| epochs: Number of training epochs (default: 3) | |
| batch_size: Training batch size (default: 16) | |
| learning_rate: Learning rate for training (default: 2e-5) | |
| backend: Training backend to use (default: local) | |
| Returns: | |
| Status message with run ID and details | |
| """ | |
| try: | |
| # Convert string parameters | |
| epochs_int = int(epochs) | |
| batch_size_int = int(batch_size) | |
| learning_rate_float = float(learning_rate) | |
| # Generate run ID | |
| run_id = str(uuid.uuid4()) | |
| # Create run record | |
| run_data = { | |
| "run_id": run_id, | |
| "project_name": project_name, | |
| "task": task, | |
| "base_model": base_model, | |
| "dataset_path": dataset_path, | |
| "status": "pending", | |
| "created_at": datetime.utcnow().isoformat(), | |
| "updated_at": datetime.utcnow().isoformat(), | |
| "config": { | |
| "task": task, | |
| "epochs": epochs_int, | |
| "batch_size": batch_size_int, | |
| "learning_rate": learning_rate_float, | |
| "backend": backend, | |
| }, | |
| } | |
| # Save to storage | |
| runs = load_runs() | |
| runs.append(run_data) | |
| save_runs(runs) | |
| # Create AutoTrain parameters | |
| params = create_autotrain_params( | |
| task=task, | |
| base_model=base_model, | |
| project_name=project_name, | |
| dataset_path=dataset_path, | |
| epochs=epochs_int, | |
| batch_size=batch_size_int, | |
| learning_rate=learning_rate_float, | |
| ) | |
| # Start training in background | |
| thread = threading.Thread( | |
| target=run_training_background, args=(run_id, params, backend) | |
| ) | |
| thread.daemon = True | |
| thread.start() | |
| return f"""β Training job submitted successfully! | |
| Run ID: {run_id} | |
| Project: {project_name} | |
| Task: {task} | |
| Model: {base_model} | |
| Dataset: {dataset_path} | |
| Configuration: | |
| β’ Epochs: {epochs} | |
| β’ Batch Size: {batch_size} | |
| β’ Learning Rate: {learning_rate} | |
| β’ Backend: {backend} | |
| π Monitor progress: | |
| β’ Gradio UI: http://localhost:7860 | |
| β’ W&B tracking will be available once training starts | |
| π‘ Use get_training_runs() to check status""" | |
| except Exception as e: | |
| return f"β Error submitting job: {str(e)}" | |
| def get_training_runs(limit: str = "20", status: str = "") -> str: | |
| """ | |
| Get list of training runs with their status and details. | |
| Args: | |
| limit: Maximum number of runs to return (default: 20) | |
| status: Filter by run status (pending, running, completed, | |
| failed, cancelled) | |
| Returns: | |
| Formatted list of training runs with status and links | |
| """ | |
| try: | |
| runs = load_runs() | |
| # Filter by status if provided | |
| if status: | |
| runs = [run for run in runs if run.get("status") == status] | |
| # Apply limit | |
| runs = runs[-int(limit) :] | |
| if not runs: | |
| return "No training runs found. Start a new training job to see it here!" | |
| runs_text = f"π Training Runs (showing {len(runs)}):\n\n" | |
| for run in reversed(runs): # Show newest first | |
| status_emoji = get_status_emoji(run["status"]) | |
| # Format run display with line break | |
| run_display = ( | |
| f"{status_emoji} **{run['project_name']}** ({run['run_id'][:8]}...)" | |
| ) | |
| runs_text += f"{run_display}\n" | |
| runs_text += f" Task: {run['task']}\n" | |
| runs_text += f" Model: {run['base_model']}\n" | |
| runs_text += f" Status: {run['status'].title()}\n" | |
| runs_text += f" Created: {run['created_at']}\n" | |
| if run.get("wandb_url"): | |
| runs_text += f" π W&B: {run['wandb_url']}\n" | |
| if run.get("error_message"): | |
| runs_text += f" β Error: {run['error_message']}\n" | |
| runs_text += "\n" | |
| return runs_text | |
| except Exception as e: | |
| return f"β Error fetching runs: {str(e)}" | |
| def get_run_details(run_id: str) -> str: | |
| """ | |
| Get detailed information about a specific training run. | |
| Args: | |
| run_id: ID of the training run (can be partial ID) | |
| Returns: | |
| Detailed run information including config and status | |
| """ | |
| try: | |
| runs = load_runs() | |
| # Find run by full or partial ID | |
| found_run = None | |
| for run in runs: | |
| if run["run_id"] == run_id or run["run_id"].startswith(run_id): | |
| found_run = run | |
| break | |
| if not found_run: | |
| return f"β Training run {run_id} not found" | |
| run = found_run | |
| details_text = f"""π Training Run Details | |
| **Run ID:** {run["run_id"]} | |
| **Project:** {run["project_name"]} | |
| **Task:** {run["task"]} | |
| **Model:** {run["base_model"]} | |
| **Dataset:** {run["dataset_path"]} | |
| **Status:** {run["status"].title()} | |
| **Timestamps:** | |
| β’ Created: {run["created_at"]} | |
| β’ Updated: {run.get("updated_at", "N/A")}""" | |
| if run.get("started_at"): | |
| details_text += f"\nβ’ Started: {run['started_at']}" | |
| if run.get("completed_at"): | |
| details_text += f"\nβ’ Completed: {run['completed_at']}" | |
| if run.get("wandb_url"): | |
| details_text += f"\n\nπ **W&B Dashboard:** {run['wandb_url']}" | |
| if run.get("error_message"): | |
| details_text += f"\n\nβ **Error:** {run['error_message']}" | |
| if run.get("config"): | |
| config = run["config"] | |
| details_text += "\n\nβοΈ **Training Configuration:**" | |
| details_text += f"\nβ’ Epochs: {config.get('epochs')}" | |
| details_text += f"\nβ’ Batch Size: {config.get('batch_size')}" | |
| details_text += f"\nβ’ Learning Rate: {config.get('learning_rate')}" | |
| details_text += f"\nβ’ Backend: {config.get('backend')}" | |
| return details_text | |
| except Exception as e: | |
| return f"β Error fetching run details: {str(e)}" | |
| def get_task_recommendations( | |
| task: str = "text-classification", dataset_size: str = "medium" | |
| ) -> str: | |
| """ | |
| Get training recommendations for a specific task type. | |
| Args: | |
| task: Task type (text-classification, llm-sft, image-classification) | |
| dataset_size: Size of dataset (small, medium, large) | |
| Returns: | |
| Recommended models, parameters, and best practices | |
| """ | |
| recommendations = { | |
| "text-classification": { | |
| "models": ["distilbert-base-uncased", "bert-base-uncased", "roberta-base"], | |
| "params": {"batch_size": 16, "learning_rate": 2e-5, "epochs": 3}, | |
| "backends": ["local", "spaces-t4-small"], | |
| "notes": [ | |
| "Good for sentiment analysis", | |
| "Works well with IMDB, AG News datasets", | |
| ], | |
| }, | |
| "llm-sft": { | |
| "models": [ | |
| "microsoft/DialoGPT-medium", | |
| "HuggingFaceTB/SmolLM2-1.7B-Instruct", | |
| ], | |
| "params": {"batch_size": 1, "learning_rate": 1e-5, "epochs": 3}, | |
| "backends": ["spaces-t4-medium", "spaces-a10g-large"], | |
| "notes": ["Use PEFT for efficiency", "Ensure proper chat formatting"], | |
| }, | |
| "image-classification": { | |
| "models": ["google/vit-base-patch16-224", "microsoft/resnet-50"], | |
| "params": {"batch_size": 32, "learning_rate": 2e-5, "epochs": 5}, | |
| "backends": ["local", "spaces-t4-small"], | |
| "notes": ["Ensure images are preprocessed", "Works with CIFAR, ImageNet"], | |
| }, | |
| } | |
| rec = recommendations.get( | |
| task, | |
| { | |
| "models": [], | |
| "params": {}, | |
| "backends": ["local"], | |
| "notes": ["No specific recommendations available"], | |
| }, | |
| ) | |
| rec_text = f"""π― Training Recommendations for {task.title()} \ | |
| ({dataset_size} dataset) | |
| **Recommended Models:** | |
| {chr(10).join(f"β’ {model}" for model in rec["models"])} | |
| **Recommended Parameters:** | |
| {chr(10).join(f"β’ {k}: {v}" for k, v in rec["params"].items())} | |
| **Backend Suggestions:** | |
| {chr(10).join(f"β’ {backend}" for backend in rec["backends"])} | |
| **Best Practices:** | |
| {chr(10).join(f"β’ {note}" for note in rec["notes"])}""" | |
| return rec_text | |
| def get_system_status(random_string: str = "") -> str: | |
| """ | |
| Get AutoTrain system status and capabilities. | |
| Returns: | |
| System status, available tasks, backends, and statistics | |
| """ | |
| try: | |
| runs = load_runs() | |
| # Calculate stats | |
| total_runs = len(runs) | |
| running_runs = len([r for r in runs if r.get("status") == "running"]) | |
| completed_runs = len([r for r in runs if r.get("status") == "completed"]) | |
| failed_runs = len([r for r in runs if r.get("status") == "failed"]) | |
| available_tasks = [ | |
| "text-classification", | |
| "llm-sft", | |
| "llm-dpo", | |
| "llm-orpo", | |
| "image-classification", | |
| ] | |
| available_backends = [ | |
| "local", | |
| "spaces-t4-small", | |
| "spaces-t4-medium", | |
| "spaces-a10g-large", | |
| "spaces-a10g-small", | |
| "spaces-a100-large", | |
| "spaces-l4x1", | |
| "spaces-l4x4", | |
| ] | |
| status_text = f"""π AutoTrain Gradio MCP Server - System Status | |
| **Server Status:** Running | |
| **Total Runs:** {total_runs} | |
| **Active Runs:** {running_runs} | |
| **Completed Runs:** {completed_runs} | |
| **Failed Runs:** {failed_runs} | |
| **Available Tasks:** {len(available_tasks)} | |
| {chr(10).join(f" β’ {task}" for task in available_tasks)} | |
| **Available Backends:** {len(available_backends)} | |
| {chr(10).join(f" β’ {backend}" for backend in available_backends[:10])} | |
| { | |
| f" ... and {len(available_backends) - 10} more" | |
| if len(available_backends) > 10 | |
| else "" | |
| } | |
| π‘ **Access Points:** | |
| β’ Gradio UI: http://localhost:7860 | |
| β’ MCP Server: http://localhost:7860/gradio_api/mcp/sse | |
| β’ MCP Schema: http://localhost:7860/gradio_api/mcp/schema | |
| π οΈ **W&B Integration:** | |
| β’ Project: {WANDB_PROJECT} | |
| β’ Set WANDB_PROJECT environment variable to customize""" | |
| return status_text | |
| except Exception as e: | |
| return f"β Error getting system status: {str(e)}" | |
| def refresh_data(random_string: str = "") -> str: | |
| """Refresh data for UI updates""" | |
| return "Data refreshed successfully" | |
| def load_initial_data(random_string: str = "") -> str: | |
| """Load initial data for the application""" | |
| return "Initial data loaded successfully" | |
| # Web UI Functions | |
| def fetch_runs_for_ui(): | |
| """Fetch runs for the web interface table""" | |
| try: | |
| runs = load_runs() | |
| if not runs: | |
| return pd.DataFrame( | |
| { | |
| "Status": [], | |
| "Project": [], | |
| "Task": [], | |
| "Model": [], | |
| "Created": [], | |
| "W&B Link": [], | |
| "Run ID": [], | |
| } | |
| ) | |
| data = [] | |
| for run in reversed(runs): # Newest first | |
| wandb_link = "" | |
| if run.get("wandb_url"): | |
| wandb_link = ( | |
| f'<a href="{run["wandb_url"]}" target="_blank">View W&B</a>' | |
| ) | |
| data.append( | |
| { | |
| "Status": f"{get_status_emoji(run['status'])} {run['status'].title()}", | |
| "Project": run["project_name"], | |
| "Task": run["task"].replace("-", " ").title(), | |
| "Model": run["base_model"], | |
| "Created": run["created_at"][:16].replace("T", " "), | |
| "W&B Link": wandb_link, | |
| "Run ID": run["run_id"][:8] + "...", | |
| } | |
| ) | |
| return pd.DataFrame(data) | |
| except Exception as e: | |
| return pd.DataFrame({"Error": [f"Failed to fetch runs: {str(e)}"]}) | |
| def submit_training_job_ui( | |
| task, | |
| project_name, | |
| base_model, | |
| dataset_path, | |
| epochs, | |
| batch_size, | |
| learning_rate, | |
| backend, | |
| ): | |
| """Submit training job from web UI""" | |
| if not all([task, project_name, base_model, dataset_path]): | |
| return "β Please fill in all required fields", fetch_runs_for_ui() | |
| result = start_training_job( | |
| task=task, | |
| project_name=project_name, | |
| base_model=base_model, | |
| dataset_path=dataset_path, | |
| epochs=str(epochs), | |
| batch_size=str(batch_size), | |
| learning_rate=str(learning_rate), | |
| backend=backend, | |
| ) | |
| return result, fetch_runs_for_ui() | |
| # Create Gradio Interface | |
| with gr.Blocks( | |
| title="AutoTrain Gradio MCP Server", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| } | |
| """, | |
| ) as app: | |
| gr.Markdown(""" | |
| # π AutoTrain Gradio MCP Server | |
| **All-in-One Solution:** Web UI + MCP Server + AutoTrain Integration | |
| β’ **Web Interface**: Manage training jobs through this UI | |
| β’ **MCP Server**: AI assistants can use tools at `http://localhost:7860/gradio_api/mcp/sse` | |
| β’ **Direct Integration**: No FastAPI needed - everything runs in Gradio | |
| """) | |
| with gr.Tabs(): | |
| # Dashboard Tab | |
| with gr.Tab("π Dashboard"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| gr.Markdown("## Training Runs") | |
| refresh_btn = gr.Button("π Refresh", variant="secondary") | |
| runs_table = gr.Dataframe( | |
| value=fetch_runs_for_ui(), interactive=False | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Quick Stats") | |
| stats = gr.Textbox( | |
| value=get_system_status(), interactive=False, lines=15 | |
| ) | |
| # Start Training Tab | |
| with gr.Tab("π Start Training"): | |
| gr.Markdown("## Submit New Training Job") | |
| with gr.Row(): | |
| with gr.Column(): | |
| task_dropdown = gr.Dropdown( | |
| choices=[ | |
| "text-classification", | |
| "llm-sft", | |
| "llm-dpo", | |
| "llm-orpo", | |
| "image-classification", | |
| ], | |
| label="Task Type", | |
| value="text-classification", | |
| ) | |
| project_name = gr.Textbox( | |
| label="Project Name", placeholder="my-training-project" | |
| ) | |
| base_model = gr.Textbox( | |
| label="Base Model", placeholder="distilbert-base-uncased" | |
| ) | |
| dataset_path = gr.Textbox(label="Dataset Path", placeholder="imdb") | |
| with gr.Column(): | |
| epochs = gr.Slider(1, 20, value=3, step=1, label="Epochs") | |
| batch_size = gr.Slider(1, 128, value=16, step=1, label="Batch Size") | |
| learning_rate = gr.Number(value=2e-5, label="Learning Rate") | |
| backend = gr.Dropdown( | |
| choices=["local", "spaces-t4-small", "spaces-a10g-large"], | |
| label="Backend", | |
| value="local", | |
| ) | |
| submit_btn = gr.Button("π Start Training", variant="primary", size="lg") | |
| submit_output = gr.Textbox(label="Status", interactive=False, lines=10) | |
| # MCP Info Tab | |
| with gr.Tab("π MCP Integration"): | |
| gr.Markdown(f""" | |
| ## MCP Server Information | |
| This Gradio app automatically serves as an MCP server. | |
| **MCP Endpoint:** `http://localhost:7860/gradio_api/mcp/sse` | |
| **MCP Schema:** `http://localhost:7860/gradio_api/mcp/schema` | |
| ### Available MCP Tools: | |
| - `start_training_job` - Submit new training jobs | |
| - `get_training_runs` - List all runs with status | |
| - `get_run_details` - Get detailed run information | |
| - `delete_training_run` - Delete training runs | |
| - `get_task_recommendations` - Get training recommendations | |
| - `get_system_status` - Check system status | |
| ### Claude Desktop Configuration: | |
| ```json | |
| {{ | |
| "mcpServers": {{ | |
| "autotrain": {{ | |
| "url": "http://localhost:7860/gradio_api/mcp/sse" | |
| }} | |
| }} | |
| }} | |
| ``` | |
| ### Current Stats: | |
| Total Runs: {len(load_runs())} | |
| W&B Project: {WANDB_PROJECT} | |
| """) | |
| # MCP Tools Tab | |
| with gr.Tab("π§ MCP Tools"): | |
| gr.Markdown("## MCP Tool Testing Interface") | |
| gr.Markdown("These tools are exposed via MCP for Claude Desktop") | |
| gr.Interface( | |
| fn=get_system_status, | |
| inputs=[], | |
| outputs=gr.Textbox(label="System Status"), | |
| title="get_system_status", | |
| description="Get AutoTrain system status and capabilities", | |
| ) | |
| gr.Interface( | |
| fn=get_training_runs, | |
| inputs=[ | |
| gr.Textbox(label="limit", value="20"), | |
| gr.Textbox(label="status", value=""), | |
| ], | |
| outputs=gr.Textbox(label="Training Runs"), | |
| title="get_training_runs", | |
| description="Get list of training runs with status", | |
| ) | |
| gr.Interface( | |
| fn=start_training_job, | |
| inputs=[ | |
| gr.Textbox(label="task", value="text-classification"), | |
| gr.Textbox(label="project_name", value="test-project"), | |
| gr.Textbox(label="base_model", value="distilbert-base-uncased"), | |
| gr.Textbox(label="dataset_path", value="imdb"), | |
| gr.Textbox(label="epochs", value="1"), | |
| gr.Textbox(label="batch_size", value="8"), | |
| gr.Textbox(label="learning_rate", value="2e-5"), | |
| gr.Textbox(label="backend", value="local"), | |
| ], | |
| outputs=gr.Textbox(label="Training Job Result"), | |
| title="start_training_job", | |
| description="Start a new AutoTrain training job", | |
| ) | |
| gr.Interface( | |
| fn=get_run_details, | |
| inputs=gr.Textbox( | |
| label="run_id", placeholder="Enter run ID or first 8 chars" | |
| ), | |
| outputs=gr.Textbox(label="Run Details"), | |
| title="get_run_details", | |
| description="Get detailed information about a training run", | |
| ) | |
| gr.Interface( | |
| fn=get_task_recommendations, | |
| inputs=[ | |
| gr.Textbox(label="task", value="text-classification"), | |
| gr.Textbox(label="dataset_size", value="medium"), | |
| ], | |
| outputs=gr.Textbox(label="Recommendations"), | |
| title="get_task_recommendations", | |
| description="Get training recommendations for a task", | |
| ) | |
| # Event handlers with proper function names (not lambda) | |
| def refresh_data(): | |
| return fetch_runs_for_ui(), get_system_status() | |
| def load_initial_data(): | |
| return fetch_runs_for_ui(), get_system_status() | |
| refresh_btn.click( | |
| fn=refresh_data, | |
| outputs=[runs_table, stats], | |
| ) | |
| submit_btn.click( | |
| fn=submit_training_job_ui, | |
| inputs=[ | |
| task_dropdown, | |
| project_name, | |
| base_model, | |
| dataset_path, | |
| epochs, | |
| batch_size, | |
| learning_rate, | |
| backend, | |
| ], | |
| outputs=[submit_output, runs_table], | |
| ) | |
| # Load initial data | |
| app.load( | |
| fn=load_initial_data, | |
| outputs=[runs_table, stats], | |
| ) | |
| # Helper to find an available port | |
| def _find_available_port(start_port: int = 7860, max_tries: int = 20) -> int: | |
| """Return the first available port starting from `start_port`.""" | |
| port = start_port | |
| for _ in range(max_tries): | |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
| try: | |
| s.bind(("0.0.0.0", port)) | |
| return port # Port is free | |
| except OSError: | |
| port += 1 # Try next port | |
| # If no port found, let OS pick one | |
| return 0 | |
| if __name__ == "__main__": | |
| chosen_port = int(os.environ.get("GRADIO_SERVER_PORT", "7860")) | |
| try: | |
| chosen_port = _find_available_port(chosen_port) | |
| except Exception: | |
| # Fallback to OS-assigned port if something goes wrong | |
| chosen_port = 0 | |
| app.launch( | |
| server_name="0.0.0.0", | |
| server_port=chosen_port, | |
| mcp_server=True, # Enable MCP server functionality | |
| ) | |